mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
phybench
This commit is contained in:
parent
37cbaf8d92
commit
71173c4fef
149
opencompass/datasets/PHYBench/PHYBench.py
Normal file
149
opencompass/datasets/PHYBench/PHYBench.py
Normal file
@ -0,0 +1,149 @@
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import sympy
|
||||
from datasets import load_dataset
|
||||
|
||||
from opencompass.datasets.base import BaseDataset
|
||||
from opencompass.datasets.PHYBench.EED.EED import EED
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class PHYBenchDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
dataset = load_dataset(path, split='train')
|
||||
# only use first 100 examples
|
||||
return dataset.select(range(100))
|
||||
|
||||
|
||||
def extract_last_latex(prediction: str) -> str:
|
||||
# 1) Find all \boxed{ occurrences and manually extract balanced content
|
||||
boxed_positions = [
|
||||
m.start() for m in re.finditer(r'\\boxed\s*\{', prediction)
|
||||
]
|
||||
boxed_contents = []
|
||||
for pos in boxed_positions:
|
||||
# find the opening brace
|
||||
brace_start = prediction.find('{', pos)
|
||||
if brace_start == -1:
|
||||
continue
|
||||
# scan forward to find matching closing brace
|
||||
depth = 0
|
||||
for i in range(brace_start, len(prediction)):
|
||||
if prediction[i] == '{':
|
||||
depth += 1
|
||||
elif prediction[i] == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
# extract between braces
|
||||
boxed_contents.append(prediction[brace_start +
|
||||
1:i].strip())
|
||||
break
|
||||
|
||||
if boxed_contents:
|
||||
return boxed_contents[-1]
|
||||
|
||||
# 2) fallback: other delimiters
|
||||
cleaned = re.sub(r'^###.*$', '', prediction, flags=re.MULTILINE)
|
||||
cleaned = re.sub(r'[*\\-]{3,}', '', cleaned)
|
||||
cleaned = re.sub(r'(^|\n)[ \t]*[-*+] ', r'\1', cleaned)
|
||||
|
||||
patterns = [
|
||||
r'\$\$(.*?)\$\$',
|
||||
r'\\\[(.*?)\\\]',
|
||||
r'\$(.*?)\$',
|
||||
r'\\\((.*?)\\\)',
|
||||
]
|
||||
fragments = []
|
||||
for pat in patterns:
|
||||
for mm in re.finditer(pat, cleaned, re.DOTALL):
|
||||
fragments.append(mm.group(1).strip())
|
||||
if fragments:
|
||||
return fragments[-1]
|
||||
|
||||
# 3) final fallback
|
||||
m2 = re.search(r'Final\s*Answer\s*:?\s*(.+)$', prediction, re.DOTALL)
|
||||
return m2.group(1).strip() if m2 else prediction.strip()
|
||||
|
||||
|
||||
def _calculate_eed_score(pred_str: str, ref_str: str) -> float:
|
||||
"""Calculate the Expression Edit Distance (EED) score.
|
||||
|
||||
Args:
|
||||
pred_str (str): Predicted answer string (LaTeX format)
|
||||
ref_str (str): Reference answer string (LaTeX format)
|
||||
|
||||
Returns:
|
||||
float: EED score between 0 and 100
|
||||
"""
|
||||
try:
|
||||
# Normalize the inputs first
|
||||
# remove the first $$ and the last $$ from the ref_str
|
||||
|
||||
clean_pred = extract_last_latex(pred_str)
|
||||
if '$$' in ref_str:
|
||||
clean_ref = ref_str.split('$$')[1].strip()
|
||||
else:
|
||||
clean_ref = extract_last_latex(ref_str)
|
||||
# only compare the rhs of rightmost =
|
||||
clean_pred = clean_pred.split('=')[-1].strip()
|
||||
clean_ref = clean_ref.split('=')[-1].strip()
|
||||
|
||||
# try to convert the latex to sympy expression
|
||||
try:
|
||||
clean_pred_expr = sympy.latex(sympy.sympify(clean_pred))
|
||||
clean_ref_expr = sympy.latex(sympy.sympify(clean_ref))
|
||||
except Exception:
|
||||
clean_pred_expr = None
|
||||
clean_ref_expr = None
|
||||
eed_result = EED(clean_ref, clean_pred)
|
||||
if clean_pred_expr and clean_ref_expr:
|
||||
clean_eed_result = EED(clean_ref_expr, clean_pred_expr)
|
||||
final_eed_result = max(clean_eed_result[0], eed_result[0])
|
||||
else:
|
||||
final_eed_result = eed_result[0]
|
||||
return final_eed_result
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class PHYBenchEvaluator(BaseEvaluator):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def score(self,
|
||||
predictions: List[str],
|
||||
references: List[str],
|
||||
test_set: List[Dict] = None) -> Dict:
|
||||
"""Evaluate predictions for PHYBench based on Accuracy and EED
|
||||
Score."""
|
||||
|
||||
if len(predictions) != len(references):
|
||||
return {'error': 'Number of predictions and references mismatch.'}
|
||||
|
||||
correct_count = 0
|
||||
total_count = len(predictions)
|
||||
eed_scores = []
|
||||
|
||||
for idx, (pred_str, ref_str) in enumerate(zip(predictions,
|
||||
references)):
|
||||
|
||||
eed = _calculate_eed_score(pred_str, ref_str)
|
||||
eed_scores.append(eed)
|
||||
|
||||
if abs(eed - 100) < 1e-6:
|
||||
correct_count += 1
|
||||
|
||||
accuracy = (correct_count /
|
||||
total_count) * 100 if total_count > 0 else 0
|
||||
average_eed_score = np.mean(eed_scores) if eed_scores else 0
|
||||
|
||||
# Return results as a dictionary
|
||||
return {'accuracy': accuracy, 'eed_score': average_eed_score}
|
Loading…
Reference in New Issue
Block a user