OpenCompass/opencompass/datasets/supergpqa/supergpqa.py

185 lines
7.1 KiB
Python
Raw Normal View History

import os
from datasets import Dataset, load_dataset
from opencompass.datasets.supergpqa.supergpqa_eval import (
extract_option_content, extract_option_labels)
from opencompass.datasets.supergpqa.supergpqa_utils import load_yaml
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from opencompass.utils import get_data_path
from ..base import BaseDataset
def _parse(item, template, prompt_mode):
prompt_format = [
item['question'] + '\n' + '\n'.join([
f'{chr(65+i)}) {option}'
for i, option in enumerate(item['options'])
])
]
item['infer_prompt'] = template['prompt_format'][0].format(*prompt_format)
item['prompt_mode'] = prompt_mode
return item
@LOAD_DATASET.register_module()
class SuperGPQADataset(BaseDataset):
@staticmethod
def load(path: str, prompt_mode: str, **kwargs):
path = get_data_path(path, local_mode=True)
dataset = load_dataset(path, split='train')
# get prompt template
template_path = None
if prompt_mode == 'zero-shot':
template_path = os.path.join(
os.path.dirname(__file__),
'supergpqa_dataset_config/prompt/zero-shot.yaml',
)
elif prompt_mode == 'five-shot':
template_path = os.path.join(
os.path.dirname(__file__),
'supergpqa_dataset_config/prompt/five-shot.yaml',
)
try:
template = load_yaml(template_path)
except FileNotFoundError:
print(f'[ERROR] Missing prompt template: {template_path}')
return Dataset.from_list([])
dataset = dataset.map(lambda item: _parse(item, template, prompt_mode))
return dataset
@ICL_EVALUATORS.register_module()
class SuperGPQAEvaluator(BaseEvaluator):
def __init__(self):
super().__init__()
def score(self, predictions, references, test_set):
mode = test_set[0]['prompt_mode']
acc = 0
count = 0
err = 0
miss = 0
acc_difficulty = {'hard': 0, 'middle': 0, 'easy': 0}
count_difficulty = {'hard': 0, 'middle': 0, 'easy': 0}
stats = {'discipline': {}, 'field': {}, 'subfield': {}}
details = []
for i, sample in enumerate(test_set):
sample['pred'] = prediction = predictions[i]
gold = references[i]
if mode == 'zero-shot':
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict is None:
predict = extract_option_content(prediction,
sample['options'])
predict = (chr(sample['options'].index(predict) +
65) if predict else None)
sample['extracted_answer'] = predict
elif mode == 'five-shot':
response = prediction.split('Question:')[0]
predict = extract_option_labels(response, 'ABCDEFGHIJ')
if predict is None:
predict = extract_option_content(response,
sample['options'])
predict = (chr(sample['options'].index(predict) +
65) if predict else None)
if predict is None:
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict is None:
predict = extract_option_content(
prediction, sample['options'])
predict = (chr(sample['options'].index(predict) +
65) if predict else None)
sample['extracted_answer'] = predict
discipline = sample.get('discipline', 'unknown')
field = sample.get('field', 'unknown')
subfield = sample.get('subfield', 'unknown')
difficulty = sample.get('difficulty', 'unknown')
for level, key in [
('discipline', discipline),
# ('field', f"{discipline}/{field}"),
# ('subfield', f"{discipline}/{field}/{subfield}"),
]:
if key not in stats[level]:
stats[level][key] = {
'correct': 0,
'total': 0,
'miss': 0,
'error': 0,
'discipline': discipline,
'field': field,
'subfield': subfield,
'difficulty': {
'easy': {
'correct': 0,
'total': 0
},
'middle': {
'correct': 0,
'total': 0
},
'hard': {
'correct': 0,
'total': 0
},
},
}
stats[level][key]['total'] += 1
stats[level][key]['difficulty'][difficulty]['total'] += 1
answer_letter = sample['answer_letter']
assert answer_letter == gold
if predict and answer_letter == predict:
acc += 1
acc_difficulty[difficulty] += 1
sample['status'] = 'correct'
stats[level][key]['correct'] += 1
stats[level][key]['difficulty'][difficulty]['correct'] += 1
elif predict is None or predict == '':
miss += 1
sample['status'] = 'miss'
stats[level][key]['miss'] += 1
elif predict == 'error':
err += 1
sample['status'] = 'error'
stats[level][key]['error'] += 1
else:
sample['status'] = 'incorrect'
count += 1
count_difficulty[difficulty] += 1
details.append({
'pred': sample['pred'],
'answer': sample['answer'],
'parsed_answer': sample['extracted_answer'],
'correct': True if sample['status'] else False,
})
return {
'accuracy':
acc / count if count > 0 else 0,
'error_rate':
err / count if count > 0 else 0,
'miss_rate':
miss / count if count > 0 else 0,
'hard_accuracy':
(acc_difficulty['hard'] /
count_difficulty['hard'] if count_difficulty['hard'] > 0 else 0),
'middle_accuracy':
(acc_difficulty['middle'] / count_difficulty['middle']
if count_difficulty['middle'] > 0 else 0),
'easy_accuracy':
(acc_difficulty['easy'] /
count_difficulty['easy'] if count_difficulty['easy'] > 0 else 0),
'details':
details,
}