mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
SuperGPQA subset metrics
This commit is contained in:
parent
b9de8b0e2b
commit
03531e7a2f
@ -1,5 +1,5 @@
|
|||||||
from opencompass.datasets.supergpqa.supergpqa import (
|
from opencompass.datasets.supergpqa.supergpqa import (
|
||||||
SuperGPQADataset,
|
SuperGPQADataset, supergpqa_llmjudge_postprocess
|
||||||
)
|
)
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
@ -87,7 +87,7 @@ eval_cfg = dict(
|
|||||||
reader_cfg=reader_cfg,
|
reader_cfg=reader_cfg,
|
||||||
),
|
),
|
||||||
judge_cfg=dict(),
|
judge_cfg=dict(),
|
||||||
dict_postprocessor=dict(type=generic_llmjudge_postprocess),
|
dict_postprocessor=dict(type=supergpqa_llmjudge_postprocess),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
supergpqa_dataset = dict(
|
supergpqa_dataset = dict(
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
|
|
||||||
@ -7,6 +8,7 @@ from opencompass.datasets.supergpqa.supergpqa_eval import (
|
|||||||
from opencompass.datasets.supergpqa.supergpqa_utils import load_yaml
|
from opencompass.datasets.supergpqa.supergpqa_utils import load_yaml
|
||||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
||||||
|
from opencompass.utils import get_logger
|
||||||
|
|
||||||
from ..base import BaseDataset
|
from ..base import BaseDataset
|
||||||
|
|
||||||
@ -180,3 +182,133 @@ class SuperGPQAEvaluator(BaseEvaluator):
|
|||||||
'details':
|
'details':
|
||||||
details,
|
details,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _generic_llmjudge_postprocess(judgement: str):
|
||||||
|
match = re.search(r'(A|B)', judgement)
|
||||||
|
grade_letter = (match.group(0) if match else 'B'
|
||||||
|
) # Default to "INCORRECT" if no match
|
||||||
|
return grade_letter
|
||||||
|
|
||||||
|
|
||||||
|
def supergpqa_llmjudge_postprocess(
|
||||||
|
output: dict,
|
||||||
|
output_path: str,
|
||||||
|
dataset: Dataset,
|
||||||
|
) -> dict:
|
||||||
|
# Get the original dataset
|
||||||
|
original_dataset = dataset.reader.dataset['test']
|
||||||
|
|
||||||
|
judged_answers = []
|
||||||
|
original_responses = []
|
||||||
|
references = []
|
||||||
|
details = []
|
||||||
|
|
||||||
|
# Initialize statistics dictionaries
|
||||||
|
stats = {'discipline': {}, 'field': {}, 'subfield': {}}
|
||||||
|
|
||||||
|
total_correct = 0
|
||||||
|
total_count = 0
|
||||||
|
|
||||||
|
# Process each sample
|
||||||
|
for k, v in output.items():
|
||||||
|
idx = int(k) # Convert key to integer for indexing
|
||||||
|
original_responses.append(v['prediction'])
|
||||||
|
processed_judge = _generic_llmjudge_postprocess(v['prediction'])
|
||||||
|
|
||||||
|
# Get category information from the dataset
|
||||||
|
sample = original_dataset[idx]
|
||||||
|
discipline = sample.get('discipline', 'unknown')
|
||||||
|
field = sample.get('field', 'unknown')
|
||||||
|
subfield = sample.get('subfield', 'unknown')
|
||||||
|
|
||||||
|
# Initialize category stats if not exists
|
||||||
|
for level, key in [
|
||||||
|
('discipline', discipline),
|
||||||
|
('field', f'{discipline}/{field}'),
|
||||||
|
('subfield', f'{discipline}/{field}/{subfield}'),
|
||||||
|
]:
|
||||||
|
if key not in stats[level]:
|
||||||
|
stats[level][key] = {'correct': 0, 'total': 0}
|
||||||
|
|
||||||
|
# Record the judgment
|
||||||
|
if processed_judge is not None:
|
||||||
|
judged_answers.append(processed_judge)
|
||||||
|
try:
|
||||||
|
gold = v['gold']
|
||||||
|
references.append(gold)
|
||||||
|
except KeyError:
|
||||||
|
get_logger().warning(
|
||||||
|
f'No gold answer for {k}, use empty string as reference!')
|
||||||
|
gold = ''
|
||||||
|
references.append('')
|
||||||
|
|
||||||
|
# Check if the answer is correct (A means correct)
|
||||||
|
is_correct = processed_judge == 'A'
|
||||||
|
total_count += 1
|
||||||
|
|
||||||
|
if is_correct:
|
||||||
|
total_correct += 1
|
||||||
|
# Update category stats
|
||||||
|
for level, key in [
|
||||||
|
('discipline', discipline),
|
||||||
|
('field', f'{discipline}/{field}'),
|
||||||
|
('subfield', f'{discipline}/{field}/{subfield}'),
|
||||||
|
]:
|
||||||
|
stats[level][key]['correct'] += 1
|
||||||
|
|
||||||
|
# Update category totals
|
||||||
|
for level, key in [
|
||||||
|
('discipline', discipline),
|
||||||
|
('field', f'{discipline}/{field}'),
|
||||||
|
('subfield', f'{discipline}/{field}/{subfield}'),
|
||||||
|
]:
|
||||||
|
stats[level][key]['total'] += 1
|
||||||
|
# Add to details
|
||||||
|
details.append({
|
||||||
|
'id': k,
|
||||||
|
'question': sample['question'],
|
||||||
|
'options': sample['options'],
|
||||||
|
'origin_prompt': v['origin_prompt'],
|
||||||
|
'llm_judge': processed_judge,
|
||||||
|
'gold': gold,
|
||||||
|
'is_correct': is_correct,
|
||||||
|
'discipline': discipline,
|
||||||
|
'field': field,
|
||||||
|
'subfield': subfield,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Calculate overall accuracy with two decimal places
|
||||||
|
overall_accuracy = (round(
|
||||||
|
(total_correct / total_count * 100), 2) if total_count > 0 else 0.00)
|
||||||
|
|
||||||
|
# Initialize results dictionary
|
||||||
|
results = {
|
||||||
|
'accuracy': overall_accuracy,
|
||||||
|
'total_correct': total_correct,
|
||||||
|
'total_count': total_count,
|
||||||
|
'details': details,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Calculate accuracy for each category and flatten into results
|
||||||
|
for level in stats:
|
||||||
|
for key, value in stats[level].items():
|
||||||
|
if value['total'] > 0:
|
||||||
|
# Calculate accuracy with two decimal places
|
||||||
|
accuracy = round((value['correct'] / value['total'] * 100), 2)
|
||||||
|
|
||||||
|
# Create a flattened key for the category
|
||||||
|
flat_key = f'SuperGPQA-{level}'
|
||||||
|
if level == 'discipline':
|
||||||
|
flat_key = f'SuperGPQA-{key}'
|
||||||
|
elif level == 'field':
|
||||||
|
discipline, field = key.split('/')
|
||||||
|
flat_key = f'SuperGPQA-{discipline}-{field}'
|
||||||
|
elif level == 'subfield':
|
||||||
|
discipline, field, subfield = key.split('/')
|
||||||
|
flat_key = f'SuperGPQA-{discipline}-{field}-{subfield}'
|
||||||
|
|
||||||
|
# Add to results
|
||||||
|
results[flat_key] = accuracy
|
||||||
|
|
||||||
|
return results
|
||||||
|
@ -84,6 +84,8 @@ class GenericLLMEvaluator(BaseEvaluator):
|
|||||||
references: Optional[List] = None,
|
references: Optional[List] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Apply to single-model scoring."""
|
"""Apply to single-model scoring."""
|
||||||
|
assert len(predictions) == len(
|
||||||
|
references), 'predictions and references must have the same length'
|
||||||
# -------------- Build Inferencer ----------------
|
# -------------- Build Inferencer ----------------
|
||||||
self.build_inferencer()
|
self.build_inferencer()
|
||||||
|
|
||||||
@ -127,7 +129,7 @@ class GenericLLMEvaluator(BaseEvaluator):
|
|||||||
prompt_template=self.prompt_template)
|
prompt_template=self.prompt_template)
|
||||||
|
|
||||||
output = mmengine.load(self.output_path)
|
output = mmengine.load(self.output_path)
|
||||||
return self.output_postprocess(output)
|
return self.output_postprocess(output, dataset)
|
||||||
|
|
||||||
def pred_postprocess(self, predictions: List) -> Dict:
|
def pred_postprocess(self, predictions: List) -> Dict:
|
||||||
if self.pred_postprocessor is None:
|
if self.pred_postprocessor is None:
|
||||||
@ -137,15 +139,24 @@ class GenericLLMEvaluator(BaseEvaluator):
|
|||||||
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type'))
|
proc = TEXT_POSTPROCESSORS.get(kwargs.pop('type'))
|
||||||
return [proc(pred, **kwargs) for pred in predictions]
|
return [proc(pred, **kwargs) for pred in predictions]
|
||||||
|
|
||||||
def output_postprocess(self, output: Dict) -> Dict:
|
def output_postprocess(self, output: Dict, dataset=None) -> Dict:
|
||||||
"""Postprocess output by adding necessary statistics or data into
|
"""Postprocess output by adding necessary statistics or data into
|
||||||
it."""
|
it."""
|
||||||
|
import inspect
|
||||||
|
|
||||||
if self.dict_postprocessor is None:
|
if self.dict_postprocessor is None:
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
kwargs = self.dict_postprocessor
|
kwargs = self.dict_postprocessor
|
||||||
proc = DICT_POSTPROCESSORS.get(kwargs.pop('type'))
|
proc = DICT_POSTPROCESSORS.get(kwargs.pop('type'))
|
||||||
return proc(output, self.output_path, **kwargs)
|
sig = inspect.signature(proc)
|
||||||
|
if 'dataset' in sig.parameters:
|
||||||
|
return proc(output,
|
||||||
|
self.output_path,
|
||||||
|
dataset=dataset,
|
||||||
|
**kwargs)
|
||||||
|
else:
|
||||||
|
return proc(output, self.output_path, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_judge_cfg(self):
|
def default_judge_cfg(self):
|
||||||
|
@ -89,6 +89,14 @@ class BaseEvaluator:
|
|||||||
original_dataset: Dataset,
|
original_dataset: Dataset,
|
||||||
**score_kwargs,
|
**score_kwargs,
|
||||||
):
|
):
|
||||||
|
# Check if predictions and references have the
|
||||||
|
# same length if both are provided
|
||||||
|
if 'predictions' in score_kwargs and 'references' in score_kwargs:
|
||||||
|
if len(score_kwargs['predictions']) != len(
|
||||||
|
score_kwargs['references']):
|
||||||
|
raise ValueError(
|
||||||
|
'Predictions and references must have the same length')
|
||||||
|
|
||||||
real_size = len(original_dataset) // n
|
real_size = len(original_dataset) // n
|
||||||
all_details = []
|
all_details = []
|
||||||
all_results = []
|
all_results = []
|
||||||
|
Loading…
Reference in New Issue
Block a user