diff --git a/configs/datasets/wikibench/wikibench_gen.py b/configs/datasets/wikibench/wikibench_gen.py new file mode 100644 index 00000000..006b2ded --- /dev/null +++ b/configs/datasets/wikibench/wikibench_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .wikibench_gen_f96ece import wikibench_datasets # noqa: F401, F403 diff --git a/configs/datasets/wikibench/wikibench_gen_f96ece.py b/configs/datasets/wikibench/wikibench_gen_f96ece.py new file mode 100644 index 00000000..08a096c9 --- /dev/null +++ b/configs/datasets/wikibench/wikibench_gen_f96ece.py @@ -0,0 +1,56 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import CircularEvaluator, AccEvaluator +from opencompass.datasets import WikiBenchDataset +from opencompass.utils.text_postprocessors import first_option_postprocess + + +single_choice_prompts = { + "single_choice_cn": "以下是一道单项选择题,请你根据你了解的知识给出正确的答案选项。\n下面是你要回答的题目:\n{question}\n答案选项:", +} + +wikibench_sets = { + "wiki": ["single_choice_cn"], +} + +do_circular = True + +wikibench_datasets = [] + +for _split in list(wikibench_sets.keys()): + for _name in wikibench_sets[_split]: + wikibench_infer_cfg = dict( + ice_template=dict( + type=PromptTemplate, + template=dict( + begin="", + round=[ + dict(role="HUMAN", prompt=single_choice_prompts[_name]), + dict(role="BOT", prompt="{answer}"), + ], + ), + ice_token="", + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), + ) + wikibench_eval_cfg = dict( + evaluator=dict(type=CircularEvaluator if do_circular else AccEvaluator), + pred_postprocessor=dict(type=first_option_postprocess, options="ABCD"), + ) + + wikibench_datasets.append( + dict( + type=WikiBenchDataset, + path=f"./data/WikiBench/{_name}.jsonl", + name="circular_" + _name if do_circular else _name, + abbr="wikibench-" + _split + "-" + _name + "circular" if do_circular else "", + reader_cfg=dict( + input_columns=["question"], + output_column="answer", + ), + infer_cfg=wikibench_infer_cfg, + eval_cfg=wikibench_eval_cfg, + ) + ) diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index effbbf41..c9acb7a6 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -86,6 +86,7 @@ from .triviaqarc import * # noqa: F401, F403 from .truthfulqa import * # noqa: F401, F403 from .tydiqa import * # noqa: F401, F403 from .wic import * # noqa: F401, F403 +from .wikibench import * # noqa: F401, F403 from .winograd import * # noqa: F401, F403 from .winogrande import * # noqa: F401, F403 from .wnli import wnliDataset # noqa: F401, F403 diff --git a/opencompass/datasets/wikibench.py b/opencompass/datasets/wikibench.py new file mode 100644 index 00000000..6ca06e68 --- /dev/null +++ b/opencompass/datasets/wikibench.py @@ -0,0 +1,62 @@ +import copy +import json + +from datasets import Dataset + +from opencompass.registry import LOAD_DATASET + +from .base import BaseDataset + + +def get_number(options): + + result_string = '' + for i, option in enumerate(options, start=65): + result_string += f'{chr(i)}. {option}\n' + return result_string + + +@LOAD_DATASET.register_module() +class WikiBenchDataset(BaseDataset): + + @staticmethod + def load(path: str, name: str): + + circular_patterns = ['ABCD', 'BCDA', 'CDAB', 'DABC'] + + data = [] + with open(path, 'r') as infile: + for id, line in enumerate(infile): + entry = json.loads(line) + if 'cloze' in name: + data.append({ + 'question': entry['question'].strip(), + 'answer': entry['answer'].strip() + }) + elif 'circular' in name: + for c in circular_patterns: + line = copy.deepcopy(entry) + options = [] + for i in range(4): + options.append(line['options'][ord(c[i]) - + ord('A')]) + line['options'] = options + line['answer'] = { + c[0]: 'A', + c[1]: 'B', + c[2]: 'C', + c[3]: 'D' + }[line['answer']] + line['answer'] = str( + id) + '--' + line['answer'] + '--' + c + line['question'] = line['question'].strip( + ) + '\n' + get_number(line['options']) + data.append(line) + else: + # treat as normal single choice question + entry['question'] = entry['question'].strip( + ) + '\n' + get_number(entry['options']) + data.append(entry) + + dataset = Dataset.from_list(data) + return dataset