diff --git a/dataset-index.yml b/dataset-index.yml index 36e6847a..5ee6cc06 100644 --- a/dataset-index.yml +++ b/dataset-index.yml @@ -1113,3 +1113,9 @@ paper: https://arxiv.org/pdf/2203.14371 configpath: opencompass/configs/datasets/medmcqa/medmcqa_gen.py configpath_llmjudge: opencompass/configs/datasets/medmcqa/medmcqa_llmjudge_gen.py +- phybench: + name: PHYBench + category: Science /Physics + paper: https://arxiv.org/abs/2504.16074 + configpath: opencompass/configs/datasets/PHYBench/phybench_gen.py + configpath_llmjudge: '' \ No newline at end of file diff --git a/opencompass/configs/datasets/PHYBench/EED.py b/opencompass/configs/datasets/PHYBench/EED.py new file mode 100644 index 00000000..8a3ada35 --- /dev/null +++ b/opencompass/configs/datasets/PHYBench/EED.py @@ -0,0 +1,361 @@ +from sympy import * +from sympy.core.function import AppliedUndef +from sympy.core.numbers import Pi, Exp1,I,Infinity,NegativeInfinity +import numpy as np +import timeout_decorator +from extended_zss import ext_distance +from latex_pre_process import * +from sympy.simplify import * +""" +Guide: +You only need to use EED and install the following packages: +- sympy +- numpy +- latex2sympy2_extended +- timeout_decorator +""" + +""" +There are four main categories: + +Constants: such as integers, decimals, or mathematical constants like π and e. +Variables: letters like x, y, z, or specified terms in problems (e.g., ħ, c, G). +Functions: sine, cosine, exponential, logarithm, etc. +Operators: basic binary operations including addition, multiplication, and exponentiation. +""" +# The costs can be modified if you think their values are different +insert_cost={"number":1,"symbol":1,"operator":1,"function":1} +delete_cost={"number":1,"symbol":1,"operator":1,"function":1} +update_cost={"number":1,"symbol":1,"operator":1,"function":1} + +change_type_cost=1 #the cost of an update between different types,can be set to higher + +bar_size=5 # the minimum size of triggering cluster discount +discount_slope=0.6 #discount + +simplify_time_limit=30 #set the time limit of simplify +equals_time_limit=10 #set the time limit of equals + +def update_func(x,y): + + if x.label==y.label: + return 0 + + elif x.label.split("_")[0]==y.label.split("_")[0]: + return update_cost[x.label.split("_")[0]] + return change_type_cost +def remove_func(x): + return delete_cost[x.label.split("_")[0]] + +def remove_tree_func(x): + if not x.children: + return remove_func(x) + s=calc_tree_size(x) + return min(s,discount_slope*(s-bar_size)+bar_size) + + +def insert_func(x): + return insert_cost[x.label.split("_")[0]] +def insert_tree_func(x): + return remove_tree_func(x) + + + +def calc_tree_size(node): + """ + Calculate the size of a subtree based on its total insertion cost. + The function computes the size of a subtree by summing up the insertion + costs of the current node and all its descendant nodes. If the subtree + size has already been calculated and stored in `node.subtree_size`, it + returns the cached value to avoid redundant computation. + Args: + node (Node): The root node of the subtree for which the size is to + be calculated + Returns: + int: The total size of the subtree, calculated as the sum of the + insertion costs of the current node and all its descendants. + Notes: + - The `insert_cost` dictionary is assumed to be globally defined + and maps node labels to their respective insertion costs. + - The function modifies the `subtree_size` attribute of the input + node to store the calculated subtree size for future use. + """ + """The size of a subtree equals to its total insertion cost""" + + total = insert_cost[node.label.split("_")[0]] + + if node.children and node.subtree_size !=0: + + return node.subtree_size + + for child in node.children: + total += calc_tree_size(child) + + node.subtree_size=total + + return total +""" +Scoring function from relative distance +""" +def score_calc(tree_dist,tree_size): + + if tree_dist==0.: + return 100 + return max(0,100*discount_slope-100*tree_dist/tree_size) + + + + +@timeout_decorator.timeout(30, timeout_exception=TimeoutError) +def simplify_with_timeout(expr): + return simplify(expr) +def time_simplify(expr): + try: + result=simplify_with_timeout(expr) + return result + except TimeoutError: + return expr + +@timeout_decorator.timeout(10, timeout_exception=TimeoutError) +def equal_with_timeout(expr1,expr2): + return expr1.equals(expr2) +def time_equal(expr1,expr2): + try: + result=equal_with_timeout(expr1,expr2) + return result + except TimeoutError: + return False + + +def sympy_to_tree(expr): + """ + Convert a SymPy expression into a tree structure. + This function takes a SymPy expression and recursively converts it into a tree + representation using `TreeNode` objects. Each node in the tree is labeled based + on the type of the SymPy expression (e.g., number, symbol, operator, or function), + and its children represent the arguments of the expression. + Args: + expr (sympy.Basic): The SymPy expression to be converted. + Returns: + TreeNode: The root node of the tree representation of the SymPy expression. + Raises: + ValueError: If the SymPy expression contains an unsupported type. + Supported Types: + - Numbers: Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity + - Symbols: Symbol + - Binary Operators: Add, Mul, Pow + - Functions: Any subclass of `sympy.Function` + Example: + >>> from sympy import symbols, sin, pi + >>> x, y = symbols('x y') + >>> expr = x + y * sin(pi) + >>> tree = sympy_to_tree(expr) + >>> print(tree) + """ + #print(expr) + + """Convert the sympy expression to a tree""" + # Symbols and constants + if isinstance(expr, (Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity)): + return TreeNode(label="number_"+str(expr), children=[]) + elif isinstance(expr, (Symbol,)): + + return TreeNode(label="symbol_"+str(expr),children=[]) + + + # Binary operators + elif isinstance(expr, (Add, Mul, Pow)): + + op_name = type(expr).__name__ + children = [sympy_to_tree(arg) for arg in expr.args] + return TreeNode(label="operator_"+op_name, children=children) + + + elif isinstance(expr, (Function)): + # Functions + + func_name = expr.func.__name__ + children = [sympy_to_tree(arg) for arg in expr.args] + return TreeNode(label="function_"+func_name, children=children) + + else: + #print(expr) + print(f"Unsupported Sympy type: {type(expr).__name__}, Expression: {expr}") + raise ValueError(f"Unsupported SymPy type: {type(expr)}") + +class TreeNode: + def __init__(self, label, children=None,node_type='other'): + self.label = label + self.children = children if children is not None else [] + self.node_type=node_type + self.subtree_size=0 + def get_children(self): + return self.children + + def __str__(self): + return self.label + + + + +def print_tree(node, indent=0): + """Print a tree structure""" + print(' ' * indent + f'└─ {node.label}') + for child in node.children: + print_tree(child, indent + 1) + + + +import timeout_decorator + +class LaTeXError(Exception): + def __init__(self, message="LaTeXError"): + super().__init__(message) +class SymPyError(Exception): + def __init__(self, message="SymPyError"): + super().__init__(message) + + +class TreeError(Exception): + def __init__(self, message="TreeError"): + super().__init__(message) + + +class DistError(Exception): + def __init__(self, message="DistanceError"): + super().__init__(message) + +def EED(answer_latex,test_latex,debug_mode=False): + """ + Computes the similarity score and distance metrics between two LaTeX expressions. + This function evaluates the equivalence of two mathematical expressions represented + in LaTeX format. It uses symbolic computation and tree-based distance metrics to + calculate a similarity score and other related metrics. + + tuple: A tuple containing the following elements: + - score (float): The similarity score between the two expressions (0 to 100). + - relative_distance (float): The normalized distance between the two expressions. + - answer_tree_size (int): The size of the expression tree for the answer. + - distance (float): The raw distance between the two expression trees. + Notes: + - If either input contains unsupported LaTeX constructs (e.g., integrals or sums), + the function returns default values indicating failure. + - If the test expression is significantly longer than the answer expression, + the function assumes they are not equivalent. + - The function uses symbolic simplification and tree-based distance metrics to + evaluate equivalence. + - In case of errors during processing, the function returns default values unless + `debug_mode` is enabled, in which case it raises specific exceptions. + Exceptions: + - LaTeXError: Raised when LaTeX conversion to symbolic expressions fails (if `debug_mode` is True). + - SymPyError: Raised when symbolic simplification or tree construction fails (if `debug_mode` is True). + - DistError: Raised when distance calculation fails (if `debug_mode` is True). + Args: + answer_latex: the latex expression of answer expression + test_latex: the latex expression of test expression + debug_mode: whether it raise errors or just skip it + Returns: + tuple: A tuple containing the following elements: + - score (float): The similarity score between the two expressions (0 to 100). + - relative_distance (float): The normalized distance between the two expressions. + - answer_tree_size (int): The size of the expression tree for the answer. + - distance (float): The raw distance between the two expression trees. + """ + + if not test_latex: + return 0,-1,-1,-1 + if '\\int' in test_latex or '\\int' in answer_latex: + return 0,-1,-1,-1 + if '\\sum' in test_latex or '\\sum' in answer_latex: + return 0,-1,-1,1 + if answer_latex==test_latex: + return 100,0.0,-1,0 + if len(test_latex)>3*len(answer_latex): + return 0,-1,-1,-1 + + try: + + answer_exp=master_convert(answer_latex) + test_exp=master_convert(test_latex) + except: + print(f"Failed to convert input latex to sympy expression,please check it") + if debug_mode: + raise LaTeXError(f"Fail to convert latex.\n GT:{answer_latex}\n GEN:{test_latex}") + return 0,-1,-1,-1 + + try: + + answer_exp,rep1=posify(answer_exp) + + answer_exp=time_simplify(answer_exp) + + + test_exp,rep2=posify(test_exp) + test_exp=time_simplify(test_exp) + + + + answer_exp=answer_exp.subs(rep1) + test_exp=test_exp.subs(rep2) + + zero_exp=time_simplify(expand(answer_exp-test_exp)) + + + if answer_exp==test_exp or zero_exp==0: + return 100,0.,0,0 + + if time_equal(answer_exp,test_exp): + return 100,0.,0,0 + + except: + print("Something happened during simplification,returning zero") + if debug_mode: + raise SymPyError(f"Failed to simplify the sympy expression. Expressions: answer_exp={answer_exp}, test_exp={test_exp}") + return 0,-1,-1,-1 + + try: + tree_answer=sympy_to_tree(answer_exp) + tree_test=sympy_to_tree(test_exp) + + except: + + print("Failed to build expression tree,returning zero") + if debug_mode: + raise SymPyError(f"Failed to build the sympy expression tree.\n GT:{answer_exp}\n GEN:{test_exp}") + return 0,-1,-1,-1 + + distance=ext_distance( + tree_test, + tree_answer, + get_children=lambda x:x.get_children(), + single_insert_cost=insert_func, + insert_cost=insert_tree_func, + single_remove_cost=remove_func, + remove_cost=remove_tree_func, + update_cost=update_func) + try: + + + distance=ext_distance( + tree_test, + tree_answer, + get_children=lambda x:x.get_children(), + single_insert_cost=insert_func, + insert_cost=insert_tree_func, + single_remove_cost=remove_func, + remove_cost=remove_tree_func, + update_cost=update_func + ) + except: + print("Failed to calculate distance") + if debug_mode: + raise DistError(f"Failed to calculate the distance between trees.\n GT:{answer_latex}\n GEN:{test_latex}") + return 0,-1,calc_tree_size(tree_answer),-1 + tree_size=calc_tree_size(tree_answer) + distance_number=distance + + rel_distance=distance/tree_size + + score=score_calc(distance_number,tree_size) + + return score,rel_distance,tree_size,distance_number diff --git a/opencompass/configs/datasets/PHYBench/extended_zss.py b/opencompass/configs/datasets/PHYBench/extended_zss.py new file mode 100644 index 00000000..d75fb541 --- /dev/null +++ b/opencompass/configs/datasets/PHYBench/extended_zss.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +#Original Authors: Tim Henderson and Steve Johnson +#Email: tim.tadh@gmail.com, steve@steveasleep.com +#For licensing see the LICENSE file in the top level directory. + +# This is a modified version of zss package. + + +import collections +import numpy as np +from numpy import zeros,ones + +class Node(object): + + + def __init__(self, label, children=None): + self.label = label + self.children = children or list() + + + @staticmethod + def get_children(node): + return node.children + + @staticmethod + def get_label(node): + return node.label + + def addkid(self, node, before=False): + + if before: self.children.insert(0, node) + else: self.children.append(node) + return self + + def get(self, label): + + if self.label == label: return self + for c in self.children: + if label in c: return c.get(label) + + + + + + +class AnnotatedTree(object): + + def __init__(self, root, get_children): + self.get_children = get_children + + self.root = root + self.nodes = list() # a post-order enumeration of the nodes in the tree + self.ids = list() # a matching list of ids + self.lmds = list() # left most descendents of each nodes + self.keyroots = None + # the keyroots in the original paper + + + stack = list() + pstack = list() + stack.append((root, collections.deque())) + j = 0 + while len(stack) > 0: + n, anc = stack.pop() + nid = j + for c in self.get_children(n): + a = collections.deque(anc) + a.appendleft(nid) + stack.append((c, a)) + pstack.append(((n, nid), anc)) + j += 1 + lmds = dict() + keyroots = dict() + i = 0 + while len(pstack) > 0: + (n, nid), anc = pstack.pop() + self.nodes.append(n) + self.ids.append(nid) + if not self.get_children(n): + lmd = i + for a in anc: + if a not in lmds: lmds[a] = i + else: break + else: + try: lmd = lmds[nid] + except: + import pdb + pdb.set_trace() + self.lmds.append(lmd) + keyroots[lmd] = i + i += 1 + self.keyroots = sorted(keyroots.values()) + + +def ext_distance(A, B, get_children, single_insert_cost,insert_cost,single_remove_cost, remove_cost, update_cost): + '''Computes the extended tree edit distance between trees A and B with extended-zss algorithm + Args: + A(Node): Root node of tree 1 + B(Node): Root node of tree 2 + get_children(Func): the get_children method of tree + single_insert_cost(Func): cost of inserting single node + insert_cost(Func): cost of inserting a subtree + update_cost(Func): cost of updating A to B + + + Return: + Distance(float):the tree editing distance + ''' + A, B = AnnotatedTree(A, get_children), AnnotatedTree(B, get_children) + size_a = len(A.nodes) + size_b = len(B.nodes) + treedists = zeros((size_a, size_b), float) + fd=1000*ones((size_a+1,size_b+1),float) + operations = [[[] for _ in range(size_b)] for _ in range(size_a)] + + + def treedist(x, y): + Al = A.lmds + Bl = B.lmds + An = A.nodes + Bn = B.nodes + + m = size_a + n = size_b + + fd[Al[x]][Bl[y]]=0 + for i in range(Al[x], x+1): + node = An[i] + fd[i+1][Bl[y]] = fd[Al[i]][Bl[y]] + remove_cost(node) + + for j in range(Bl[y], y+1): + node = Bn[j] + + fd[Al[x]][j+1] = fd[Al[x]][Bl[j]] + insert_cost(node) + + for i in range(Al[x], x+1): + for j in range(Bl[y], y+1): + + node1 = An[i] + node2 = Bn[j] + costs = [fd[i][j+1] + single_remove_cost(node1), + fd[i+1][j] + single_insert_cost(node2), + fd[Al[i]][j+1]+ remove_cost(node1), + fd[i+1][Bl[j]]+ insert_cost(node2)] + m=min(costs) + + if Al[x] == Al[i] and Bl[y] == Bl[j]: + treedists[i][j]=min(m,fd[i][j]+update_cost(node1,node2)) + fd[i+1][j+1]=treedists[i][j] + else: + fd[i+1][j+1]=min(m,fd[Al[i]][Bl[j]]+treedists[i][j]) + + + for x in A.keyroots: + for y in B.keyroots: + treedist(x, y) + + return treedists[-1][-1] + diff --git a/opencompass/configs/datasets/PHYBench/latex_pre_process.py b/opencompass/configs/datasets/PHYBench/latex_pre_process.py new file mode 100644 index 00000000..1a99fe60 --- /dev/null +++ b/opencompass/configs/datasets/PHYBench/latex_pre_process.py @@ -0,0 +1,523 @@ +#This file is used to pre-process input latex expressions +#You only need a "master_convert()" +from latex2sympy2_extended import * +from sympy import simplify + + + +def brackets_balanced(s: str) -> bool: + """ + Check if the brackets in a LaTeX string are balanced + Args: + s(str): the input string + Return: + bool: True if the brackets are balanced, False otherwise + """ + stack = [] + bracket_pairs = {')': '(', ']': '[', '}': '{'} + + for char in s: + if char in bracket_pairs.values(): + stack.append(char) + elif char in bracket_pairs: + if not stack or stack[-1] != bracket_pairs[char]: + return False + stack.pop() + return len(stack) == 0 + + + +def remove_non_ascii(text): + return text.encode("ascii", errors="ignore").decode() + +import re +def extract_bracket_content(s:str,bracket_position:int) -> str: + start_idx=bracket_position + + stack = [] + content = [] + escaped = False + brace_start=start_idx+1 + brace_depth = 0 + for i in range(brace_start, len(s)): + char = s[i] + if escaped: + content.append(char) + escaped = False + continue + if char == '\\': + escaped = True + content.append(char) + continue + if char == '{': + brace_depth += 1 + content.append(char) + elif char == '}': + if brace_depth == 0: + return ''.join(content),i + brace_depth -= 1 + content.append(char) + else: + content.append(char) + + return None,-1 +def find_first_unescaped_brace(s: str) -> int: + escaped = False + for i, c in enumerate(s): + if c == '\\' and not escaped: + escaped = True + continue + if c == '{' and not escaped: + return i + escaped = False + return -1 + +def extract_command(s: str, brace_pos: int) -> str | None: + """extract the command name from a bracket""" + i = brace_pos - 1 + parameter_mode=False + while i >= 0: + if not parameter_mode and s[i] in ('^','_'): + return s[i] + if not parameter_mode and not s[i] in (' ','\t',']','['): + break + if s[i]==']': + parameter_mode=True + if s[i]=='[' and parameter_mode: + parameter_mode=False + i -= 1 + + # Start point + if i < 0 or s[i] == '\\': + return None + + # Extract command name + command_end = i + i -= 1 + while i >= 0 and s[i].isalpha(): + i -= 1 + if i<-1 or s[i]!='\\': + return None + return s[i+1:command_end+1] + + +def remove_command(s,command,keep_inside=False): + def remove_command(s, command, keep_inside=False): + """ + Removes all occurrences of a specified LaTeX-style command from a string. + + This function searches for a given command in the input string `s` and removes it, + along with its associated content enclosed in curly braces `{}`. If `keep_inside` + is set to `True`, the content inside the braces is preserved, and only the command + itself is removed. The function handles nested braces correctly. + + Args: + s (str): The input string from which the command should be removed. + command (str): The LaTeX-style command to be removed (e.g., "\\textbf"). + keep_inside (bool, optional): If `True`, preserves the content inside the braces + while removing the command. Defaults to `False`. + + Returns: + str: The modified string with the specified command removed. + + Examples: + >>> remove_command("This is \\textbf{bold text}.", "\\textbf") + 'This is bold text.' + + >>> remove_command("This is \\textbf{bold text}.", "\\textbf", keep_inside=True) + 'This is bold text.' + + >>> remove_command("Nested \\textbf{bold \\textit{italic text}} example.", "\\textbf") + 'Nested bold \\textit{italic text} example.' + """ + pos=s.find(command) + if pos<0: + return s + end_index=pos+len(command) + level=0 + escaped=False + #print(end_index,s[end_index]) + if end_index < len(s) and s[end_index] == "{": + while end_index str | None: + """ Find the first brace """ + brace_pos = find_first_unescaped_brace(s) + if brace_pos == -1: + return None + return extract_command(s, brace_pos) +def remove_overall_brace(s:str) -> str: + """ + Remove the overall {xxx} brace + """ + pos=find_first_unescaped_brace(s) + if pos==-1: + return s,0 + command=get_first_brace_command(s) + if not command: + + content,final=extract_bracket_content(s,pos) + #print(s[final]) + if final==len(s) or not '}' in s[final+1:]: + return content,1 + return s,0 + + +def exp_frac(s): + + def exp_frac_single(s): + position=s.find("^\\frac")+1 + if position == 0: + return s + level=0 + cnt=0 + idx=position + while idx>> convert_vec_syntax(r"\vec x + \vec\alpha + \vec\Gamma") + '\\vec{x} + \\vec{\\alpha} + \\vec{\\Gamma}' + """ + + pattern = r'\\vec(\s*)(\\?[a-zA-Zα-ωΑ-Ω]+)' + replacement = r'\\vec{\2}' + return re.sub(pattern, replacement, text) + +def remove_outer_braces(tex_str): + """ + convert {base}_{subscript} to base_{subscript} + Example: + {a}_{xyz} → a_{xyz} + {\theta}_{0} → \theta_{0} + """ + pattern = r'\{(\\(?:[a-zA-Z]+|.)|[^{}])+\}_\{([^}]+)\}' + return re.sub(pattern, r'\1_{\2}', tex_str) + +def extract_last_equal_content(s: str, strip_whitespace: bool = True) -> str: + """ + Extract the content after the last occurrence of specific mathematical comparison or assignment operators. + + :param strip_whitespace: If True, removes leading and trailing whitespace from the extracted content. Defaults to True. + (e.g., '=', '\\approx', '\\ge', '\\le', etc.) within the input string `s`. It then extracts + and returns the content that follows the operator. If no operator is found, the entire string + is returned. Optionally, leading and trailing whitespace can be stripped from the extracted content. + + Args: + s (str): The input string to process. + strip_whitespace (bool): Whether to strip leading and trailing whitespace from the extracted content. Defaults to True. + + Returns: + str: The content after the last matching operator, or the entire string if no operator is found. + """ + comparison_operators=('=','\\approx','\\ge','\\le','\\geq','\\leq','<','>') + + content=s + for sign in comparison_operators: + if sign in s: + rfind_index = s.rfind(sign) + if rfind_index != -1: + content = s[rfind_index + 1:] + if strip_whitespace: + return content.strip() + return content + + +def first_pre_process(s,extrac_box=True): + """ + Perform the first stage of LaTeX string preprocessing. + + if not brackets_balanced(s): + raise ValueError("The input string has unbalanced brackets. Please check the LaTeX expression.") + equality or comparison operator. + + Args: + s (str): The input LaTeX string to preprocess. + extrac_box (bool): If True, extracts the content inside a '\\boxed' command. Defaults to True. + + Returns: + str: The preprocessed LaTeX string. + """ + #s=remove_non_ascii(s) + s=s.replace('\\{','(') + s=s.replace('\\}',')') + if not brackets_balanced(s): + return s + if extrac_box: + boxed_content=remove_command(s,'\\boxed',keep_inside=True) + else: + boxed_content=s + exist_overall_brace=True + cnt=0 + while exist_overall_brace and cnt<10: + boxed_content,exist_overall_brace=remove_overall_brace(boxed_content) + cnt+=1 + + if '\\quad' in boxed_content: + boxed_content = boxed_content.split('\\quad')[0] + + last_equal_content=extract_last_equal_content(boxed_content) + + exist_overall_brace=True + cnt=0 + while exist_overall_brace and cnt<10: + last_equal_content,exist_overall_brace=remove_overall_brace(last_equal_content) + cnt+=1 + return last_equal_content +def second_pre_process(s): + """ + Perform the second stage of LaTeX string preprocessing. + + This function removes or modifies specific LaTeX commands and content to standardize + the input string for further processing. It handles commands like '\\text', '\\mathbf', + and '\\mathrm', removes unnecessary content, and applies transformations such as + converting fractions and vector syntax. + + Args: + s (str): The input LaTeX string to preprocess. + + Returns: + str: The preprocessed LaTeX string. + """ + + + kill_commands=[ + '\\begin', + '\\end' + ] + remove_commands=[ + '\\text', + '\\mathbf', + '\\mathrm', + '\\pmb', + '\\hat', + '\\overline', + '\\boldsymbol', + ] + + + remove_content=[ + '\\,','$',',','`','latex','\\left','\\right','\\text','\\mathrm','\\Bigr','\\Bigl','\n','\\]','\\[', + '\\Big','\\bigl','\\bigr','\\biggl','\\biggr','\\displaystyle','\\boldsymbol','\\infty' + ] + replace_content=[ + ('\\operatorname{asin}','\\asin'), + ('\\operatorname{sech}','\\sech'), + ('\\operatorname{acos}','\\acos'), + ('\\operatorname{sinh}','\\sinh'), + ('\\dfrac','\\frac'), + ('\\tfrac','\\frac'), + ('\\Exp','\\exp'), + ('\\times','\\bar{times}'), + ('\\partial','\\bar{partial}'), + ('\\perp','\\bar{perp}'), + ('\\epsilon','\\varepsilon'), + ('\\varOmega','\\Omega'), + ('I','\\bar{I}'), + ('_e','_{e}'), + ('e_','\\bar{e}_'), + ('E_','\\bar{E}_'), + ('\\pm','+'), + ('\\mp','-'), + ('{+}','{p}'), + ("{-}",'{m}'), + ("_+",'_p'), + ('_-',"_m") + ] + for command in kill_commands: + s=remove_command(s,command,keep_inside=False) + for command in remove_commands: + s=remove_command(s,command,keep_inside=True) + for content in remove_content: + s=s.replace(content,'') + for content in replace_content: + s=s.replace(content[0],content[1]) + s=convert_latex_fractions(s) + #print(s) + s=bar_inside_vec(s) + s=vec_lower_idx(s) + s=convert_vec_syntax(s) + s=exp_frac(s) + #s=remove_outer_braces(s) + if s and s[-1] == '.': + return s[:-1] + return s + + +class MyConfig: + + interpret_as_mixed_fractions: bool = False + interpret_simple_eq_as_assignment: bool = False + interpret_contains_as_eq: bool = True + lowercase_symbols: bool = False + """ + Args: + interpret_as_mixed_fractions (bool): Whether to interpert 2 \frac{1}{2} as 2/2 or 2 + 1/2 + interpret_simple_eq_as_assignment (bool): Whether to interpret simple equations as assignments k=1 -> 1 + interpret_contains_as_eq (bool): Whether to interpret contains as equality x \\in {1,2,3} -> x = {1,2,3} + lowercase_symbols (bool): Whether to lowercase all symbols + """ +class MyNormalization: + """Configuration for latex normalization. + + Each field controls a group of related normalizations: + - basic_latex: Basic latex command replacements (mathrm, displaystyle, etc.) + - units: Remove units and their variations + - malformed_operators: Fix malformed operators (sqrt, frac, etc.) + - nits: Small formatting fixes (spaces, dots, etc.) + - boxed: Extract content from boxed environments + - equations: Handle equation splitting and approximations (deprecated) + """ + basic_latex: bool = True + units: bool = False + malformed_operators: bool = True + nits: bool = True + boxed = "all" + equations: bool = False + +def master_convert(s): + """ + The only function needed to convert a LaTeX string into a SymPy expression. + + Args: + s (str): The input LaTeX string. It should be a valid LaTeX mathematical expression, + such as equations, fractions, or symbols, and must have balanced brackets. + + Returns: + Sym (Sympy Expression): A SymPy expression representing the mathematical content of the input string. + The returned object can be used for symbolic computation, simplification, + or evaluation using SymPy's functionality. + + Example: + >>> master_convert("\\frac{1}{2} + x") + 1/2 + x + """ + preprocessed_stage1=first_pre_process(s) + + preprocessed_stage2=second_pre_process(preprocessed_stage1) + + Sym=latex2sympy(preprocessed_stage2,normalization_config=MyNormalization(),conversion_config=MyConfig()) + return Sym diff --git a/opencompass/configs/datasets/PHYBench/phybench_gen.py b/opencompass/configs/datasets/PHYBench/phybench_gen.py new file mode 100644 index 00000000..9b88ac37 --- /dev/null +++ b/opencompass/configs/datasets/PHYBench/phybench_gen.py @@ -0,0 +1,46 @@ +import sys +import os + + +sys.path.append(os.path.dirname(__file__)) + +from EED import EED +from opencompass.openicl import BaseEvaluator +from opencompass.registry import ICL_EVALUATOR + + +@ICL_EVALUATOR.register_module() +class MathEEDEvaluator(BaseEvaluator): + def score(self, predictions, references): + scores = [] + for pred, ref in zip(predictions, references): + score, _, _, _ = EED(ref, pred) + scores.append(score) + return {'accuracy': sum(scores) / len(scores)} + + + +from opencompass.datasets import PhyBenchDataset + +phybench_datasets = [ + dict( + abbr='phybench-eed', + type=PhyBenchDataset, + path='opencompass/PHYBench', + reader_cfg=dict( + input_columns=['input'], + output_column='target', + ), + infer_cfg=dict( + prompt_template=dict( + type='plain', + template='Solve the following physics problem and return only the final result as a clean LaTeX expression. No explanation. No text.\n\nQuestion: {{input}}\nAnswer: ' + ), + retriever=dict(type='zero_shot'), + inferencer=dict(type='gen', max_out_len=512) + ), + eval_cfg=dict( + evaluator=dict(type=MathEEDEvaluator) + ) + ) +] diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index b1753221..dc4d7d0b 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -170,3 +170,4 @@ from .xcopa import * # noqa: F401, F403 from .xiezhi import XiezhiDataset, XiezhiRetriever # noqa: F401, F403 from .xlsum import * # noqa: F401, F403 from .xsum import * # noqa: F401, F403 +from .phybench import * \ No newline at end of file diff --git a/opencompass/datasets/phybench.py b/opencompass/datasets/phybench.py new file mode 100644 index 00000000..56a89704 --- /dev/null +++ b/opencompass/datasets/phybench.py @@ -0,0 +1,26 @@ +import os.path as osp +import json +from datasets import Dataset +from opencompass.datasets.base import BaseDataset +from opencompass.registry import LOAD_DATASET +from opencompass.utils import get_data_path + +@LOAD_DATASET.register_module() +class PhyBenchDataset(BaseDataset): + + @staticmethod + def load(path: str, name: str = None, **kwargs): + path = get_data_path(path) + + file_path = osp.join(path, 'PHYBench-fullques_v1.json') + + with open(file_path, 'r', encoding='utf-8') as f: + raw_data = json.load(f) + + inputs = [item['content'] for item in raw_data] + targets = [item['answer'] for item in raw_data] + + return Dataset.from_dict({ + 'input': inputs, + 'target': targets + })