diff --git a/configs/datasets/gsm_hard/gsmhard_gen.py b/configs/datasets/gsm_hard/gsmhard_gen.py new file mode 100644 index 00000000..c03bd16c --- /dev/null +++ b/configs/datasets/gsm_hard/gsmhard_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .gsmhard_gen_8a1400 import gsmhard_datasets # noqa: F401, F403 diff --git a/configs/datasets/gsm_hard/gsmhard_gen_8a1400.py b/configs/datasets/gsm_hard/gsmhard_gen_8a1400.py new file mode 100644 index 00000000..abcef16c --- /dev/null +++ b/configs/datasets/gsm_hard/gsmhard_gen_8a1400.py @@ -0,0 +1,36 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import FixKRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import GSMHardDataset, mathbench_postprocess + +gsmhard_reader_cfg = dict(input_columns=['question'], output_column='answer') + +gsmhard_infer_cfg = dict( + ice_template=dict( + type=PromptTemplate, + template=dict( + begin="", + round=[ + dict(role='HUMAN', prompt="Question: {question}\nAnswer:"), + dict(role="BOT", prompt="The answer is {answer}"), + ], + ), + ice_token="", +), + + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=GenInferencer, max_out_len=512)) + +gsmhard_eval_cfg = dict(evaluator=dict(type=AccEvaluator), + pred_postprocessor=dict(type=mathbench_postprocess, name='en')) + +gsmhard_datasets = [ + dict( + abbr='gsm-hard', + type=GSMHardDataset, + path='./data/gsm-hard/test.jsonl', + reader_cfg=gsmhard_reader_cfg, + infer_cfg=gsmhard_infer_cfg, + eval_cfg=gsmhard_eval_cfg) +] diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index 306cee4a..70a4e52e 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -37,6 +37,7 @@ from .game24 import * # noqa: F401, F403 from .GaokaoBench import * # noqa: F401, F403 from .govrepcrs import * # noqa: F401, F403 from .gsm8k import * # noqa: F401, F403 +from .gsm_hard import * # noqa: F401, F403 from .hellaswag import * # noqa: F401, F403 from .huggingface import * # noqa: F401, F403 from .humaneval import * # noqa: F401, F403 diff --git a/opencompass/datasets/gsm_hard.py b/opencompass/datasets/gsm_hard.py new file mode 100644 index 00000000..5a3f31e8 --- /dev/null +++ b/opencompass/datasets/gsm_hard.py @@ -0,0 +1,24 @@ +import json + +from datasets import Dataset + +from opencompass.registry import LOAD_DATASET + +from .base import BaseDataset + + +@LOAD_DATASET.register_module() +class GSMHardDataset(BaseDataset): + + @staticmethod + def load(path): + dataset = [] + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = json.loads(line.strip()) + dataset.append({ + 'question': line['input'], + 'answer': str(line['target']) + }) + dataset = Dataset.from_list(dataset) + return dataset