diff --git a/.gitignore b/.gitignore index b6befd46..4eb5978b 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ configs/eval_debug*.py configs/viz_*.py data work_dirs -models +models/* configs/internal/ # Byte-compiled / optimized / DLL files __pycache__/ @@ -94,3 +94,6 @@ configs/cky/ # path of turbomind's model after runing `lmdeploy.serve.turbomind.deploy` turbomind/ + +# ignore the config file for criticbench evaluation +configs/sft_cfg/criticbench_eval/* diff --git a/configs/datasets/humaneval/humaneval_gen_6d1cc2.py b/configs/datasets/humaneval/humaneval_gen_6d1cc2.py new file mode 100644 index 00000000..9740039e --- /dev/null +++ b/configs/datasets/humaneval/humaneval_gen_6d1cc2.py @@ -0,0 +1,36 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import HumanevalDataset, HumanEvaluator, humaneval_postprocess + +humaneval_reader_cfg = dict( + input_columns=['prompt'], output_column='task_id', train_split='test') + +# TODO: allow empty output-column +humaneval_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role='HUMAN', + prompt='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nComplete the following python function.:\n{prompt}\n\n### Response:\n'), + ])), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=512)) + +humaneval_eval_cfg = dict( + evaluator=dict(type=HumanEvaluator), + pred_role='BOT', + k=[1, 10, 100], # the parameter only for humaneval + pred_postprocessor=dict(type=humaneval_postprocess), +) + +humaneval_datasets = [ + dict( + abbr='openai_humaneval', + type=HumanevalDataset, + path='./data/humaneval/human-eval-v2-20210705.jsonl', + reader_cfg=humaneval_reader_cfg, + infer_cfg=humaneval_infer_cfg, + eval_cfg=humaneval_eval_cfg) +] diff --git a/configs/datasets/mbpp/mbpp_gen_caa7ab.py b/configs/datasets/mbpp/mbpp_gen_caa7ab.py new file mode 100644 index 00000000..9c24f7ac --- /dev/null +++ b/configs/datasets/mbpp/mbpp_gen_caa7ab.py @@ -0,0 +1,65 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import MBPPDataset, MBPPEvaluator + +mbpp_reader_cfg = dict( + input_columns=['text', 'test_list'], output_column='test_list_2') + +mbpp_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role="HUMAN", + prompt= + "You are an expert Python programmer, and here is your task: Write a function to find the similar elements from the given two tuple lists. Your code should pass these tests:\n\n assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)\n assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4) \n assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14) \n\nYour code should start with a [BEGIN] tag and end with a [DONE] tag.\n" + ), + dict( + role="BOT", + prompt= + "[BEGIN]\ndef similar_elements(test_tup1, test_tup2):\r\n res = tuple(set(test_tup1) & set(test_tup2))\r\n return (res)\n[DONE] \n\n " + ), + dict( + role="HUMAN", + prompt= + "You are an expert Python programmer, and here is your task: Write a python function to identify non-prime numbers. Your code should pass these tests:\n\n assert is_not_prime(2) == False \n assert is_not_prime(10) == True \n assert is_not_prime(35) == True \n\nYour code should start with a [BEGIN] tag and end with a [DONE] tag.\n" + ), + dict( + role="BOT", + prompt= + "[BEGIN]\nimport math\r\ndef is_not_prime(n):\r\n result = False\r\n for i in range(2,int(math.sqrt(n)) + 1):\r\n if n % i == 0:\r\n result = True\r\n return result\n[DONE] \n\n " + ), + dict( + role="HUMAN", + prompt= + "You are an expert Python programmer, and here is your task: Write a function to find the largest integers from a given list of numbers using heap queue algorithm. Your code should pass these tests:\n\n assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] \n assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] \n assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35] \n\nYour code should start with a [BEGIN] tag and end with a [DONE] tag.\n" + ), + dict( + role="BOT", + prompt= + "[BEGIN]\nimport heapq as hq\r\ndef heap_queue_largest(nums,n):\r\n largest_nums = hq.nlargest(n, nums)\r\n return largest_nums\n[DONE] \n\n " + ), + dict( + role="HUMAN", + prompt= + "You are an expert Python programmer, and here is your task: {text} Your code should pass these tests:\n\n {test_list} \n\nYour code should start with a [BEGIN] tag and end with a [DONE] tag.\n" + ), + dict(role="BOT", prompt="[BEGIN]\n"), + + ], )), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=512)) + +mbpp_eval_cfg = dict(evaluator=dict(type=MBPPEvaluator), pred_role="BOT") + +mbpp_datasets = [ + dict( + type=MBPPDataset, + abbr='mbpp', + path='./data/mbpp/mbpp.jsonl', + reader_cfg=mbpp_reader_cfg, + infer_cfg=mbpp_infer_cfg, + eval_cfg=mbpp_eval_cfg) +] diff --git a/configs/models/hf_internlm/hf_internlm_chat_7b_v11.py b/configs/models/hf_internlm/hf_internlm_chat_7b_v11.py new file mode 100644 index 00000000..7471e68c --- /dev/null +++ b/configs/models/hf_internlm/hf_internlm_chat_7b_v11.py @@ -0,0 +1,34 @@ +from opencompass.models import HuggingFaceCausalLM + + +_meta_template = dict( + round=[ + dict(role='HUMAN', begin='<|User|>:', end='\n'), + dict(role='BOT', begin='<|Bot|>:', end='\n', generate=True), + ], +) + +models = [ + dict( + type=HuggingFaceCausalLM, + abbr='internlm-chat-7b-v1.1-hf', + path="internlm/internlm-chat-7b-v1_1", + tokenizer_path='internlm/internlm-chat-7b-v1_1', + model_kwargs=dict( + trust_remote_code=True, + device_map='auto', + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + use_fast=False, + trust_remote_code=True, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + meta_template=_meta_template, + run_cfg=dict(num_gpus=1, num_procs=1), + end_str='', + ) +] diff --git a/configs/models/mistral/hf_mistral_7b_instruct_v02.py b/configs/models/mistral/hf_mistral_7b_instruct_v02.py new file mode 100644 index 00000000..0dec9321 --- /dev/null +++ b/configs/models/mistral/hf_mistral_7b_instruct_v02.py @@ -0,0 +1,34 @@ +from opencompass.models import HuggingFaceCausalLM + + +_meta_template = dict( + begin="", + round=[ + dict(role="HUMAN", begin='[INST]', end='[/INST]'), + dict(role="BOT", begin="", end='', generate=True), + ], + eos_token_id=2 +) + +models = [ + dict( + abbr='mistral-7b-instruct-v0.2-hf', + type=HuggingFaceCausalLM, + path='mistralai/Mistral-7B-Instruct-v0.2', + tokenizer_path='mistralai/Mistral-7B-Instruct-v0.2', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + meta_template=_meta_template, + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/mixtral/hf_mixtral_8x7b_instruct_v01.py b/configs/models/mixtral/hf_mixtral_8x7b_instruct_v01.py new file mode 100644 index 00000000..cdfff9f0 --- /dev/null +++ b/configs/models/mixtral/hf_mixtral_8x7b_instruct_v01.py @@ -0,0 +1,34 @@ +from opencompass.models import HuggingFaceCausalLM + + +_meta_template = dict( + begin="", + round=[ + dict(role="HUMAN", begin='[INST]', end='[/INST]'), + dict(role="BOT", begin="", end='', generate=True), + ], + eos_token_id=2 +) + +models = [ + dict( + abbr='mixtral-8x7b-instruct-v0.1', + type=HuggingFaceCausalLM, + path='mistralai/Mixtral-8x7B-Instruct-v0.1', + tokenizer_path='mistralai/Mixtral-8x7B-Instruct-v0.1', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + meta_template=_meta_template, + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/models/mixtral/hf_mixtral_8x7b_v01.py b/configs/models/mixtral/hf_mixtral_8x7b_v01.py new file mode 100644 index 00000000..e6e3c217 --- /dev/null +++ b/configs/models/mixtral/hf_mixtral_8x7b_v01.py @@ -0,0 +1,24 @@ +from opencompass.models import HuggingFaceCausalLM + + +models = [ + dict( + abbr='mixtral-8x7b-v0.1', + type=HuggingFaceCausalLM, + path='mistralai/Mixtral-8x7B-v0.1', + tokenizer_path='mistralai/Mixtral-8x7B-v0.1', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + max_out_len=100, + max_seq_len=2048, + batch_size=8, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] diff --git a/configs/summarizers/groups/cibench.py b/configs/summarizers/groups/cibench.py index d49f0774..ceb41914 100644 --- a/configs/summarizers/groups/cibench.py +++ b/configs/summarizers/groups/cibench.py @@ -1,4 +1,4 @@ _cibench = ['Pandas', 'Matplotlib', 'Opencv', 'SciPy', 'Seaborn', 'PyTorch'] -_cibench = ['cibench_' + i for i in _cibench] -cibench_summary_groups = [{'name': 'cibench', 'subsets': _cibench}] +_cibench = ['cibench_generation_' + i for i in _cibench] +cibench_summary_groups = [{'name': 'cibench_generation', 'subsets': _cibench}] diff --git a/configs/summarizers/groups/mathbench.py b/configs/summarizers/groups/mathbench.py index ccd9db5f..b4877ef1 100644 --- a/configs/summarizers/groups/mathbench.py +++ b/configs/summarizers/groups/mathbench.py @@ -1,5 +1,6 @@ +from copy import deepcopy -mathbench_summary_groups = [ +naive_mathbench_summary_groups = [ { 'name': 'mathbench-college', 'subsets': [ @@ -73,3 +74,15 @@ mathbench_summary_groups = [ ], } ] + +agent_mathbench_summary_groups = [] +for item in naive_mathbench_summary_groups: + item = deepcopy(item) + item['name'] = item['name'] + '-agent' + if isinstance(item['subsets'][0], str): + item['subsets'] = [i + '-agent' for i in item['subsets']] + else: + item['subsets'] = [[i[0] + '-agent', i[1]] for i in item['subsets']] + agent_mathbench_summary_groups.append(item) + +mathbench_summary_groups = naive_mathbench_summary_groups + agent_mathbench_summary_groups diff --git a/configs/summarizers/groups/plugineval.py b/configs/summarizers/groups/plugineval.py new file mode 100644 index 00000000..6c9b5c78 --- /dev/null +++ b/configs/summarizers/groups/plugineval.py @@ -0,0 +1,34 @@ +plugineval_summary_groups = [ + { + 'name': 'plugin_eval-instruct_v1', + 'metric': 'format_metric', + 'subsets': [ + ['plugin_eval-instruct_v1', 'string_format_metric'], + ['plugin_eval-instruct_v1', 'json_format_metric'], + ] + }, + { + 'name': 'plugin_eval-instruct_v1', + 'metric': 'args_em_metric', + 'subsets': [ + ['plugin_eval-instruct_v1', 'string_args_em_metric'], + ['plugin_eval-instruct_v1', 'json_args_em_metric'], + ] + }, + { + 'name': 'plugin_eval', + 'subsets': [ + ['plugin_eval-instruct_v1', 'format_metric'], + ['plugin_eval-instruct_v1', 'args_em_metric'], + ['plugin_eval-plan_str_v1', 'f1_score'], + ['plugin_eval-plan_json_v1', 'f1_score'], + ['plugin_eval-reason_str_v2', 'thought'], + ['plugin_eval-reason_retrieve_understand_json_v2', 'thought'], + ['plugin_eval-retrieve_str_v2', 'name'], + ['plugin_eval-reason_retrieve_understand_json_v2', 'name'], + ['plugin_eval-understand_str_v2', 'args'], + ['plugin_eval-reason_retrieve_understand_json_v2', 'args'], + ['plugin_eval-review_str_v6', 'review_quality'], + ] + }, +] diff --git a/docs/en/advanced_guides/custom_dataset.md b/docs/en/advanced_guides/custom_dataset.md new file mode 100644 index 00000000..99c2617f --- /dev/null +++ b/docs/en/advanced_guides/custom_dataset.md @@ -0,0 +1,149 @@ +# Custom Dataset Tutorial + +This tutorial is intended for temporary and informal use of datasets. If the dataset requires long-term use or has specific needs for custom reading/inference/evaluation, it is strongly recommended to implement it according to the methods described in [new_dataset.md](./new_dataset.md). + +In this tutorial, we will introduce how to test a new dataset without implementing a config or modifying the OpenCompass source code. We support two types of tasks: multiple choice (`mcq`) and question & answer (`qa`). For `mcq`, both ppl and gen inferences are supported; for `qa`, gen inference is supported. + +## Dataset Format + +We support datasets in both `.jsonl` and `.csv` formats. + +### Multiple Choice (`mcq`) + +For `mcq` datasets, the default fields are as follows: + +- `question`: The stem of the multiple-choice question. +- `A`, `B`, `C`, ...: Single uppercase letters representing the options, with no limit on the number. Defaults to parsing consecutive letters strating from `A` as options. +- `answer`: The correct answer to the multiple-choice question, which must be one of the options used above, such as `A`, `B`, `C`, etc. + +Non-default fields will be read in but are not used by default. To use them, specify in the `.meta.json` file. + +An example of the `.jsonl` format: + +```jsonl +{"question": "165+833+650+615=", "A": "2258", "B": "2263", "C": "2281", "answer": "B"} +{"question": "368+959+918+653+978=", "A": "3876", "B": "3878", "C": "3880", "answer": "A"} +{"question": "776+208+589+882+571+996+515+726=", "A": "5213", "B": "5263", "C": "5383", "answer": "B"} +{"question": "803+862+815+100+409+758+262+169=", "A": "4098", "B": "4128", "C": "4178", "answer": "C"} +``` + +An example of the `.csv` format: + +```csv +question,A,B,C,answer +127+545+588+620+556+199=,2632,2635,2645,B +735+603+102+335+605=,2376,2380,2410,B +506+346+920+451+910+142+659+850=,4766,4774,4784,C +504+811+870+445=,2615,2630,2750,B +``` + +### Question & Answer (`qa`) + +For `qa` datasets, the default fields are as follows: + +- `question`: The stem of the question & answer question. +- `answer`: The correct answer to the question & answer question. It can be missing, indicating the dataset has no correct answer. + +Non-default fields will be read in but are not used by default. To use them, specify in the `.meta.json` file. + +An example of the `.jsonl` format: + +```jsonl +{"question": "752+361+181+933+235+986=", "answer": "3448"} +{"question": "712+165+223+711=", "answer": "1811"} +{"question": "921+975+888+539=", "answer": "3323"} +{"question": "752+321+388+643+568+982+468+397=", "answer": "4519"} +``` + +An example of the `.csv` format: + +```csv +question,answer +123+147+874+850+915+163+291+604=,3967 +149+646+241+898+822+386=,3142 +332+424+582+962+735+798+653+214=,4700 +649+215+412+495+220+738+989+452=,4170 +``` + +## Command Line List + +Custom datasets can be directly called for evaluation through the command line. + +```bash +python run.py \ + --models hf_llama2_7b \ + --custom-dataset-path xxx/test_mcq.csv \ + --custom-dataset-data-type mcq \ + --custom-dataset-infer-method ppl +``` + +```bash +python run.py \ + --models hf_llama2_7b \ + --custom-dataset-path xxx/test_qa.jsonl \ + --custom-dataset-data-type qa \ + --custom-dataset-infer-method gen +``` + +In most cases, `--custom-dataset-data-type` and `--custom-dataset-infer-method` can be omitted. OpenCompass will + +set them based on the following logic: + +- If options like `A`, `B`, `C`, etc., can be parsed from the dataset file, it is considered an `mcq` dataset; otherwise, it is considered a `qa` dataset. +- The default `infer_method` is `gen`. + +## Configuration File + +In the original configuration file, simply add a new item to the `datasets` variable. Custom datasets can be mixed with regular datasets. + +```python +datasets = [ + {"path": "xxx/test_mcq.csv", "data_type": "mcq", "infer_method": "ppl"}, + {"path": "xxx/test_qa.jsonl", "data_type": "qa", "infer_method": "gen"}, +] +``` + +## Supplemental Information for Dataset `.meta.json` + +OpenCompass will try to parse the input dataset file by default, so in most cases, the `.meta.json` file is **not necessary**. However, if the dataset field names are not the default ones, or custom prompt words are required, it should be specified in the `.meta.json` file. + +The file is placed in the same directory as the dataset, with the filename followed by `.meta.json`. An example file structure is as follows: + +```tree +. +├── test_mcq.csv +├── test_mcq.csv.meta.json +├── test_qa.jsonl +└── test_qa.jsonl.meta.json +``` + +Possible fields in this file include: + +- `abbr` (str): Abbreviation of the dataset, serving as its ID. +- `data_type` (str): Type of dataset, options are `mcq` and `qa`. +- `infer_method` (str): Inference method, options are `ppl` and `gen`. +- `human_prompt` (str): User prompt template for generating prompts. Variables in the template are enclosed in `{}`, like `{question}`, `{opt1}`, etc. If `template` exists, this field will be ignored. +- `bot_prompt` (str): Bot prompt template for generating prompts. Variables in the template are enclosed in `{}`, like `{answer}`, etc. If `template` exists, this field will be ignored. +- `template` (str or dict): Question template for generating prompts. Variables in the template are enclosed in `{}`, like `{question}`, `{opt1}`, etc. The relevant syntax is in [here](../prompt/prompt_template.md) regarding `infer_cfg['prompt_template']['template']`. +- `input_columns` (list): List of input fields for reading data. +- `output_column` (str): Output field for reading data. +- `options` (list): List of options for reading data, valid only when `data_type` is `mcq`. + +For example: + +```json +{ + "human_prompt": "Question: 127 + 545 + 588 + 620 + 556 + 199 =\nA. 2632\nB. 2635\nC. 2645\nAnswer: Let's think step by step, 127 + 545 + 588 + 620 + 556 + 199 = 672 + 588 + 620 + 556 + 199 = 1260 + 620 + 556 + 199 = 1880 + 556 + 199 = 2436 + 199 = 2635. So the answer is B.\nQuestion: {question}\nA. {A}\nB. {B}\nC. {C}\nAnswer: ", + "bot_prompt": "{answer}" +} +``` + +or + +```json +{ + "template": "Question: {my_question}\nX. {X}\nY. {Y}\nZ. {Z}\nW. {W}\nAnswer:", + "input_columns": ["my_question", "X", "Y", "Z", "W"], + "output_column": "my_answer", +} +``` diff --git a/docs/en/index.rst b/docs/en/index.rst index 20f01129..fe3792a4 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -58,6 +58,7 @@ We always welcome *PRs* and *Issues* for the betterment of OpenCompass. :caption: Advanced Guides advanced_guides/new_dataset.md + advanced_guides/custom_dataset.md advanced_guides/new_model.md advanced_guides/evaluation_turbomind.md advanced_guides/evaluation_lightllm.md diff --git a/docs/zh_cn/advanced_guides/custom_dataset.md b/docs/zh_cn/advanced_guides/custom_dataset.md new file mode 100644 index 00000000..2c96486a --- /dev/null +++ b/docs/zh_cn/advanced_guides/custom_dataset.md @@ -0,0 +1,147 @@ +# 自定义数据集 + +本教程仅供临时性的、非正式的数据集使用,如果所用数据集需要长期使用,或者存在定制化读取 / 推理 / 评测需求的,强烈建议按照 [new_dataset.md](./new_dataset.md) 中介绍的方法进行实现。 + +在本教程中,我们将会介绍如何在不实现 config,不修改 OpenCompass 源码的情况下,对一新增数据集进行测试的方法。我们支持的任务类型包括选择 (`mcq`) 和问答 (`qa`) 两种,其中 `mcq` 支持 `ppl` 推理和 `gen` 推理;`qa` 支持 `gen` 推理。 + +## 数据集格式 + +我们支持 `.jsonl` 和 `.csv` 两种格式的数据集。 + +### 选择题 (`mcq`) + +对于选择 (`mcq`) 类型的数据,默认的字段如下: + +- `question`: 表示选择题的题干 +- `A`, `B`, `C`, ...: 使用单个大写字母表示选项,个数不限定。默认只会从 `A` 开始,解析连续的字母作为选项。 +- `answer`: 表示选择题的正确答案,其值必须是上述所选用的选项之一,如 `A`, `B`, `C` 等。 + +对于非默认字段,我们都会进行读入,但默认不会使用。如需使用,则需要在 `.meta.json` 文件中进行指定。 + +`.jsonl` 格式样例如下: + +```jsonl +{"question": "165+833+650+615=", "A": "2258", "B": "2263", "C": "2281", "answer": "B"} +{"question": "368+959+918+653+978=", "A": "3876", "B": "3878", "C": "3880", "answer": "A"} +{"question": "776+208+589+882+571+996+515+726=", "A": "5213", "B": "5263", "C": "5383", "answer": "B"} +{"question": "803+862+815+100+409+758+262+169=", "A": "4098", "B": "4128", "C": "4178", "answer": "C"} +``` + +`.csv` 格式样例如下: + +```csv +question,A,B,C,answer +127+545+588+620+556+199=,2632,2635,2645,B +735+603+102+335+605=,2376,2380,2410,B +506+346+920+451+910+142+659+850=,4766,4774,4784,C +504+811+870+445=,2615,2630,2750,B +``` + +### 问答题 (`qa`) + +对于问答 (`qa`) 类型的数据,默认的字段如下: + +- `question`: 表示问答题的题干 +- `answer`: 表示问答题的正确答案。可缺失,表示该数据集无正确答案。 + +对于非默认字段,我们都会进行读入,但默认不会使用。如需使用,则需要在 `.meta.json` 文件中进行指定。 + +`.jsonl` 格式样例如下: + +```jsonl +{"question": "752+361+181+933+235+986=", "answer": "3448"} +{"question": "712+165+223+711=", "answer": "1811"} +{"question": "921+975+888+539=", "answer": "3323"} +{"question": "752+321+388+643+568+982+468+397=", "answer": "4519"} +``` + +`.csv` 格式样例如下: + +```csv +question,answer +123+147+874+850+915+163+291+604=,3967 +149+646+241+898+822+386=,3142 +332+424+582+962+735+798+653+214=,4700 +649+215+412+495+220+738+989+452=,4170 +``` + +## 命令行列表 + +自定义数据集可直接通过命令行来调用开始评测。 + +```bash +python run.py \ + --models hf_llama2_7b \ + --custom-dataset-path xxx/test_mcq.csv \ + --custom-dataset-data-type mcq \ + --custom-dataset-infer-method ppl +``` + +```bash +python run.py \ + --models hf_llama2_7b \ + --custom-dataset-path xxx/test_qa.jsonl \ + --custom-dataset-data-type qa \ + --custom-dataset-infer-method gen +``` + +在绝大多数情况下,`--custom-dataset-data-type` 和 `--custom-dataset-infer-method` 可以省略,OpenCompass 会根据以下逻辑进行设置: + +- 如果从数据集文件中可以解析出选项,如 `A`, `B`, `C` 等,则认定该数据集为 `mcq`,否则认定为 `qa`。 +- 默认 `infer_method` 为 `gen`。 + +## 配置文件 + +在原配置文件中,直接向 `datasets` 变量中添加新的项即可即可。自定义数据集亦可与普通数据集混用。 + +```python +datasets = [ + {"path": "xxx/test_mcq.csv", "data_type": "mcq", "infer_method": "ppl"}, + {"path": "xxx/test_qa.jsonl", "data_type": "qa", "infer_method": "gen"}, +] +``` + +## 数据集补充信息 `.meta.json` + +OpenCompass 会默认尝试对输入的数据集文件进行解析,因此在绝大多数情况下,`.meta.json` 文件都是 **不需要** 的。但是,如果数据集的字段名不是默认的字段名,或者需要自定义提示词,则需要在 `.meta.json` 文件中进行指定。 + +我们会在数据集同级目录下,以文件名+`.meta.json` 的形式放置一个表征数据集使用方法的文件,样例文件结构如下: + +```tree +. +├── test_mcq.csv +├── test_mcq.csv.meta.json +├── test_qa.jsonl +└── test_qa.jsonl.meta.json +``` + +该文件可能字段如下: + +- `abbr` (str): 数据集缩写,作为该数据集的 ID。 +- `data_type` (str): 数据集类型,可选值为 `mcq` 和 `qa`. +- `infer_method` (str): 推理方法,可选值为 `ppl` 和 `gen`. +- `human_prompt` (str): 用户提示词模板,用于生成提示词。模板中的变量使用 `{}` 包裹,如 `{question}`,`{opt1}` 等。如存在 `template`,则该字段会被忽略。 +- `bot_prompt` (str): 机器人提示词模板,用于生成提示词。模板中的变量使用 `{}` 包裹,如 `{answer}` 等。如存在 `template`,则该字段会被忽略。 +- `template` (str or dict): 问题模板,用于生成提示词。模板中的变量使用 `{}` 包裹,如 `{question}`,`{opt1}` 等。相关语法见[此处](../prompt/prompt_template.md) 关于 `infer_cfg['prompt_template']['template']` 的内容。 +- `input_columns` (list): 输入字段列表,用于读入数据。 +- `output_column` (str): 输出字段,用于读入数据。 +- `options` (list): 选项列表,用于读入数据,仅在 `data_type` 为 `mcq` 时有效。 + +样例如下: + +```json +{ + "human_prompt": "Question: 127 + 545 + 588 + 620 + 556 + 199 =\nA. 2632\nB. 2635\nC. 2645\nAnswer: Let's think step by step, 127 + 545 + 588 + 620 + 556 + 199 = 672 + 588 + 620 + 556 + 199 = 1260 + 620 + 556 + 199 = 1880 + 556 + 199 = 2436 + 199 = 2635. So the answer is B.\nQuestion: {question}\nA. {A}\nB. {B}\nC. {C}\nAnswer: ", + "bot_prompt": "{answer}" +} +``` + +或者 + +```json +{ + "template": "Question: {my_question}\nX. {X}\nY. {Y}\nZ. {Z}\nW. {W}\nAnswer:", + "input_columns": ["my_question", "X", "Y", "Z", "W"], + "output_column": "my_answer", +} +``` diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 96beb240..49425830 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -58,6 +58,7 @@ OpenCompass 上手路线 :caption: 进阶教程 advanced_guides/new_dataset.md + advanced_guides/custom_dataset.md advanced_guides/new_model.md advanced_guides/evaluation_turbomind.md advanced_guides/evaluation_lightllm.md diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index 43d9dfdf..e2164ae8 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -27,6 +27,7 @@ from .copa import * # noqa: F401, F403 from .crowspairs import * # noqa: F401, F403 from .crowspairs_cn import * # noqa: F401, F403 from .csl import * # noqa: F401, F403 +from .custom import * # noqa: F401, F403 from .cvalues import * # noqa: F401, F403 from .drcd import * # noqa: F401, F403 from .drop import * # noqa: F401, F403 diff --git a/opencompass/datasets/cibench.py b/opencompass/datasets/cibench.py index 05b16387..e6f121f5 100644 --- a/opencompass/datasets/cibench.py +++ b/opencompass/datasets/cibench.py @@ -346,6 +346,8 @@ class CIBenchEvaluator(BaseEvaluator): def score(self, predictions: List, references: List, steps: List, origin_prompt: List): """Calculate accuracy.""" + if len(steps) != len(references): + return {'error': 'steps and refrs have different length'} cwd = os.getcwd() self.get_output_dir() if self.output_dir: diff --git a/opencompass/datasets/custom.py b/opencompass/datasets/custom.py new file mode 100644 index 00000000..e37bf6bc --- /dev/null +++ b/opencompass/datasets/custom.py @@ -0,0 +1,245 @@ +import csv +import json +import os + +from datasets import Dataset + +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.openicl.icl_inferencer import GenInferencer, PPLInferencer +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.registry import LOAD_DATASET +from opencompass.utils.text_postprocessors import first_option_postprocess + +from .base import BaseDataset + + +@LOAD_DATASET.register_module() +class CustomDataset(BaseDataset): + + @staticmethod + def load(path): + if path.endswith('.jsonl'): + with open(path, 'r', encoding='utf-8') as f: + data = [json.loads(line) for line in f] + elif path.endswith('.csv'): + with open(path, 'r', encoding='utf-8') as f: + reader = csv.reader(f) + header = next(reader) + data = [dict(zip(header, row)) for row in reader] + else: + raise ValueError(f'Unsupported file format: {path}') + + return Dataset.from_list(data) + + +def stringfy_types(obj): + for k, v in obj.items(): + if k == 'type': + obj[k] = f'{v.__module__}.{v.__name__}' + elif isinstance(v, dict): + stringfy_types(v) + return obj + + +def make_mcq_gen_config(meta): + if meta.get('template', None) is None: + _human_prompt = 'Question: {question}' + ''.join( + [f'\n{item}. {{{item}}}' for item in meta['options']]) + human_prompt = meta.get('human_prompt', _human_prompt) + _bot_prompt = f'Answer: {{{meta["output_column"]}}}' + bot_prompt = meta.get('bot_prompt', _bot_prompt) + template = dict(round=[ + dict(role='HUMAN', prompt=human_prompt), + dict(role='BOT', prompt=bot_prompt), + ]) + else: + template = meta['template'] + + reader_cfg = dict( + input_columns=meta['input_columns'], + output_column=meta['output_column'], + ) + infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=template, + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), + ) + + eval_cfg = dict(evaluator=dict(type=AccEvaluator), + pred_role='BOT', + pred_postprocessor=dict( + type=first_option_postprocess, + options=''.join(meta['options']), + )) + + dataset = dict( + abbr=meta['abbr'], + type=CustomDataset, + path=meta['path'], + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg, + ) + return dataset + + +def make_qa_gen_config(meta): + if meta.get('template', None) is None: + human_prompt = meta.get('human_prompt', '{question}') + if meta['output_column'] is None: + template = dict(round=[ + dict(role='HUMAN', prompt=human_prompt), + ]) + else: + bot_prompt = meta.get('bot_prompt', f'{{{meta["output_column"]}}}') + template = dict(round=[ + dict(role='HUMAN', prompt=human_prompt), + dict(role='BOT', prompt=bot_prompt), + ]) + else: + template = meta['template'] + + reader_cfg = dict( + input_columns=meta['input_columns'], + output_column=meta['output_column'], + ) + infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=template, + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), + ) + + eval_cfg = dict( + evaluator=dict(type=AccEvaluator), + pred_role='BOT', + ) + + dataset = dict( + abbr=meta['abbr'], + type=CustomDataset, + path=meta['path'], + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg, + ) + return dataset + + +def make_mcq_ppl_config(meta): + if meta.get('template', None) is None: + _human_prompt = 'Question: {question}' + ''.join( + [f'\n{item}. {{{item}}}' for item in meta['options']]) + human_prompt = meta.get('human_prompt', _human_prompt) + _bot_prompt = f'Answer: {{{meta["output_column"]}}}' + bot_prompt = meta.get('bot_prompt', _bot_prompt) + template = { + answer: dict(round=[ + dict(role='HUMAN', prompt=human_prompt), + dict(role='BOT', + prompt=bot_prompt.format( + **{meta['output_column']: answer})), + ], ) + for answer in meta['options'] + } + else: + template = meta['template'] + + reader_cfg = dict( + input_columns=meta['input_columns'], + output_column=meta['output_column'], + ) + infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=template, + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=PPLInferencer), + ) + + eval_cfg = dict(evaluator=dict(type=AccEvaluator)) + + dataset = dict( + abbr=meta['abbr'], + type=CustomDataset, + path=meta['path'], + reader_cfg=reader_cfg, + infer_cfg=infer_cfg, + eval_cfg=eval_cfg, + ) + return dataset + + +def parse_example_dataset(config): + # try to read meta json + path = config['path'] + meta_path = config.get('meta_path', path + '.meta.json') + if os.path.exists(meta_path): + with open(meta_path, 'r', encoding='utf-8') as f: + meta = json.load(f) + else: + meta = {} + + # load sample + if path.endswith('.jsonl'): + with open(path, 'r', encoding='utf-8') as f: + data_item = json.loads(f.readline()) + elif path.endswith('.csv'): + with open(path, 'r', encoding='utf-8') as f: + reader = csv.reader(f) + header = next(reader) + row = next(reader) + data_item = dict(zip(header, row)) + else: + raise ValueError(f'Unsupported ext: {path}, .jsonl or .csv required') + + meta['path'] = path + input_columns = [i for i in data_item.keys() if i != 'answer'] + meta.setdefault('input_columns', input_columns) + output_column = 'answer' if 'answer' in data_item else None + meta.setdefault('output_column', output_column) + options = [] + for i in range(26): + i = chr(ord('A') + i) + if i in data_item: + options.append(i) + else: + break + meta.setdefault('options', options) + abbr = os.path.basename(path).split('.')[0] + meta.setdefault('abbr', abbr) + + if 'data_type' in config: + meta.setdefault('data_type', config['data_type']) + else: + data_type = 'mcq' if len(options) > 1 else 'qa' + meta.setdefault('data_type', data_type) + if 'infer_method' in config: + meta.setdefault('infer_method', config['infer_method']) + else: + meta.setdefault('infer_method', 'gen') + + return meta + + +def make_custom_dataset_config(config): + # considered as a custom dataset + meta = parse_example_dataset(config) + make_config_func = { + ('mcq', 'gen'): make_mcq_gen_config, + ('mcq', 'ppl'): make_mcq_ppl_config, + ('qa', 'gen'): make_qa_gen_config, + }.get((meta['data_type'], meta['infer_method']), None) + if make_config_func is None: + raise ValueError(f'Unsupported dataset data_type: {meta["data_type"]}' + f' and infer_method: {meta["infer_method"]}') + dataset = make_config_func(meta) + dataset = stringfy_types(dataset) + return dataset diff --git a/opencompass/datasets/gsm8k.py b/opencompass/datasets/gsm8k.py index 1b10ea16..88e3fa2f 100644 --- a/opencompass/datasets/gsm8k.py +++ b/opencompass/datasets/gsm8k.py @@ -113,6 +113,8 @@ class Gsm8kAgentEvaluator(BaseEvaluator): def score(self, predictions, references, steps): """Calculate accuracy.""" + if len(predictions) != len(references): + return {'error': 'preds and refrs have different length'} row_reasoning_scope = 0 action_scope = 0 diff --git a/opencompass/datasets/math.py b/opencompass/datasets/math.py index 257a7772..faf8910a 100644 --- a/opencompass/datasets/math.py +++ b/opencompass/datasets/math.py @@ -340,6 +340,8 @@ class MATHAgentEvaluator(MATHEvaluator): def score(self, predictions, references, steps): """Calculate accuracy.""" + if len(predictions) != len(references): + return {'error': 'preds and refrs have different length'} row_reasoning_scope = 0 action_scope = 0 diff --git a/opencompass/openicl/icl_evaluator/icl_circular_evaluator.py b/opencompass/openicl/icl_evaluator/icl_circular_evaluator.py index fbc221d6..95b6b847 100644 --- a/opencompass/openicl/icl_evaluator/icl_circular_evaluator.py +++ b/opencompass/openicl/icl_evaluator/icl_circular_evaluator.py @@ -24,6 +24,8 @@ class CircularEvaluator(BaseEvaluator): Returns: dict: A dict of evaluation results. """ + if len(predictions) != len(references): + return {'error': 'preds and refrs have different length'} self._metrics = {} self._metrics.update({'acc_4': 0, 'acc_1': 0}) diff --git a/opencompass/summarizers/default.py b/opencompass/summarizers/default.py index 62460665..516d5467 100644 --- a/opencompass/summarizers/default.py +++ b/opencompass/summarizers/default.py @@ -19,6 +19,13 @@ from opencompass.utils.prompt import get_prompt_hash METRIC_WHITELIST = ['score', 'auc_score', 'accuracy', 'humaneval_pass@1', 'rouge1', 'avg_toxicity_score', 'bleurt_diff', 'matthews_correlation', 'truth'] METRIC_BLACKLIST = ['bp', 'sys_len', 'ref_len'] +def model_abbr_from_cfg_used_in_summarizer(model): + if model.get('summarizer_abbr', None): + return model['summarizer_abbr'] + else: + return model_abbr_from_cfg(model) + + class DefaultSummarizer: """Default summarizer in OpenCompass. @@ -49,7 +56,13 @@ class DefaultSummarizer: self.model_cfgs = self.cfg['models'] self.dataset_cfgs = self.cfg['datasets'] self.work_dir = self.cfg['work_dir'] - self.model_abbrs = [model_abbr_from_cfg(model) for model in self.model_cfgs] + model_abbrs = [] + for model in self.model_cfgs: + model_abbr = model_abbr_from_cfg_used_in_summarizer(model) + if model_abbr in model_abbrs: + continue + model_abbrs.append(model_abbr) + self.model_abbrs = model_abbrs def _pick_up_results(self): """The function reads the numerical results of evaluations from the @@ -71,9 +84,9 @@ class DefaultSummarizer: dataset_metrics : Dict[str, List[str]] = {} for model in self.model_cfgs: - model_abbr = model_abbr_from_cfg(model) - parsed_results[model_abbr] = {} - raw_results[model_abbr] = {} + model_abbr = model_abbr_from_cfg_used_in_summarizer(model) + parsed_results.setdefault(model_abbr, {}) + raw_results.setdefault(model_abbr, {}) for dataset in self.dataset_cfgs: dataset_abbr = dataset_abbr_from_cfg(dataset) filepath = get_infer_output_path(model, dataset, osp.join(self.work_dir, 'results')) @@ -165,23 +178,23 @@ class DefaultSummarizer: if all(isinstance(dataset_abbr, (list, tuple)) for dataset_abbr in sg['subsets']): group_metrics = [default_metric] for dataset_abbr, metric in sg['subsets']: - scores.setdefault(default_metric, {})[dataset_abbr] = parsed_results[model_abbr][dataset_abbr][metric] + scores.setdefault(default_metric, {})[dataset_abbr + '@' + metric] = parsed_results[model_abbr][dataset_abbr][metric] eval_modes.append(dataset_eval_mode.get(dataset_abbr, 'unknown')) else: group_metrics = list(functools.reduce(lambda a, b: a & b, [set(dataset_metrics[dataset_abbr]) for dataset_abbr in sg['subsets']])) if need_smart_metric and len(group_metrics) > 1: for metric in group_metrics: for dataset_abbr in sg['subsets']: - scores.setdefault(metric, {})[dataset_abbr] = parsed_results[model_abbr][dataset_abbr][metric] + scores.setdefault(metric, {})[dataset_abbr + '@' + metric] = parsed_results[model_abbr][dataset_abbr][metric] eval_modes.append(dataset_eval_mode.get(sg['subsets'][0], 'unknown')) else: group_metrics = [default_metric] for dataset_abbr in sg['subsets']: metric = dataset_metrics[dataset_abbr][0] - scores.setdefault(default_metric, {})[dataset_abbr] = parsed_results[model_abbr][dataset_abbr][metric] + scores.setdefault(default_metric, {})[dataset_abbr + '@' + metric] = parsed_results[model_abbr][dataset_abbr][metric] eval_modes.append(dataset_eval_mode.get(dataset_abbr, 'unknown')) - result = parsed_results[model_abbr].get(sg['name'], {}) + result = {} for metric in scores: if default_metric == 'standard_deviation': avg = sum(scores[metric].values()) / len(scores[metric]) @@ -190,7 +203,11 @@ class DefaultSummarizer: else: if sg.get('weights', []): # check sg['weights'][k] != 0 in case of scores[metric][k] is NaN - numerator = sum(scores[metric][k] * sg['weights'][k] for k in sg['weights'] if sg['weights'][k] != 0) + try: + numerator = sum(scores[metric][k] * sg['weights'][k] for k in sg['weights'] if sg['weights'][k] != 0) + except KeyError: + tmp_scores = {metric: {k.split('@')[0]: v for k, v in scores[metric].items()} for metric in scores} + numerator = sum(tmp_scores[metric][k] * sg['weights'][k] for k in sg['weights'] if sg['weights'][k] != 0) denominator = sum(sg['weights'].values()) else: numerator = sum(scores[metric].values()) @@ -200,9 +217,9 @@ class DefaultSummarizer: eval_mode = eval_modes[0] if len(eval_modes) == 1 else 'mixed' # add to global results - raw_results[model_abbr][sg['name']] = scores - parsed_results[model_abbr][sg['name']] = result - dataset_metrics[sg['name']] = group_metrics + raw_results[model_abbr].setdefault(sg['name'], {}).update(scores) + parsed_results[model_abbr].setdefault(sg['name'], {}).update(result) + dataset_metrics.setdefault(sg['name'], []).extend(group_metrics) dataset_eval_mode[sg['name']] = eval_mode return raw_results, parsed_results, dataset_metrics, dataset_eval_mode diff --git a/opencompass/tasks/openicl_eval.py b/opencompass/tasks/openicl_eval.py index 354154fb..59495917 100644 --- a/opencompass/tasks/openicl_eval.py +++ b/opencompass/tasks/openicl_eval.py @@ -198,7 +198,8 @@ class OpenICLEvalTask(BaseTask): 'incorrect_bpb'] = self.calculate_bpb(pred_dicts) else: result['incorrect_bpb'] = result['correct_bpb'] = -1 - except Exception: + except Exception as e: + self.logger.warning(f'Skip dumping details due to: {e}.') result['incorrect_bpb'] = result['correct_bpb'] = -1 else: result.pop('details', None) @@ -288,13 +289,19 @@ class OpenICLEvalTask(BaseTask): result['predictions'] = str(predictions[i]) result['references'] = str(references[i]) result['correct'] = str(predictions[i]) == str(references[i]) - else: + elif details is not None: results['type'] = 'GEN' result['prompt'] = origin_prediction['origin_prompt'] result['origin_prediction'] = pred_dicts[i]['prediction'] result['predictions'] = details[i]['pred'] result['references'] = details[i]['answer'] result['correct'] = details[i]['correct'] + else: + results['type'] = 'GEN' + result['prompt'] = origin_prediction['origin_prompt'] + result['origin_prediction'] = pred_dicts[i]['prediction'] + result['predictions'] = str(predictions[i]) + result['references'] = str(references[i]) results[str(i)] = result return results diff --git a/opencompass/utils/build.py b/opencompass/utils/build.py index b27a133d..240324e3 100644 --- a/opencompass/utils/build.py +++ b/opencompass/utils/build.py @@ -19,5 +19,6 @@ def build_model_from_cfg(model_cfg: ConfigDict): model_cfg.pop('max_out_len', None) model_cfg.pop('batch_size', None) model_cfg.pop('abbr', None) + model_cfg.pop('summarizer_abbr', None) model_cfg.pop('pred_postprocessor', None) return MODELS.build(model_cfg) diff --git a/opencompass/utils/run.py b/opencompass/utils/run.py index 0c5b1c3f..e1babda8 100644 --- a/opencompass/utils/run.py +++ b/opencompass/utils/run.py @@ -4,6 +4,7 @@ from typing import List, Union import tabulate from mmengine.config import Config +from opencompass.datasets.custom import make_custom_dataset_config from opencompass.partitioners import NaivePartitioner, SizePartitioner from opencompass.runners import DLCRunner, LocalRunner, SlurmRunner from opencompass.tasks import OpenICLEvalTask, OpenICLInferTask @@ -56,18 +57,37 @@ def get_config_from_arg(args) -> Config: 3. Huggingface parameter groups and args.datasets """ if args.config: - return Config.fromfile(args.config, format_python_code=False) - if args.datasets is None: - raise ValueError('You must specify "--datasets" if you do not specify ' - 'a config file path.') + config = Config.fromfile(args.config, format_python_code=False) + for i, dataset in enumerate(config['datasets']): + if 'type' not in dataset: + config['datasets'][i] = make_custom_dataset_config(dataset) + return config + # parse dataset args + if not args.datasets and not args.custom_dataset_path: + raise ValueError('You must specify "--datasets" or ' + '"--custom-dataset-path" if you do not specify a ' + 'config file path.') datasets = [] - datasets_dir = os.path.join(args.config_dir, 'datasets') - for dataset in match_cfg_file(datasets_dir, args.datasets): - get_logger().info(f'Loading {dataset[0]}: {dataset[1]}') - cfg = Config.fromfile(dataset[1]) - for k in cfg.keys(): - if k.endswith('_datasets'): - datasets += cfg[k] + if args.datasets: + datasets_dir = os.path.join(args.config_dir, 'datasets') + for dataset in match_cfg_file(datasets_dir, args.datasets): + get_logger().info(f'Loading {dataset[0]}: {dataset[1]}') + cfg = Config.fromfile(dataset[1]) + for k in cfg.keys(): + if k.endswith('_datasets'): + datasets += cfg[k] + else: + dataset = {'path': args.custom_dataset_path} + if args.custom_dataset_infer_method is not None: + dataset['infer_method'] = args.custom_dataset_infer_method + if args.custom_dataset_data_type is not None: + dataset['data_type'] = args.custom_dataset_data_type + if args.custom_dataset_meta_path is not None: + dataset['meta_path'] = args.custom_dataset_meta_path + dataset = make_custom_dataset_config(dataset) + datasets.append(dataset) + + # parse model args if not args.models and not args.hf_path: raise ValueError('You must specify a config file path, ' 'or specify --models and --datasets, or ' @@ -98,7 +118,7 @@ def get_config_from_arg(args) -> Config: pad_token_id=args.pad_token_id, run_cfg=dict(num_gpus=args.num_gpus)) models.append(model) - + # parse summarizer args summarizer = args.summarizer if args.summarizer is not None else 'example' summarizers_dir = os.path.join(args.config_dir, 'summarizers') s = match_cfg_file(summarizers_dir, [summarizer])[0] diff --git a/run.py b/run.py index fd323c58..a991fab4 100644 --- a/run.py +++ b/run.py @@ -138,6 +138,9 @@ def parse_args(): # set hf args hf_parser = parser.add_argument_group('hf_args') parse_hf_args(hf_parser) + # set custom dataset args + custom_dataset_parser = parser.add_argument_group('custom_dataset_args') + parse_custom_dataset_args(custom_dataset_parser) args = parser.parse_args() if args.slurm: assert args.partition is not None, ( @@ -199,6 +202,18 @@ def parse_hf_args(hf_parser): hf_parser.add_argument('--pad-token-id', type=int) +def parse_custom_dataset_args(custom_dataset_parser): + """These args are all for the quick construction of custom datasets.""" + custom_dataset_parser.add_argument('--custom-dataset-path', type=str) + custom_dataset_parser.add_argument('--custom-dataset-meta-path', type=str) + custom_dataset_parser.add_argument('--custom-dataset-data-type', + type=str, + choices=['mcq', 'qa']) + custom_dataset_parser.add_argument('--custom-dataset-infer-method', + type=str, + choices=['gen', 'ppl']) + + def main(): args = parse_args() if args.dry_run: