OpenCompass/opencompass/datasets/supergpqa/supergpqa.py
2025-03-11 09:07:47 +00:00

206 lines
7.4 KiB
Python

import csv
import json
import os.path as osp
from os import environ
from datasets import load_dataset
import os
from datasets import Dataset, DatasetDict
from opencompass.datasets.supergpqa.supergpqa_utils import (
evaluate_responses,
find_file,
load_json_or_jsonl,
load_json_or_jsonl_with_idx,
load_yaml,
)
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
import unittest
from opencompass.utils import get_data_path
from opencompass.datasets.supergpqa.supergpqa_eval import (
extract_option_labels,
extract_option_content,
)
from ..base import BaseDataset
def _parse(item, template, prompt_mode):
prompt_format = [
item['question']
+ '\n'
+ '\n'.join(
[
f'{chr(65+i)}) {option}'
for i, option in enumerate(item['options'])
]
)
]
item['infer_prompt'] = template['prompt_format'][0].format(*prompt_format)
item['prompt_mode'] = prompt_mode
return item
@LOAD_DATASET.register_module()
class SuperGPQADataset(BaseDataset):
@staticmethod
def load(path: str, prompt_mode: str, **kwargs):
path = get_data_path(path, local_mode=True)
dataset = load_dataset(path, split='train')
# get prompt template
template_path = None
if prompt_mode == 'zero-shot':
template_path = os.path.join(
os.path.dirname(__file__),
'supergpqa_dataset_config/prompt/zero-shot.yaml',
)
elif prompt_mode == 'five-shot':
template_path = os.path.join(
os.path.dirname(__file__),
'supergpqa_dataset_config/prompt/five-shot.yaml',
)
try:
template = load_yaml(template_path)
except FileNotFoundError:
print(f'[ERROR] Missing prompt template: {template_path}')
return Dataset.from_list([])
dataset = dataset.map(lambda item: _parse(item, template, prompt_mode))
return dataset
@ICL_EVALUATORS.register_module()
class SuperGPQAEvaluator(BaseEvaluator):
def __init__(self):
super().__init__()
def score(self, predictions, references, test_set):
mode = test_set[0]['prompt_mode']
acc = 0
count = 0
err = 0
miss = 0
acc_difficulty = {"hard": 0, "middle": 0, "easy": 0}
count_difficulty = {"hard": 0, "middle": 0, "easy": 0}
stats = {'discipline': {}, 'field': {}, 'subfield': {}}
details = []
for i, sample in enumerate(test_set):
sample["pred"] = prediction = predictions[i]
gold = references[i]
if mode == 'zero-shot':
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(
prediction, sample["options"]
)
predict = (
chr(sample["options"].index(predict) + 65)
if predict
else None
)
sample["extracted_answer"] = predict
elif mode == 'five-shot':
response = prediction.split('Question:')[0]
predict = extract_option_labels(response, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(
response, sample["options"]
)
predict = (
chr(sample["options"].index(predict) + 65)
if predict
else None
)
if predict == None:
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
if predict == None:
predict = extract_option_content(
prediction, sample["options"]
)
predict = (
chr(sample["options"].index(predict) + 65)
if predict
else None
)
sample["extracted_answer"] = predict
discipline = sample.get("discipline", "unknown")
field = sample.get("field", "unknown")
subfield = sample.get("subfield", "unknown")
difficulty = sample.get("difficulty", "unknown")
for level, key in [
('discipline', discipline),
# ('field', f"{discipline}/{field}"),
# ('subfield', f"{discipline}/{field}/{subfield}"),
]:
if key not in stats[level]:
stats[level][key] = {
"correct": 0,
"total": 0,
"miss": 0,
"error": 0,
"discipline": discipline,
"field": field,
"subfield": subfield,
"difficulty": {
"easy": {"correct": 0, "total": 0},
"middle": {"correct": 0, "total": 0},
"hard": {"correct": 0, "total": 0},
},
}
stats[level][key]["total"] += 1
stats[level][key]["difficulty"][difficulty]["total"] += 1
answer_letter = sample["answer_letter"]
assert answer_letter == gold
if predict and answer_letter == predict:
acc += 1
acc_difficulty[difficulty] += 1
sample["status"] = "correct"
stats[level][key]["correct"] += 1
stats[level][key]["difficulty"][difficulty]["correct"] += 1
elif predict == None or predict == "":
miss += 1
sample["status"] = "miss"
stats[level][key]["miss"] += 1
elif predict == 'error':
err += 1
sample["status"] = "error"
stats[level][key]["error"] += 1
else:
sample["status"] = "incorrect"
count += 1
count_difficulty[difficulty] += 1
details.append(
{
'pred': sample['pred'],
'answer': sample['answer'],
'parsed_answer': sample['extracted_answer'],
'correct': True if sample['status'] else False,
}
)
return {
'accuracy': acc / count if count > 0 else 0,
'error_rate': err / count if count > 0 else 0,
'miss_rate': miss / count if count > 0 else 0,
'hard_accuracy': (
acc_difficulty["hard"] / count_difficulty["hard"]
if count_difficulty["hard"] > 0
else 0
),
'middle_accuracy': (
acc_difficulty["middle"] / count_difficulty["middle"]
if count_difficulty["middle"] > 0
else 0
),
'easy_accuracy': (
acc_difficulty["easy"] / count_difficulty["easy"]
if count_difficulty["easy"] > 0
else 0
),
'details': details,
}