2024-06-28 14:16:34 +08:00
|
|
|
# flake8: noqa
|
|
|
|
# yapf: disable
|
|
|
|
|
|
|
|
from datasets import load_dataset
|
|
|
|
|
|
|
|
from opencompass.registry import LOAD_DATASET
|
2024-08-30 10:03:40 +08:00
|
|
|
from opencompass.utils import get_data_path
|
2024-06-28 14:16:34 +08:00
|
|
|
|
|
|
|
from .base import BaseDataset
|
|
|
|
|
|
|
|
|
|
|
|
def _parse(item):
|
|
|
|
choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P']
|
|
|
|
s = ''
|
|
|
|
for i, opt in enumerate(item['options']):
|
|
|
|
if opt == 'N/A':
|
|
|
|
continue
|
|
|
|
s += '{}. {}\n'.format(choices[i], opt)
|
|
|
|
item['options_str'] = s.strip()
|
|
|
|
item['cot_content'] = item['cot_content'].removeprefix("A: Let's think step by step.").strip()
|
|
|
|
return item
|
|
|
|
|
|
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
|
|
class MMLUProDataset(BaseDataset):
|
|
|
|
|
|
|
|
@staticmethod
|
2024-08-30 10:03:40 +08:00
|
|
|
def load(path: str, category: str):
|
|
|
|
path = get_data_path(path)
|
|
|
|
mmlu_pro = load_dataset(path)
|
2024-06-28 14:16:34 +08:00
|
|
|
mmlu_pro = mmlu_pro.filter(lambda x: x['category'] == category)
|
|
|
|
mmlu_pro = mmlu_pro.map(_parse)
|
|
|
|
return mmlu_pro
|