diff --git a/dataset-index.yml b/dataset-index.yml index 36e6847a..d43629df 100644 --- a/dataset-index.yml +++ b/dataset-index.yml @@ -539,6 +539,12 @@ paper: https://aclanthology.org/D19-1632.pdf configpath: opencompass/configs/datasets/flores/flores_gen.py configpath_llmjudge: '' +- gaia: + name: GAIA + category: Tool Utilization + paper: https://arxiv.org/abs/2311.12983 + configpath: opencompass/configs/datasets/GAIA/gaia_gen.py + configpath_llmjudge: '' - game24: name: Game24 category: Math diff --git a/opencompass/configs/datasets/GAIA/gaia_gen.py b/opencompass/configs/datasets/GAIA/gaia_gen.py new file mode 100644 index 00000000..839ba2f9 --- /dev/null +++ b/opencompass/configs/datasets/GAIA/gaia_gen.py @@ -0,0 +1,52 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import GAIADataset +from opencompass.utils.text_postprocessors import first_capital_postprocess + +gaia_reader_cfg = dict( + input_columns=['question', 'file_path'], + output_column='answerKey', + test_split='test') + +gaia_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role='HUMAN', + prompt= + '''You are a general AI assistant. I will ask you a question. Report your thoughts, and +finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. +YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated +list of numbers and/or strings. +If you are asked for a number, don’t use comma to write your number neither use units such as $ or +percent sign unless specified otherwise. +If you are asked for a string, don’t use articles, neither abbreviations (e.g. for cities), and write the +digits in plain text unless specified otherwise. +If you are asked for a comma separated list, apply the above rules depending of whether the element +to be put in the list is a number or a string.\nGAIA Question: {question}\nFile Path: {file_path}\n''' + ), + ]), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) +gaia_eval_cfg = dict( + evaluator=dict(type=AccEvaluator), + pred_role='BOT', + pred_postprocessor=dict(type=first_capital_postprocess), +) + +gaia_datasets = [ + dict( + abbr='gaia-validation', + type=GAIADataset, + path='opencompass/gaia', + local_mode=False, + reader_cfg=gaia_reader_cfg, + infer_cfg=gaia_infer_cfg, + eval_cfg=gaia_eval_cfg, + ) +] diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index b1753221..672e9684 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -51,6 +51,7 @@ from .ds1000_interpreter import * # noqa: F401, F403 from .eprstmt import * # noqa: F401, F403 from .FinanceIQ import * # noqa: F401, F403 from .flores import * # noqa: F401, F403 +from .gaia import * # noqa: F401, F403 from .game24 import * # noqa: F401, F403 from .gaokao_math import * # noqa: F401, F403 from .GaokaoBench import * # noqa: F401, F403 diff --git a/opencompass/datasets/gaia.py b/opencompass/datasets/gaia.py new file mode 100644 index 00000000..841dffc9 --- /dev/null +++ b/opencompass/datasets/gaia.py @@ -0,0 +1,60 @@ +import json +import os +from os import environ + +from datasets import Dataset + +from opencompass.registry import LOAD_DATASET +from opencompass.utils.datasets_info import DATASETS_MAPPING + +from .base import BaseDataset + + +@LOAD_DATASET.register_module() +class GAIADataset(BaseDataset): + + @staticmethod + def load(path, local_mode: bool = False): + rows = [] + if environ.get('DATASET_SOURCE') == 'HF': + from datasets import load_dataset + try: + hf_id = DATASETS_MAPPING[path]['hf_id'] + # 因为ModelScope的GAIA数据集读取存在问题,所以从huggingface读取 + ds = load_dataset(hf_id, '2023_all', split='validation') + rows = [] + for item in ds: + rows.append({ + 'question': item['Question'], + 'answerKey': item['Final answer'], + 'file_path': item['file_path'], + 'file_name': item['file_name'], + 'level': item['Level'] + }) + except Exception as e: + print(f'Error loading local file: {e}') + else: + # 从本地读取 + compass_data_cache = os.environ.get('COMPASS_DATA_CACHE') + local_path = DATASETS_MAPPING[path]['local'] + local_path = os.path.join(compass_data_cache, local_path) + with open(local_path, 'r', encoding='utf-8') as f: + for line in f: + line = json.loads(line.strip()) + # 构建数据行 + row_data = { + 'question': line['Question'], + 'answerKey': line['Final answer'], + 'file_name': line['file_name'], + 'level': line['Level'] + } + + # 只有在file_name不为空时设置file_path + if line['file_name']: + file_name = line['file_name'] + row_data['file_path'] = f'{local_path}/{file_name}' + else: + row_data['file_path'] = '' + + rows.append(row_data) + return Dataset.from_list(rows) diff --git a/opencompass/utils/datasets_info.py b/opencompass/utils/datasets_info.py index af814eb8..f51ec674 100644 --- a/opencompass/utils/datasets_info.py +++ b/opencompass/utils/datasets_info.py @@ -1,4 +1,10 @@ DATASETS_MAPPING = { + # GAIA Datasets + "opencompass/gaia": { + "ms_id": None, + "hf_id": "gaia-benchmark/GAIA", + "local": "./data/gaia/2023/validation/metadata.jsonl", + }, # ADVGLUE Datasets "opencompass/advglue-dev": { "ms_id": None,