2023-09-27 21:18:48 +08:00
|
|
|
import csv
|
2023-07-05 09:01:25 +08:00
|
|
|
import os.path as osp
|
|
|
|
|
2023-09-27 21:18:48 +08:00
|
|
|
from datasets import Dataset, DatasetDict
|
2023-07-05 09:01:25 +08:00
|
|
|
|
|
|
|
from opencompass.registry import LOAD_DATASET
|
|
|
|
|
|
|
|
from .base import BaseDataset
|
|
|
|
|
|
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
|
|
class CEvalDataset(BaseDataset):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load(path: str, name: str):
|
2023-09-27 21:18:48 +08:00
|
|
|
dataset = {}
|
|
|
|
for split in ['dev', 'val', 'test']:
|
|
|
|
with open(osp.join(path, split, f'{name}_{split}.csv')) as f:
|
|
|
|
reader = csv.reader(f)
|
|
|
|
header = next(reader)
|
|
|
|
for row in reader:
|
|
|
|
item = dict(zip(header, row))
|
|
|
|
item.setdefault('explanation', '')
|
|
|
|
item.setdefault('answer', '')
|
|
|
|
dataset.setdefault(split, []).append(item)
|
|
|
|
dataset = {i: Dataset.from_list(dataset[i]) for i in dataset}
|
|
|
|
return DatasetDict(dataset)
|