This commit is contained in:
yufeng zhao 2025-04-30 12:29:54 +00:00
parent 37cbaf8d92
commit 71173c4fef

View File

@ -0,0 +1,149 @@
import re
from typing import Dict, List
import numpy as np
import sympy
from datasets import load_dataset
from opencompass.datasets.base import BaseDataset
from opencompass.datasets.PHYBench.EED.EED import EED
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
@LOAD_DATASET.register_module()
class PHYBenchDataset(BaseDataset):
@staticmethod
def load(path: str):
dataset = load_dataset(path, split='train')
# only use first 100 examples
return dataset.select(range(100))
def extract_last_latex(prediction: str) -> str:
# 1) Find all \boxed{ occurrences and manually extract balanced content
boxed_positions = [
m.start() for m in re.finditer(r'\\boxed\s*\{', prediction)
]
boxed_contents = []
for pos in boxed_positions:
# find the opening brace
brace_start = prediction.find('{', pos)
if brace_start == -1:
continue
# scan forward to find matching closing brace
depth = 0
for i in range(brace_start, len(prediction)):
if prediction[i] == '{':
depth += 1
elif prediction[i] == '}':
depth -= 1
if depth == 0:
# extract between braces
boxed_contents.append(prediction[brace_start +
1:i].strip())
break
if boxed_contents:
return boxed_contents[-1]
# 2) fallback: other delimiters
cleaned = re.sub(r'^###.*$', '', prediction, flags=re.MULTILINE)
cleaned = re.sub(r'[*\\-]{3,}', '', cleaned)
cleaned = re.sub(r'(^|\n)[ \t]*[-*+] ', r'\1', cleaned)
patterns = [
r'\$\$(.*?)\$\$',
r'\\\[(.*?)\\\]',
r'\$(.*?)\$',
r'\\\((.*?)\\\)',
]
fragments = []
for pat in patterns:
for mm in re.finditer(pat, cleaned, re.DOTALL):
fragments.append(mm.group(1).strip())
if fragments:
return fragments[-1]
# 3) final fallback
m2 = re.search(r'Final\s*Answer\s*:?\s*(.+)$', prediction, re.DOTALL)
return m2.group(1).strip() if m2 else prediction.strip()
def _calculate_eed_score(pred_str: str, ref_str: str) -> float:
"""Calculate the Expression Edit Distance (EED) score.
Args:
pred_str (str): Predicted answer string (LaTeX format)
ref_str (str): Reference answer string (LaTeX format)
Returns:
float: EED score between 0 and 100
"""
try:
# Normalize the inputs first
# remove the first $$ and the last $$ from the ref_str
clean_pred = extract_last_latex(pred_str)
if '$$' in ref_str:
clean_ref = ref_str.split('$$')[1].strip()
else:
clean_ref = extract_last_latex(ref_str)
# only compare the rhs of rightmost =
clean_pred = clean_pred.split('=')[-1].strip()
clean_ref = clean_ref.split('=')[-1].strip()
# try to convert the latex to sympy expression
try:
clean_pred_expr = sympy.latex(sympy.sympify(clean_pred))
clean_ref_expr = sympy.latex(sympy.sympify(clean_ref))
except Exception:
clean_pred_expr = None
clean_ref_expr = None
eed_result = EED(clean_ref, clean_pred)
if clean_pred_expr and clean_ref_expr:
clean_eed_result = EED(clean_ref_expr, clean_pred_expr)
final_eed_result = max(clean_eed_result[0], eed_result[0])
else:
final_eed_result = eed_result[0]
return final_eed_result
except Exception:
return 0
@ICL_EVALUATORS.register_module()
class PHYBenchEvaluator(BaseEvaluator):
def __init__(self):
super().__init__()
def score(self,
predictions: List[str],
references: List[str],
test_set: List[Dict] = None) -> Dict:
"""Evaluate predictions for PHYBench based on Accuracy and EED
Score."""
if len(predictions) != len(references):
return {'error': 'Number of predictions and references mismatch.'}
correct_count = 0
total_count = len(predictions)
eed_scores = []
for idx, (pred_str, ref_str) in enumerate(zip(predictions,
references)):
eed = _calculate_eed_score(pred_str, ref_str)
eed_scores.append(eed)
if abs(eed - 100) < 1e-6:
correct_count += 1
accuracy = (correct_count /
total_count) * 100 if total_count > 0 else 0
average_eed_score = np.mean(eed_scores) if eed_scores else 0
# Return results as a dictionary
return {'accuracy': accuracy, 'eed_score': average_eed_score}