OpenCompass/opencompass/datasets/mmlu.py

89 lines
3.2 KiB
Python
Raw Normal View History

2023-07-04 21:34:55 +08:00
import csv
import json
2023-07-04 21:34:55 +08:00
import os.path as osp
from datasets import Dataset, DatasetDict
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class MMLUDataset(BaseDataset):
@staticmethod
def load(path: str, name: str):
dataset = DatasetDict()
for split in ['dev', 'test']:
raw_data = []
filename = osp.join(path, split, f'{name}_{split}.csv')
with open(filename, encoding='utf-8') as f:
2023-07-04 21:34:55 +08:00
reader = csv.reader(f)
for row in reader:
assert len(row) == 6
raw_data.append({
'input': row[0],
'A': row[1],
'B': row[2],
'C': row[3],
'D': row[4],
'target': row[5],
})
dataset[split] = Dataset.from_list(raw_data)
return dataset
class MMLUDatasetClean(BaseDataset):
# load the contamination annotations of CEval from
# https://github.com/liyucheng09/Contamination_Detector
@staticmethod
def load_contamination_annotations(path, split='val'):
import requests
assert split == 'test', 'We only use test set for MMLU'
annotation_cache_path = osp.join(
path, split, f'MMLU_{split}_contamination_annotations.json')
if osp.exists(annotation_cache_path):
with open(annotation_cache_path, 'r') as f:
annotations = json.load(f)
return annotations
link_of_annotations = 'https://github.com/liyucheng09/Contamination_Detector/releases/download/v0.1.1rc2/mmlu_annotations.json' # noqa
annotations = json.loads(requests.get(link_of_annotations).text)
with open(annotation_cache_path, 'w') as f:
json.dump(annotations, f)
return annotations
@staticmethod
def load(path: str, name: str):
dataset = DatasetDict()
for split in ['dev', 'test']:
raw_data = []
filename = osp.join(path, split, f'{name}_{split}.csv')
if split == 'test':
annotations = MMLUDatasetClean.load_contamination_annotations(
path, split)
with open(filename, encoding='utf-8') as f:
reader = csv.reader(f)
for row_index, row in enumerate(reader):
assert len(row) == 6
item = {
'input': row[0],
'A': row[1],
'B': row[2],
'C': row[3],
'D': row[4],
'target': row[5],
}
if split == 'test':
row_id = f'{name} {row_index}'
if row_id in annotations:
is_clean = annotations[row_id][0]
else:
is_clean = 'not labeled'
item['is_clean'] = is_clean
raw_data.append(item)
dataset[split] = Dataset.from_list(raw_data)
return dataset