From 71173c4fef1235081adfffd987cb1e328fde1bbe Mon Sep 17 00:00:00 2001 From: yufeng zhao Date: Wed, 30 Apr 2025 12:29:54 +0000 Subject: [PATCH] phybench --- opencompass/datasets/PHYBench/PHYBench.py | 149 ++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 opencompass/datasets/PHYBench/PHYBench.py diff --git a/opencompass/datasets/PHYBench/PHYBench.py b/opencompass/datasets/PHYBench/PHYBench.py new file mode 100644 index 00000000..bff38d26 --- /dev/null +++ b/opencompass/datasets/PHYBench/PHYBench.py @@ -0,0 +1,149 @@ +import re +from typing import Dict, List + +import numpy as np +import sympy +from datasets import load_dataset + +from opencompass.datasets.base import BaseDataset +from opencompass.datasets.PHYBench.EED.EED import EED +from opencompass.openicl.icl_evaluator import BaseEvaluator +from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET + + +@LOAD_DATASET.register_module() +class PHYBenchDataset(BaseDataset): + + @staticmethod + def load(path: str): + dataset = load_dataset(path, split='train') + # only use first 100 examples + return dataset.select(range(100)) + + +def extract_last_latex(prediction: str) -> str: + # 1) Find all \boxed{ occurrences and manually extract balanced content + boxed_positions = [ + m.start() for m in re.finditer(r'\\boxed\s*\{', prediction) + ] + boxed_contents = [] + for pos in boxed_positions: + # find the opening brace + brace_start = prediction.find('{', pos) + if brace_start == -1: + continue + # scan forward to find matching closing brace + depth = 0 + for i in range(brace_start, len(prediction)): + if prediction[i] == '{': + depth += 1 + elif prediction[i] == '}': + depth -= 1 + if depth == 0: + # extract between braces + boxed_contents.append(prediction[brace_start + + 1:i].strip()) + break + + if boxed_contents: + return boxed_contents[-1] + + # 2) fallback: other delimiters + cleaned = re.sub(r'^###.*$', '', prediction, flags=re.MULTILINE) + cleaned = re.sub(r'[*\\-]{3,}', '', cleaned) + cleaned = re.sub(r'(^|\n)[ \t]*[-*+] ', r'\1', cleaned) + + patterns = [ + r'\$\$(.*?)\$\$', + r'\\\[(.*?)\\\]', + r'\$(.*?)\$', + r'\\\((.*?)\\\)', + ] + fragments = [] + for pat in patterns: + for mm in re.finditer(pat, cleaned, re.DOTALL): + fragments.append(mm.group(1).strip()) + if fragments: + return fragments[-1] + + # 3) final fallback + m2 = re.search(r'Final\s*Answer\s*:?\s*(.+)$', prediction, re.DOTALL) + return m2.group(1).strip() if m2 else prediction.strip() + + +def _calculate_eed_score(pred_str: str, ref_str: str) -> float: + """Calculate the Expression Edit Distance (EED) score. + + Args: + pred_str (str): Predicted answer string (LaTeX format) + ref_str (str): Reference answer string (LaTeX format) + + Returns: + float: EED score between 0 and 100 + """ + try: + # Normalize the inputs first + # remove the first $$ and the last $$ from the ref_str + + clean_pred = extract_last_latex(pred_str) + if '$$' in ref_str: + clean_ref = ref_str.split('$$')[1].strip() + else: + clean_ref = extract_last_latex(ref_str) + # only compare the rhs of rightmost = + clean_pred = clean_pred.split('=')[-1].strip() + clean_ref = clean_ref.split('=')[-1].strip() + + # try to convert the latex to sympy expression + try: + clean_pred_expr = sympy.latex(sympy.sympify(clean_pred)) + clean_ref_expr = sympy.latex(sympy.sympify(clean_ref)) + except Exception: + clean_pred_expr = None + clean_ref_expr = None + eed_result = EED(clean_ref, clean_pred) + if clean_pred_expr and clean_ref_expr: + clean_eed_result = EED(clean_ref_expr, clean_pred_expr) + final_eed_result = max(clean_eed_result[0], eed_result[0]) + else: + final_eed_result = eed_result[0] + return final_eed_result + except Exception: + return 0 + + +@ICL_EVALUATORS.register_module() +class PHYBenchEvaluator(BaseEvaluator): + + def __init__(self): + super().__init__() + + def score(self, + predictions: List[str], + references: List[str], + test_set: List[Dict] = None) -> Dict: + """Evaluate predictions for PHYBench based on Accuracy and EED + Score.""" + + if len(predictions) != len(references): + return {'error': 'Number of predictions and references mismatch.'} + + correct_count = 0 + total_count = len(predictions) + eed_scores = [] + + for idx, (pred_str, ref_str) in enumerate(zip(predictions, + references)): + + eed = _calculate_eed_score(pred_str, ref_str) + eed_scores.append(eed) + + if abs(eed - 100) < 1e-6: + correct_count += 1 + + accuracy = (correct_count / + total_count) * 100 if total_count > 0 else 0 + average_eed_score = np.mean(eed_scores) if eed_scores else 0 + + # Return results as a dictionary + return {'accuracy': accuracy, 'eed_score': average_eed_score}