mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
315 lines
12 KiB
Python
315 lines
12 KiB
Python
import os
|
|
import re
|
|
|
|
from datasets import Dataset, load_dataset
|
|
|
|
from opencompass.datasets.supergpqa.supergpqa_eval import (
|
|
extract_option_content, extract_option_labels)
|
|
from opencompass.datasets.supergpqa.supergpqa_utils import load_yaml
|
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
|
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
|
from opencompass.utils import get_logger
|
|
|
|
from ..base import BaseDataset
|
|
|
|
|
|
def _parse(item, template, prompt_mode):
|
|
prompt_format = [
|
|
item['question'] + '\n' + '\n'.join([
|
|
f'{chr(65+i)}) {option}'
|
|
for i, option in enumerate(item['options'])
|
|
])
|
|
]
|
|
item['infer_prompt'] = template['prompt_format'][0].format(*prompt_format)
|
|
item['prompt_mode'] = prompt_mode
|
|
return item
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
class SuperGPQADataset(BaseDataset):
|
|
|
|
@staticmethod
|
|
def load(path: str, prompt_mode: str, **kwargs):
|
|
dataset = load_dataset(path, split='train')
|
|
|
|
# get prompt template
|
|
template_path = None
|
|
if prompt_mode == 'zero-shot':
|
|
template_path = os.path.join(
|
|
os.path.dirname(__file__),
|
|
'supergpqa_dataset_config/prompt/zero-shot.yaml',
|
|
)
|
|
elif prompt_mode == 'five-shot':
|
|
template_path = os.path.join(
|
|
os.path.dirname(__file__),
|
|
'supergpqa_dataset_config/prompt/five-shot.yaml',
|
|
)
|
|
try:
|
|
template = load_yaml(template_path)
|
|
except FileNotFoundError:
|
|
print(f'[ERROR] Missing prompt template: {template_path}')
|
|
return Dataset.from_list([])
|
|
|
|
dataset = dataset.map(lambda item: _parse(item, template, prompt_mode))
|
|
return dataset
|
|
|
|
|
|
@ICL_EVALUATORS.register_module()
|
|
class SuperGPQAEvaluator(BaseEvaluator):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def score(self, predictions, references, test_set):
|
|
mode = test_set[0]['prompt_mode']
|
|
acc = 0
|
|
count = 0
|
|
err = 0
|
|
miss = 0
|
|
acc_difficulty = {'hard': 0, 'middle': 0, 'easy': 0}
|
|
count_difficulty = {'hard': 0, 'middle': 0, 'easy': 0}
|
|
stats = {'discipline': {}, 'field': {}, 'subfield': {}}
|
|
details = []
|
|
for i, sample in enumerate(test_set):
|
|
sample['pred'] = prediction = predictions[i]
|
|
gold = references[i]
|
|
if mode == 'zero-shot':
|
|
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
|
|
if predict is None:
|
|
predict = extract_option_content(prediction,
|
|
sample['options'])
|
|
predict = (chr(sample['options'].index(predict) +
|
|
65) if predict else None)
|
|
sample['extracted_answer'] = predict
|
|
elif mode == 'five-shot':
|
|
response = prediction.split('Question:')[0]
|
|
predict = extract_option_labels(response, 'ABCDEFGHIJ')
|
|
if predict is None:
|
|
predict = extract_option_content(response,
|
|
sample['options'])
|
|
predict = (chr(sample['options'].index(predict) +
|
|
65) if predict else None)
|
|
if predict is None:
|
|
predict = extract_option_labels(prediction, 'ABCDEFGHIJ')
|
|
if predict is None:
|
|
predict = extract_option_content(
|
|
prediction, sample['options'])
|
|
predict = (chr(sample['options'].index(predict) +
|
|
65) if predict else None)
|
|
sample['extracted_answer'] = predict
|
|
|
|
discipline = sample.get('discipline', 'unknown')
|
|
field = sample.get('field', 'unknown')
|
|
subfield = sample.get('subfield', 'unknown')
|
|
difficulty = sample.get('difficulty', 'unknown')
|
|
|
|
for level, key in [
|
|
('discipline', discipline),
|
|
# ('field', f"{discipline}/{field}"),
|
|
# ('subfield', f"{discipline}/{field}/{subfield}"),
|
|
]:
|
|
if key not in stats[level]:
|
|
stats[level][key] = {
|
|
'correct': 0,
|
|
'total': 0,
|
|
'miss': 0,
|
|
'error': 0,
|
|
'discipline': discipline,
|
|
'field': field,
|
|
'subfield': subfield,
|
|
'difficulty': {
|
|
'easy': {
|
|
'correct': 0,
|
|
'total': 0
|
|
},
|
|
'middle': {
|
|
'correct': 0,
|
|
'total': 0
|
|
},
|
|
'hard': {
|
|
'correct': 0,
|
|
'total': 0
|
|
},
|
|
},
|
|
}
|
|
|
|
stats[level][key]['total'] += 1
|
|
stats[level][key]['difficulty'][difficulty]['total'] += 1
|
|
|
|
answer_letter = sample['answer_letter']
|
|
assert answer_letter == gold
|
|
if predict and answer_letter == predict:
|
|
acc += 1
|
|
acc_difficulty[difficulty] += 1
|
|
sample['status'] = 'correct'
|
|
stats[level][key]['correct'] += 1
|
|
stats[level][key]['difficulty'][difficulty]['correct'] += 1
|
|
elif predict is None or predict == '':
|
|
miss += 1
|
|
sample['status'] = 'miss'
|
|
stats[level][key]['miss'] += 1
|
|
elif predict == 'error':
|
|
err += 1
|
|
sample['status'] = 'error'
|
|
stats[level][key]['error'] += 1
|
|
else:
|
|
sample['status'] = 'incorrect'
|
|
count += 1
|
|
count_difficulty[difficulty] += 1
|
|
details.append({
|
|
'pred': sample['pred'],
|
|
'answer': sample['answer'],
|
|
'parsed_answer': sample['extracted_answer'],
|
|
'correct': True if sample['status'] else False,
|
|
})
|
|
|
|
return {
|
|
'accuracy':
|
|
acc / count if count > 0 else 0,
|
|
'error_rate':
|
|
err / count if count > 0 else 0,
|
|
'miss_rate':
|
|
miss / count if count > 0 else 0,
|
|
'hard_accuracy':
|
|
(acc_difficulty['hard'] /
|
|
count_difficulty['hard'] if count_difficulty['hard'] > 0 else 0),
|
|
'middle_accuracy':
|
|
(acc_difficulty['middle'] / count_difficulty['middle']
|
|
if count_difficulty['middle'] > 0 else 0),
|
|
'easy_accuracy':
|
|
(acc_difficulty['easy'] /
|
|
count_difficulty['easy'] if count_difficulty['easy'] > 0 else 0),
|
|
'details':
|
|
details,
|
|
}
|
|
|
|
|
|
def _generic_llmjudge_postprocess(judgement: str):
|
|
match = re.search(r'(A|B)', judgement)
|
|
grade_letter = (match.group(0) if match else 'B'
|
|
) # Default to "INCORRECT" if no match
|
|
return grade_letter
|
|
|
|
|
|
def supergpqa_llmjudge_postprocess(
|
|
output: dict,
|
|
output_path: str,
|
|
dataset: Dataset,
|
|
) -> dict:
|
|
# Get the original dataset
|
|
original_dataset = dataset.reader.dataset['test']
|
|
|
|
judged_answers = []
|
|
original_responses = []
|
|
references = []
|
|
details = []
|
|
|
|
# Initialize statistics dictionaries
|
|
stats = {'discipline': {}, 'field': {}, 'subfield': {}}
|
|
|
|
total_correct = 0
|
|
total_count = 0
|
|
|
|
# Process each sample
|
|
for k, v in output.items():
|
|
idx = int(k) # Convert key to integer for indexing
|
|
original_responses.append(v['prediction'])
|
|
processed_judge = _generic_llmjudge_postprocess(v['prediction'])
|
|
|
|
# Get category information from the dataset
|
|
sample = original_dataset[idx]
|
|
discipline = sample.get('discipline', 'unknown')
|
|
field = sample.get('field', 'unknown')
|
|
subfield = sample.get('subfield', 'unknown')
|
|
|
|
# Initialize category stats if not exists
|
|
for level, key in [
|
|
('discipline', discipline),
|
|
('field', f'{discipline}/{field}'),
|
|
('subfield', f'{discipline}/{field}/{subfield}'),
|
|
]:
|
|
if key not in stats[level]:
|
|
stats[level][key] = {'correct': 0, 'total': 0}
|
|
|
|
# Record the judgment
|
|
if processed_judge is not None:
|
|
judged_answers.append(processed_judge)
|
|
try:
|
|
gold = v['gold']
|
|
references.append(gold)
|
|
except KeyError:
|
|
get_logger().warning(
|
|
f'No gold answer for {k}, use empty string as reference!')
|
|
gold = ''
|
|
references.append('')
|
|
|
|
# Check if the answer is correct (A means correct)
|
|
is_correct = processed_judge == 'A'
|
|
total_count += 1
|
|
|
|
if is_correct:
|
|
total_correct += 1
|
|
# Update category stats
|
|
for level, key in [
|
|
('discipline', discipline),
|
|
('field', f'{discipline}/{field}'),
|
|
('subfield', f'{discipline}/{field}/{subfield}'),
|
|
]:
|
|
stats[level][key]['correct'] += 1
|
|
|
|
# Update category totals
|
|
for level, key in [
|
|
('discipline', discipline),
|
|
('field', f'{discipline}/{field}'),
|
|
('subfield', f'{discipline}/{field}/{subfield}'),
|
|
]:
|
|
stats[level][key]['total'] += 1
|
|
# Add to details
|
|
details.append({
|
|
'id': k,
|
|
'question': sample['question'],
|
|
'options': sample['options'],
|
|
'origin_prompt': v['origin_prompt'],
|
|
'llm_judge': processed_judge,
|
|
'gold': gold,
|
|
'is_correct': is_correct,
|
|
'discipline': discipline,
|
|
'field': field,
|
|
'subfield': subfield,
|
|
})
|
|
|
|
# Calculate overall accuracy with two decimal places
|
|
overall_accuracy = (round(
|
|
(total_correct / total_count * 100), 2) if total_count > 0 else 0.00)
|
|
|
|
# Initialize results dictionary
|
|
results = {
|
|
'accuracy': overall_accuracy,
|
|
'total_correct': total_correct,
|
|
'total_count': total_count,
|
|
'details': details,
|
|
}
|
|
|
|
# Calculate accuracy for each category and flatten into results
|
|
for level in stats:
|
|
for key, value in stats[level].items():
|
|
if value['total'] > 0:
|
|
# Calculate accuracy with two decimal places
|
|
accuracy = round((value['correct'] / value['total'] * 100), 2)
|
|
|
|
# Create a flattened key for the category
|
|
flat_key = f'SuperGPQA-{level}'
|
|
if level == 'discipline':
|
|
flat_key = f'SuperGPQA-{key}'
|
|
elif level == 'field':
|
|
discipline, field = key.split('/')
|
|
flat_key = f'SuperGPQA-{discipline}-{field}'
|
|
elif level == 'subfield':
|
|
discipline, field, subfield = key.split('/')
|
|
flat_key = f'SuperGPQA-{discipline}-{field}-{subfield}'
|
|
|
|
# Add to results
|
|
results[flat_key] = accuracy
|
|
|
|
return results
|