OpenCompass/opencompass/datasets/gaia.py

61 lines
2.3 KiB
Python
Raw Normal View History

2025-04-10 11:04:38 +08:00
import json
from os import environ
import os
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path
from opencompass.utils.datasets_info import DATASETS_MAPPING
2025-04-10 11:04:38 +08:00
from .base import BaseDataset
@LOAD_DATASET.register_module()
class GAIADataset(BaseDataset):
@staticmethod
def load(path, local_mode: bool = False):
rows = []
if environ.get('DATASET_SOURCE') == 'HF':
from datasets import load_dataset
try:
hf_id = DATASETS_MAPPING[path]['hf_id']
# 因为ModelScope的GAIA数据集读取存在问题所以从huggingface读取
ds = load_dataset(hf_id, '2023_all', split='validation')
rows = []
for item in ds:
rows.append({
'question': item['Question'],
'answerKey': item['Final answer'],
'file_path': item['file_path'],
'file_name': item['file_name'],
'level': item['Level']
})
except Exception as e:
print(f"Error loading local file: {e}")
else:
# 从本地读取
compass_data_cache = os.environ.get('COMPASS_DATA_CACHE')
local_path = DATASETS_MAPPING[path]['local']
local_path = os.path.join(compass_data_cache, local_path)
with open(local_path, 'r', encoding='utf-8') as f:
for line in f:
line = json.loads(line.strip())
# 构建数据行
row_data = {
'question': line['Question'],
'answerKey': line['Final answer'],
'file_name': line['file_name'],
'level': line['Level']
}
# 只有在file_name不为空时设置file_path
if line['file_name']:
row_data['file_path'] = f'{local_path}/{line["file_name"]}'
else:
row_data['file_path'] = ''
rows.append(row_data)
2025-04-10 11:04:38 +08:00
return Dataset.from_list(rows)