Update start guide (#4)

This commit is contained in:
Ma Zerun 2023-07-05 18:26:26 +08:00 committed by GitHub
parent dcf11cf8fd
commit 5840c7655c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 196 additions and 31 deletions

View File

@ -1,17 +1,14 @@
# Overview
# Installation
1. Prepare Torch refer to [PyTorch](https://pytorch.org/).
Notice that OpenCompass requires `pytorch>=1.13`.
1. Use the following commands to set up the OpenCompass environment:
```bash
conda create --name opencompass python=3.8 -y
conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y
conda activate opencompass
conda install pytorch torchvision -c pytorch
```
If you want to customize the PyTorch version or related CUDA version, please refer to the [official documentation](https://pytorch.org/get-started/locally/) to set up the PyTorch environment. Note that OpenCompass requires `pytorch>=1.13`.
2. Install OpenCompass:
```bash
@ -20,20 +17,107 @@ cd opencompass
pip install -e .
```
3. Install humaneval (option)
3. Install humaneval (Optional)
do this if you want to eval on humaneval dataset.
If you want to perform evaluations on the humaneval dataset, follow these steps.
```
git clone https://github.com/openai/human-eval.git
cd human-eval
pip install -r requirments.txt
pip install -e .
cd ..
```
Remember to remove the comments of Line48-57 and uncomment [line58](https://github.com/openai/human-eval/blob/312c5e5532f0e0470bf47f77a6243e02a61da530/human_eval/execution.py#L58) in the source code.
Please read the comments in `human_eval/execution.py` **lines 48-57** to understand the potential risks of executing the model generation code. If you accept these risks, uncomment **line 58** to enable code execution evaluation.
# Quick tour
# Quick Start
In this section, we will use the example of testing LLaMA-7B on SIQA and PIQA to familiarize you with some
basic features of OpenCompass. Before running, make sure you have installed OpenCompass and have GPU computing
resources that meet the minimum requirements for LLaMA-7B.
## Prepare the Dataset
Create a `data` folder in the repository directory and place the dataset files in the `data` folder.
## Prepare the Evaluation Configuration File
Create the following configuration file `configs/llama.py`:
```python
from mmengine.config import read_base
with read_base():
# Read the required dataset configurations directly from the preset dataset configurations
from .datasets.piqa.piqa_ppl import piqa_datasets
from .datasets.siqa.siqa_gen import siqa_datasets
# Concatenate the datasets to be evaluated into the datasets field
datasets = [*piqa_datasets, *siqa_datasets]
# Evaluate models supported by HuggingFace's `AutoModelForCausalLM` using `HuggingFaceCausalLM`
from opencompass.models import HuggingFaceCausalLM
models = [
dict(
type=HuggingFaceCausalLM,
# Initialization parameters for `HuggingFaceCausalLM`
path='huggyllama/llama-7b',
tokenizer_path='huggyllama/llama-7b',
tokenizer_kwargs=dict(padding_side='left', truncation_side='left'),
max_seq_len=2048,
# Common parameters for all models, not specific to HuggingFaceCausalLM's initialization parameters
abbr='llama-7b', # Model abbreviation for result display
max_out_len=100, # Maximum number of generated tokens
batch_size=16,
run_cfg=dict(num_gpus=1), # Run configuration for specifying resource requirements
)
]
```
## Start the Evaluation
First, we can start the task in **debug mode** to check for any exceptions in model loading, dataset reading, or incorrect cache usage.
```shell
python run.py configs/llama.py -w outputs/llama --debug
```
However, in `--debug` mode, tasks are executed sequentially. After confirming that everything is correct, you
can disable the `--debug` mode to fully utilize multiple GPUs.
```shell
python run.py configs/llama.py -w outputs/llama
```
Here are some parameters related to evaluation that can help you configure more efficient inference tasks based on your environment:
- `-w outputs/llama`: Directory to save evaluation logs and results.
- `-r`: Restart the previous (interrupted) evaluation.
- `--mode all`: Specify a specific stage of the task.
- all: Perform a complete evaluation, including inference and evaluation.
- infer: Perform inference on each dataset.
- eval: Perform evaluation based on the inference results.
- viz: Display evaluation results only.
- `--max-partition-size 2000`: Dataset partition size. Some datasets may be large, and using this parameter can split them into multiple sub-tasks to efficiently utilize resources. However, if the partition is too fine, the overall speed may be slower due to longer model loading times.
- `--max-num-workers 32`: Maximum number of parallel tasks. In distributed environments such as Slurm, this parameter specifies the maximum number of submitted tasks. In a local environment, it specifies the maximum number of tasks executed in parallel. Note that the actual number of parallel tasks depends on the available GPU resources and may not be equal to this number.
If you are not performing the evaluation on your local machine but using a Slurm cluster, you can specify the following parameters:
- `--slurm`: Submit tasks using Slurm on the cluster.
- `--partition my_part`: Slurm cluster partition.
- `--retry 2`: Number of retries for failed tasks.
## Obtaining Evaluation Results
After the evaluation is complete, the evaluation results table will be printed as follows:
```text
dataset version metric mode llama-7b
--------- --------- -------- ------ ----------
piqa 1cf9f0 accuracy ppl 77.75
siqa e78df3 accuracy gen 36.08
```
Additionally, the text and CSV format result files will be saved in the `summary` folder of the result directory.

View File

@ -1,17 +1,14 @@
# 概述
# 安装
1. 参考 [PyTorch](https://pytorch.org/) 准备 Torch。
注意OpenCompass 需要 `pytorch>=1.13`
1. 使用以下命令准备 OpenCompass 环境:
```bash
conda create --name opencompass python=3.8 -y
conda create --name opencompass python=3.10 pytorch torchvision pytorch-cuda -c nvidia -c pytorch -y
conda activate opencompass
conda install pytorch torchvision -c pytorch
```
如果你希望自定义 PyTorch 版本或相关的 CUDA 版本,请参考 [官方文档](https://pytorch.org/get-started/locally/) 准备 PyTorch 环境。需要注意的是OpenCompass 要求 `pytorch>=1.13`
2. 安装 OpenCompass
```bash
@ -29,9 +26,96 @@ git clone https://github.com/openai/human-eval.git
cd human-eval
pip install -r requirments.txt
pip install -e .
cd ..
```
记住在源代码中删除第48-57行的注释并取消对[第58行](https://github.com/openai/human-eval/blob/312c5e5532f0e0470bf47f77a6243e02a61da530/human_eval/execution.py#L58)的注释
仔细阅读 `human_eval/execution.py` **第48-57行**的注释,了解执行模型生成的代码可能存在的风险,如果接受这些风险,请取消**第58行**的注释,启用代码执行评测
# 快速上手
在这一节,我们会以测试 LLaMA-7B 在 SIQA 和 PIQA 上的性能为例,带领你熟悉 OpenCompass 的一些基本功能。在运行前,
请先确保你安装好了 OpenCompass并在本机或集群上有满足 LLaMA-7B 最低要求的 GPU 计算资源。
## 准备数据集
在仓库目录创建 data 文件夹,并将数据集文件放置在 data 文件夹中
## 准备评测配置文件
创建如下配置文件 `configs/llama.py`:
```python
from mmengine.config import read_base
with read_base():
# 直接从预设数据集配置中读取需要的数据集配置
from .datasets.piqa.piqa_ppl import piqa_datasets
from .datasets.siqa.siqa_gen import siqa_datasets
# 将需要评测的数据集拼接成 datasets 字段
datasets = [*piqa_datasets, *siqa_datasets]
# 使用 HuggingFaceCausalLM 评测 HuggingFace 中 AutoModelForCausalLM 支持的模型
from opencompass.models import HuggingFaceCausalLM
models = [
dict(
type=HuggingFaceCausalLM,
# 以下参数为 HuggingFaceCausalLM 的初始化参数
path='huggyllama/llama-7b',
tokenizer_path='huggyllama/llama-7b',
tokenizer_kwargs=dict(padding_side='left', truncation_side='left'),
max_seq_len=2048,
# 以下参数为各类模型都有的参数,非 HuggingFaceCausalLM 的初始化参数
abbr='llama-7b', # 模型简称,用于结果展示
max_out_len=100, # 最长生成 token 数
batch_size=16, # 批次大小
run_cfg=dict(num_gpus=1), # 运行配置,用于指定资源需求
)
]
```
## 启动评测
首先,我们可以使用 debug 模式启动任务,以检查模型加载、数据集读取是否出现异常,如未正确读取缓存等。
```shell
python run.py configs/llama.py -w outputs/llama --debug
```
`--debug` 模式下只能逐一序列执行任务,因此检查无误后,可关闭 `--debug` 模式,使程序充分利用多卡资源
```shell
python run.py configs/llama.py -w outputs/llama
```
以下是一些与评测相关的参数,可以帮助你根据自己的环境情况配置更高效的推理任务。
- `-w outputs/llama`: 评测日志及结果保存目录
- `-r`: 重启上一次(中断的)评测
- `--mode all`: 指定进行某一阶段的任务
- all: 进行全阶段评测,包括推理和评估
- infer: 仅进行各个数据集上的推理
- eval: 仅基于推理结果进行评估
- viz: 仅展示评估结果
- `--max-partition-size 2000`: 数据集拆分尺寸,部分数据集可能比较大,利用此参数将其拆分成多个子任务,能有效利用资源。但如果拆分过细,则可能因为模型本身加载时间过长,反而速度更慢
- `--max-num-workers 32`: 最大并行启动任务数,在 Slurm 等分布式环境中,该参数用于指定最大提交任务数;在本地环境中,该参数用于指定最大并行执行的任务数,注意实际并行执行任务数受制于 GPU 等资源数,并不一定为该数字。
如果你不是在本机进行评测,而是使用 slurm 集群,可以指定如下参数:
- `--slurm`: 使用 slurm 在集群提交任务
- `--partition my_part`: slurm 集群分区
- `--retry 2`: 任务出错重试次数
## 获取评测结果
评测完成后,会打印评测结果表格如下:
```text
dataset version metric mode llama-7b
--------- --------- -------- ------ ----------
piqa 1cf9f0 accuracy ppl 77.75
siqa e78df3 accuracy gen 36.08
```
另外,会在结果保存目录的 `summary` 文件夹中保存 txt 和 csv 格式的结果文件。

View File

@ -61,7 +61,6 @@ class LocalRunner(BaseRunner):
gpus = np.ones(torch.cuda.device_count(), dtype=np.bool_)
pbar = tqdm(total=len(tasks))
lock = Lock()
logger = get_logger()
def submit(task, index):
task = TASKS.build(dict(type=self.task_cfg.type, cfg=task))
@ -113,8 +112,8 @@ class LocalRunner(BaseRunner):
# Dump task config to file
mmengine.mkdir_or_exist('tmp/')
param_file = f'tmp/{os.getpid()}_{index}_params.json'
mmengine.dump(task.cfg, param_file)
param_file = f'tmp/{os.getpid()}_{index}_params.py'
task.cfg.dump(param_file)
# Build up slurm command
task_cmd_template = task.get_command_template()
@ -127,12 +126,9 @@ class LocalRunner(BaseRunner):
logger.debug(f'Running command: {cmd}')
# Run command
if self.debug:
stdout = None
else:
out_path = task.get_log_path(file_extension='out')
mmengine.mkdir_or_exist(osp.split(out_path)[0])
stdout = open(out_path, 'w', encoding='utf-8')
out_path = task.get_log_path(file_extension='out')
mmengine.mkdir_or_exist(osp.split(out_path)[0])
stdout = open(out_path, 'w', encoding='utf-8')
result = subprocess.run(cmd,
shell=True,

View File

@ -88,10 +88,11 @@ class Summarizer:
dataset_eval_mode = {}
for dataset in dataset_cfgs:
inferencer = dataset.get('infer_cfg', {}).get('inferencer', {}).get('type', '')
inferencer = inferencer if isinstance(inferencer, str) else inferencer.__name__
dataset_abbr = dataset_abbr_from_cfg(dataset)
if inferencer == 'GenInferencer':
if 'GenInferencer' in inferencer:
dataset_eval_mode[dataset_abbr] = 'gen'
elif inferencer == 'PPLInferencer':
elif 'PPLInferencer' in inferencer:
dataset_eval_mode[dataset_abbr] = 'ppl'
else:
dataset_eval_mode[dataset_abbr] = 'unknown'
@ -130,7 +131,7 @@ class Summarizer:
else:
raw_results[model_abbr][sg['name']] = {'error': 'missing datasets: {}'.format(set(sg['subsets']) - set(results.keys()))}
prompt_version = {dataset_abbr_from_cfg(d): get_prompt_hash(d) for d in dataset_cfgs}
prompt_version = {dataset_abbr_from_cfg(d): get_prompt_hash(d)[:6] for d in dataset_cfgs}
# format table
summarizer_dataset_abbrs = []