diff --git a/configs/datasets/gpqa/gpqa_gen.py b/configs/datasets/gpqa/gpqa_gen.py index f666825f..f1e8784f 100644 --- a/configs/datasets/gpqa/gpqa_gen.py +++ b/configs/datasets/gpqa/gpqa_gen.py @@ -1,4 +1,4 @@ from mmengine.config import read_base with read_base(): - from .gpqa_gen_4baadb import gpqa_datasets + from .gpqa_openai_simple_evals_gen_5aeece import gpqa_datasets diff --git a/configs/datasets/gpqa/gpqa_openai_simple_evals_gen_5aeece.py b/configs/datasets/gpqa/gpqa_openai_simple_evals_gen_5aeece.py new file mode 100644 index 00000000..a82b01d3 --- /dev/null +++ b/configs/datasets/gpqa/gpqa_openai_simple_evals_gen_5aeece.py @@ -0,0 +1,52 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import GPQADataset_Simple_Eval, GPQA_Simple_Eval_postprocess, GPQAEvaluator + +# openai_simple_eval prompt +align_prompt = """ +Answer the following multiple choice question. The last line of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. + +{question} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + +gpqa_reader_cfg = dict( + input_columns=['question', 'A', 'B', 'C', 'D'], + output_column='answer') + +gpqa_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt=align_prompt), + ], )), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer)) + +gpqa_eval_cfg = dict(evaluator=dict(type=GPQAEvaluator), + pred_postprocessor=dict(type=GPQA_Simple_Eval_postprocess)) + +gpqa_datasets = [] +gpqa_subsets = { + # 'extended': 'gpqa_extended.csv', + # 'main': 'gpqa_main.csv', + 'diamond': 'gpqa_diamond.csv' +} + +for split in list(gpqa_subsets.keys()): + gpqa_datasets.append( + dict( + abbr='GPQA_' + split, + type=GPQADataset_Simple_Eval, + path='./data/gpqa/', + name=gpqa_subsets[split], + reader_cfg=gpqa_reader_cfg, + infer_cfg=gpqa_infer_cfg, + eval_cfg=gpqa_eval_cfg) + ) diff --git a/opencompass/datasets/gpqa.py b/opencompass/datasets/gpqa.py index 20ba4c39..a4c88f37 100644 --- a/opencompass/datasets/gpqa.py +++ b/opencompass/datasets/gpqa.py @@ -1,10 +1,12 @@ import csv import os +import random +import re from datasets import Dataset from opencompass.openicl import BaseEvaluator -from opencompass.registry import LOAD_DATASET +from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS from .base import BaseDataset @@ -57,3 +59,53 @@ class GPQAEvaluator(BaseEvaluator): details.append(detail) result = {'accuracy': 100 * correct / count, 'details': details} return result + + +@LOAD_DATASET.register_module() +class GPQADataset_Simple_Eval(BaseDataset): + + @staticmethod + def load(path: str, name: str): + n_repeats = 4 + data = [] + with open(os.path.join(path, name), 'r', encoding='utf-8') as f: + reader = csv.reader(f, delimiter=',') + for row in reader: + if row[7] == 'Question': + continue + question = row[7] + # 第一个是正确选项 + options = [row[8], row[9], row[10], row[11]] + line = {'question': question} + line['answer'] = 'A' + line['options'] = options + data.append(line) + + data_list = data * n_repeats + rng = random.Random(0) + data_list = [ + data | { + 'permutation': rng.sample(range(4), 4) + } for data in data_list + ] + for entry in data_list: + options = entry['options'] + correct_options = [options[i] for i in entry['permutation']] + for i in range(4): + entry['ABCD'[i]] = correct_options[i] + correct_index = entry['permutation'].index(0) + correct_answer = 'ABCD'[correct_index] + entry['options'] = correct_options + entry['answer'] = correct_answer + + dataset = Dataset.from_list(data_list) + return dataset + + +@TEXT_POSTPROCESSORS.register_module() +def GPQA_Simple_Eval_postprocess(text: str) -> str: + ANSWER_PATTERN = r'(?i)ANSWER\s*:\s*([A-D])' + match = re.search(ANSWER_PATTERN, text) + if match: + return match.group(1) + return None