2024-06-28 14:16:34 +08:00
|
|
|
# flake8: noqa
|
|
|
|
# yapf: disable
|
|
|
|
|
|
|
|
from datasets import load_dataset
|
|
|
|
|
2024-09-18 14:35:30 +08:00
|
|
|
from opencompass.openicl import BaseEvaluator
|
2024-06-28 14:16:34 +08:00
|
|
|
from opencompass.registry import LOAD_DATASET
|
2024-08-30 10:03:40 +08:00
|
|
|
from opencompass.utils import get_data_path
|
2024-06-28 14:16:34 +08:00
|
|
|
|
|
|
|
from .base import BaseDataset
|
|
|
|
|
2024-09-18 14:35:30 +08:00
|
|
|
CHOICES=['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P']
|
2024-06-28 14:16:34 +08:00
|
|
|
|
|
|
|
def _parse(item):
|
2024-09-18 14:35:30 +08:00
|
|
|
|
2024-06-28 14:16:34 +08:00
|
|
|
s = ''
|
2024-09-18 14:35:30 +08:00
|
|
|
item['answer_string'] = ''
|
2024-06-28 14:16:34 +08:00
|
|
|
for i, opt in enumerate(item['options']):
|
|
|
|
if opt == 'N/A':
|
|
|
|
continue
|
2024-09-18 14:35:30 +08:00
|
|
|
option = '{}. {}\n'.format(CHOICES[i], opt)
|
|
|
|
s += option
|
|
|
|
if item['answer'] == CHOICES[i]:
|
|
|
|
item['answer_string'] = option
|
|
|
|
|
2024-06-28 14:16:34 +08:00
|
|
|
item['options_str'] = s.strip()
|
|
|
|
item['cot_content'] = item['cot_content'].removeprefix("A: Let's think step by step.").strip()
|
|
|
|
return item
|
|
|
|
|
|
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
|
|
class MMLUProDataset(BaseDataset):
|
|
|
|
|
|
|
|
@staticmethod
|
2024-08-30 10:03:40 +08:00
|
|
|
def load(path: str, category: str):
|
|
|
|
path = get_data_path(path)
|
|
|
|
mmlu_pro = load_dataset(path)
|
2024-06-28 14:16:34 +08:00
|
|
|
mmlu_pro = mmlu_pro.filter(lambda x: x['category'] == category)
|
|
|
|
mmlu_pro = mmlu_pro.map(_parse)
|
|
|
|
return mmlu_pro
|
2024-09-18 14:35:30 +08:00
|
|
|
|
|
|
|
class MMLUProBaseEvaluator(BaseEvaluator):
|
|
|
|
|
|
|
|
def is_equal(self, pred, refer):
|
|
|
|
try:
|
|
|
|
refer_option, refer_string = refer.split('. ')
|
|
|
|
if pred in CHOICES and refer_option == pred:
|
|
|
|
return True
|
|
|
|
elif refer_string.strip() == pred:
|
|
|
|
return True
|
|
|
|
else :
|
|
|
|
return False
|
|
|
|
except Exception:
|
|
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
|
|
def score(self, predictions, references):
|
|
|
|
if len(predictions) != len(references):
|
|
|
|
return {
|
|
|
|
'error': 'predictions and references have different '
|
|
|
|
'length'
|
|
|
|
}
|
|
|
|
correct = 0
|
|
|
|
count = 0
|
|
|
|
details = []
|
|
|
|
for i, j in zip(predictions, references):
|
|
|
|
i = i.split('\n')[0].strip()
|
|
|
|
detail = {'pred': i, 'answer': j, 'correct': False}
|
|
|
|
count += 1
|
|
|
|
if self.is_equal(i, j):
|
|
|
|
correct += 1
|
|
|
|
detail['correct'] = True
|
|
|
|
details.append(detail)
|
|
|
|
result = {'accuracy': 100 * correct / count, 'details': details}
|
|
|
|
return result
|