diff --git a/configs/datasets/cmb/cmb_gen.py b/configs/datasets/cmb/cmb_gen.py index 5d379ad1..39279d8c 100644 --- a/configs/datasets/cmb/cmb_gen.py +++ b/configs/datasets/cmb/cmb_gen.py @@ -1,4 +1,4 @@ from mmengine.config import read_base with read_base(): - from .cmb_gen_72cbb7 import cmb_datasets # noqa: F401, F403 + from .cmb_gen_dfb5c4 import cmb_datasets # noqa: F401, F403 diff --git a/configs/datasets/cmb/cmb_gen_72cbb7.py b/configs/datasets/cmb/cmb_gen_72cbb7.py deleted file mode 100644 index 48729b9f..00000000 --- a/configs/datasets/cmb/cmb_gen_72cbb7.py +++ /dev/null @@ -1,43 +0,0 @@ -from opencompass.openicl.icl_prompt_template import PromptTemplate -from opencompass.openicl.icl_retriever import FixKRetriever -from opencompass.openicl.icl_inferencer import GenInferencer -from opencompass.datasets import CMBDataset - - -cmb_datasets = [] - -cmb_reader_cfg = dict( - input_columns=["exam_type", "exam_class", "question_type", "question", "option_str"], - output_column=None, - train_split="val", - test_split="test" -) - -cmb_infer_cfg = dict( - ice_template=dict( - type=PromptTemplate, - template=dict( - begin="", - round=[ - dict( - role="HUMAN", - prompt=f"以下是中国{{exam_type}}中{{exam_class}}考试的一道{{question_type}},不需要做任何分析和解释,直接输出答案选项。\n{{question}}\n{{option_str}} \n 答案: ", - ), - dict(role="BOT", prompt="{answer}"), - ], - ), - ice_token="", - ), - retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), - inferencer=dict(type=GenInferencer), -) - -cmb_datasets.append( - dict( - type=CMBDataset, - path="./data/CMB/", - abbr="cmb", - reader_cfg=cmb_reader_cfg, - infer_cfg=cmb_infer_cfg - ) -) \ No newline at end of file diff --git a/configs/datasets/cmb/cmb_gen_dfb5c4.py b/configs/datasets/cmb/cmb_gen_dfb5c4.py new file mode 100644 index 00000000..2547010d --- /dev/null +++ b/configs/datasets/cmb/cmb_gen_dfb5c4.py @@ -0,0 +1,49 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import CMBDataset +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.utils.text_postprocessors import multiple_select_postprocess + + +cmb_datasets = [] +for split in ["val", "test"]: + cmb_reader_cfg = dict( + input_columns=["exam_type", "exam_class", "question_type", "question", "option_str"], + output_column="answer", + train_split=split, + test_split=split, + ) + + cmb_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict( + role="HUMAN", + prompt=f"以下是中国{{exam_type}}中{{exam_class}}考试的一道{{question_type}},不需要做任何分析和解释,直接输出答案选项。\n{{question}}\n{{option_str}} \n 答案: ", + ), + dict(role="BOT", prompt="{answer}"), + ], + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=10), + ) + + cmb_eval_cfg = dict( + evaluator=dict(type=AccEvaluator), + pred_postprocessor=dict(type=multiple_select_postprocess), + ) + + cmb_datasets.append( + dict( + abbr="cmb" if split == "val" else "cmb_test", + type=CMBDataset, + path="./data/CMB/", + reader_cfg=cmb_reader_cfg, + infer_cfg=cmb_infer_cfg, + eval_cfg=cmb_eval_cfg, + ) + ) diff --git a/opencompass/datasets/cmb.py b/opencompass/datasets/cmb.py index 684c88f5..f2dd321c 100644 --- a/opencompass/datasets/cmb.py +++ b/opencompass/datasets/cmb.py @@ -13,18 +13,19 @@ class CMBDataset(BaseDataset): @staticmethod def load(path: str): - with open(osp.join(path, 'test.json'), 'r', encoding='utf-8') as f: - test_data = json.load(f) with open(osp.join(path, 'val.json'), 'r', encoding='utf-8') as f: val_data = json.load(f) - - for da in test_data: - da['option_str'] = '\n'.join( - [f'{k}. {v}' for k, v in da['option'].items() if len(v) > 1]) - for da in val_data: - da['option_str'] = '\n'.join( - [f'{k}. {v}' for k, v in da['option'].items() if len(v) > 1]) - - test_dataset = Dataset.from_list(test_data) + for d in val_data: + d['option_str'] = '\n'.join( + [f'{k}. {v}' for k, v in d['option'].items() if len(v) > 1]) + d['answer'] = 'NULL' val_dataset = Dataset.from_list(val_data) - return DatasetDict({'test': test_dataset, 'val': val_dataset}) + + with open(osp.join(path, 'test.json'), 'r', encoding='utf-8') as f: + test_data = json.load(f) + for d in test_data: + d['option_str'] = '\n'.join( + [f'{k}. {v}' for k, v in d['option'].items() if len(v) > 1]) + test_dataset = Dataset.from_list(test_data) + + return DatasetDict({'val': val_dataset, 'test': test_dataset}) diff --git a/opencompass/utils/text_postprocessors.py b/opencompass/utils/text_postprocessors.py index ec668f4d..f36da458 100644 --- a/opencompass/utils/text_postprocessors.py +++ b/opencompass/utils/text_postprocessors.py @@ -98,3 +98,9 @@ def first_number_postprocess(text: str) -> float: # if a match is found, return it. Otherwise, return None. return float(match.group(1)) if match else None + + +@TEXT_POSTPROCESSORS.register_module('multiple-select') +def multiple_select_postprocess(text: str) -> str: + ret = set([t for t in text if t.isupper()]) + return ''.join(sorted(ret))