OpenCompass/opencompass/datasets/supergpqa/supergpqa.py

185 lines
7.1 KiB
Python
Raw Normal View History

2025-03-07 17:36:00 +08:00
import os
2025-03-11 17:32:35 +08:00
from datasets import Dataset, load_dataset
from opencompass.datasets.supergpqa.supergpqa_eval import (
extract_option_content, extract_option_labels)
from opencompass.datasets.supergpqa.supergpqa_utils import load_yaml
2025-03-07 17:36:00 +08:00
from opencompass.openicl.icl_evaluator import BaseEvaluator
2025-03-11 17:07:47 +08:00
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
2025-03-07 17:36:00 +08:00
from opencompass.utils import get_data_path
2025-03-11 17:32:35 +08:00
2025-03-07 17:36:00 +08:00
from ..base import BaseDataset
def _parse(item, template, prompt_mode):
2025-03-11 17:07:47 +08:00
prompt_format = [
2025-03-11 17:32:35 +08:00
item['question'] + '\n' + '\n'.join([
f'{chr(65+i)}) {option}'
for i, option in enumerate(item['options'])
])
2025-03-11 17:07:47 +08:00
]
2025-03-07 17:36:00 +08:00
item['infer_prompt'] = template['prompt_format'][0].format(*prompt_format)
item['prompt_mode'] = prompt_mode
return item
@LOAD_DATASET.register_module()
class SuperGPQADataset(BaseDataset):
2025-03-11 17:32:35 +08:00
2025-03-07 17:36:00 +08:00
@staticmethod
2025-03-11 17:07:47 +08:00
def load(path: str, prompt_mode: str, **kwargs):
path = get_data_path(path, local_mode=True)
2025-03-07 17:36:00 +08:00
dataset = load_dataset(path, split='train')
2025-03-11 17:07:47 +08:00
# get prompt template
2025-03-07 17:36:00 +08:00
template_path = None
if prompt_mode == 'zero-shot':
template_path = os.path.join(
os.path.dirname(__file__),
2025-03-11 17:07:47 +08:00
'supergpqa_dataset_config/prompt/zero-shot.yaml',
)
2025-03-07 17:36:00 +08:00
elif prompt_mode == 'five-shot':
template_path = os.path.join(
os.path.dirname(__file__),
2025-03-11 17:07:47 +08:00
'supergpqa_dataset_config/prompt/five-shot.yaml',
)
2025-03-07 17:36:00 +08:00
try:
template = load_yaml(template_path)
except FileNotFoundError:
print(f'[ERROR] Missing prompt template: {template_path}')
return Dataset.from_list([])
2025-03-11 17:07:47 +08:00
dataset = dataset.map(lambda item: _parse(item, template, prompt_mode))
2025-03-07 17:36:00 +08:00
return dataset
2025-03-11 17:07:47 +08:00
2025-03-07 17:36:00 +08:00
@ICL_EVALUATORS.register_module()
class SuperGPQAEvaluator(BaseEvaluator):
def __init__(self):
super().__init__()
2025-03-11 17:07:47 +08:00
def score(self, predictions, references, test_set):
mode = test_set[0]['prompt_mode']
2025-03-07 17:36:00 +08:00
acc = 0
count = 0
err = 0
miss = 0
2025-03-11 17:32:35 +08:00
acc_difficulty = {'hard': 0, 'middle': 0, 'easy': 0}
count_difficulty = {'hard': 0, 'middle': 0, 'easy': 0}
2025-03-11 17:07:47 +08:00
stats = {'discipline': {}, 'field': {}, 'subfield': {}}
details = []
for i, sample in enumerate(test_set):
2025-03-11 17:32:35 +08:00
sample['pred'] = prediction = predictions[i]
2025-03-11 17:07:47 +08:00
gold = references[i]
2025-03-07 17:36:00 +08:00
if mode == 'zero-shot':
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
2025-03-11 17:32:35 +08:00
if predict is None:
predict = extract_option_content(prediction,
sample['options'])
predict = (chr(sample['options'].index(predict) +
65) if predict else None)
sample['extracted_answer'] = predict
2025-03-07 17:36:00 +08:00
elif mode == 'five-shot':
response = prediction.split('Question:')[0]
predict = extract_option_labels(response, 'ABCDEFGHIJ')
2025-03-11 17:32:35 +08:00
if predict is None:
predict = extract_option_content(response,
sample['options'])
predict = (chr(sample['options'].index(predict) +
65) if predict else None)
if predict is None:
2025-03-07 17:36:00 +08:00
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
2025-03-11 17:32:35 +08:00
if predict is None:
2025-03-11 17:07:47 +08:00
predict = extract_option_content(
2025-03-11 17:32:35 +08:00
prediction, sample['options'])
predict = (chr(sample['options'].index(predict) +
65) if predict else None)
sample['extracted_answer'] = predict
discipline = sample.get('discipline', 'unknown')
field = sample.get('field', 'unknown')
subfield = sample.get('subfield', 'unknown')
difficulty = sample.get('difficulty', 'unknown')
2025-03-11 17:07:47 +08:00
2025-03-07 17:36:00 +08:00
for level, key in [
('discipline', discipline),
2025-03-11 17:32:35 +08:00
# ('field', f"{discipline}/{field}"),
# ('subfield', f"{discipline}/{field}/{subfield}"),
2025-03-07 17:36:00 +08:00
]:
if key not in stats[level]:
stats[level][key] = {
2025-03-11 17:32:35 +08:00
'correct': 0,
'total': 0,
'miss': 0,
'error': 0,
'discipline': discipline,
'field': field,
'subfield': subfield,
'difficulty': {
'easy': {
'correct': 0,
'total': 0
},
'middle': {
'correct': 0,
'total': 0
},
'hard': {
'correct': 0,
'total': 0
},
2025-03-11 17:07:47 +08:00
},
2025-03-07 17:36:00 +08:00
}
2025-03-11 17:07:47 +08:00
2025-03-11 17:32:35 +08:00
stats[level][key]['total'] += 1
stats[level][key]['difficulty'][difficulty]['total'] += 1
2025-03-11 17:07:47 +08:00
2025-03-11 17:32:35 +08:00
answer_letter = sample['answer_letter']
2025-03-11 17:07:47 +08:00
assert answer_letter == gold
2025-03-07 17:36:00 +08:00
if predict and answer_letter == predict:
acc += 1
acc_difficulty[difficulty] += 1
2025-03-11 17:32:35 +08:00
sample['status'] = 'correct'
stats[level][key]['correct'] += 1
stats[level][key]['difficulty'][difficulty]['correct'] += 1
elif predict == None or predict == '':
2025-03-07 17:36:00 +08:00
miss += 1
2025-03-11 17:32:35 +08:00
sample['status'] = 'miss'
stats[level][key]['miss'] += 1
2025-03-07 17:36:00 +08:00
elif predict == 'error':
err += 1
2025-03-11 17:32:35 +08:00
sample['status'] = 'error'
stats[level][key]['error'] += 1
2025-03-07 17:36:00 +08:00
else:
2025-03-11 17:32:35 +08:00
sample['status'] = 'incorrect'
2025-03-07 17:36:00 +08:00
count += 1
count_difficulty[difficulty] += 1
2025-03-11 17:32:35 +08:00
details.append({
'pred': sample['pred'],
'answer': sample['answer'],
'parsed_answer': sample['extracted_answer'],
'correct': True if sample['status'] else False,
})
2025-03-11 17:07:47 +08:00
2025-03-07 17:36:00 +08:00
return {
2025-03-11 17:32:35 +08:00
'accuracy':
acc / count if count > 0 else 0,
'error_rate':
err / count if count > 0 else 0,
'miss_rate':
miss / count if count > 0 else 0,
'hard_accuracy':
(acc_difficulty['hard'] /
count_difficulty['hard'] if count_difficulty['hard'] > 0 else 0),
'middle_accuracy':
(acc_difficulty['middle'] / count_difficulty['middle']
if count_difficulty['middle'] > 0 else 0),
'easy_accuracy':
(acc_difficulty['easy'] /
count_difficulty['easy'] if count_difficulty['easy'] > 0 else 0),
'details':
details,
2025-03-11 17:07:47 +08:00
}