This commit is contained in:
suencgo 2025-05-29 14:20:24 +08:00 committed by GitHub
commit 6c581aa7f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1123 additions and 0 deletions

View File

@ -1113,3 +1113,9 @@
paper: https://arxiv.org/pdf/2203.14371
configpath: opencompass/configs/datasets/medmcqa/medmcqa_gen.py
configpath_llmjudge: opencompass/configs/datasets/medmcqa/medmcqa_llmjudge_gen.py
- phybench:
name: PHYBench
category: Science /Physics
paper: https://arxiv.org/abs/2504.16074
configpath: opencompass/configs/datasets/PHYBench/phybench_gen.py
configpath_llmjudge: ''

View File

@ -0,0 +1,361 @@
from sympy import *
from sympy.core.function import AppliedUndef
from sympy.core.numbers import Pi, Exp1,I,Infinity,NegativeInfinity
import numpy as np
import timeout_decorator
from extended_zss import ext_distance
from latex_pre_process import *
from sympy.simplify import *
"""
Guide:
You only need to use EED and install the following packages:
- sympy
- numpy
- latex2sympy2_extended
- timeout_decorator
"""
"""
There are four main categories:
Constants: such as integers, decimals, or mathematical constants like π and e.
Variables: letters like x, y, z, or specified terms in problems (e.g., ħ, c, G).
Functions: sine, cosine, exponential, logarithm, etc.
Operators: basic binary operations including addition, multiplication, and exponentiation.
"""
# The costs can be modified if you think their values are different
insert_cost={"number":1,"symbol":1,"operator":1,"function":1}
delete_cost={"number":1,"symbol":1,"operator":1,"function":1}
update_cost={"number":1,"symbol":1,"operator":1,"function":1}
change_type_cost=1 #the cost of an update between different types,can be set to higher
bar_size=5 # the minimum size of triggering cluster discount
discount_slope=0.6 #discount
simplify_time_limit=30 #set the time limit of simplify
equals_time_limit=10 #set the time limit of equals
def update_func(x,y):
if x.label==y.label:
return 0
elif x.label.split("_")[0]==y.label.split("_")[0]:
return update_cost[x.label.split("_")[0]]
return change_type_cost
def remove_func(x):
return delete_cost[x.label.split("_")[0]]
def remove_tree_func(x):
if not x.children:
return remove_func(x)
s=calc_tree_size(x)
return min(s,discount_slope*(s-bar_size)+bar_size)
def insert_func(x):
return insert_cost[x.label.split("_")[0]]
def insert_tree_func(x):
return remove_tree_func(x)
def calc_tree_size(node):
"""
Calculate the size of a subtree based on its total insertion cost.
The function computes the size of a subtree by summing up the insertion
costs of the current node and all its descendant nodes. If the subtree
size has already been calculated and stored in `node.subtree_size`, it
returns the cached value to avoid redundant computation.
Args:
node (Node): The root node of the subtree for which the size is to
be calculated
Returns:
int: The total size of the subtree, calculated as the sum of the
insertion costs of the current node and all its descendants.
Notes:
- The `insert_cost` dictionary is assumed to be globally defined
and maps node labels to their respective insertion costs.
- The function modifies the `subtree_size` attribute of the input
node to store the calculated subtree size for future use.
"""
"""The size of a subtree equals to its total insertion cost"""
total = insert_cost[node.label.split("_")[0]]
if node.children and node.subtree_size !=0:
return node.subtree_size
for child in node.children:
total += calc_tree_size(child)
node.subtree_size=total
return total
"""
Scoring function from relative distance
"""
def score_calc(tree_dist,tree_size):
if tree_dist==0.:
return 100
return max(0,100*discount_slope-100*tree_dist/tree_size)
@timeout_decorator.timeout(30, timeout_exception=TimeoutError)
def simplify_with_timeout(expr):
return simplify(expr)
def time_simplify(expr):
try:
result=simplify_with_timeout(expr)
return result
except TimeoutError:
return expr
@timeout_decorator.timeout(10, timeout_exception=TimeoutError)
def equal_with_timeout(expr1,expr2):
return expr1.equals(expr2)
def time_equal(expr1,expr2):
try:
result=equal_with_timeout(expr1,expr2)
return result
except TimeoutError:
return False
def sympy_to_tree(expr):
"""
Convert a SymPy expression into a tree structure.
This function takes a SymPy expression and recursively converts it into a tree
representation using `TreeNode` objects. Each node in the tree is labeled based
on the type of the SymPy expression (e.g., number, symbol, operator, or function),
and its children represent the arguments of the expression.
Args:
expr (sympy.Basic): The SymPy expression to be converted.
Returns:
TreeNode: The root node of the tree representation of the SymPy expression.
Raises:
ValueError: If the SymPy expression contains an unsupported type.
Supported Types:
- Numbers: Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity
- Symbols: Symbol
- Binary Operators: Add, Mul, Pow
- Functions: Any subclass of `sympy.Function`
Example:
>>> from sympy import symbols, sin, pi
>>> x, y = symbols('x y')
>>> expr = x + y * sin(pi)
>>> tree = sympy_to_tree(expr)
>>> print(tree)
"""
#print(expr)
"""Convert the sympy expression to a tree"""
# Symbols and constants
if isinstance(expr, (Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity)):
return TreeNode(label="number_"+str(expr), children=[])
elif isinstance(expr, (Symbol,)):
return TreeNode(label="symbol_"+str(expr),children=[])
# Binary operators
elif isinstance(expr, (Add, Mul, Pow)):
op_name = type(expr).__name__
children = [sympy_to_tree(arg) for arg in expr.args]
return TreeNode(label="operator_"+op_name, children=children)
elif isinstance(expr, (Function)):
# Functions
func_name = expr.func.__name__
children = [sympy_to_tree(arg) for arg in expr.args]
return TreeNode(label="function_"+func_name, children=children)
else:
#print(expr)
print(f"Unsupported Sympy type: {type(expr).__name__}, Expression: {expr}")
raise ValueError(f"Unsupported SymPy type: {type(expr)}")
class TreeNode:
def __init__(self, label, children=None,node_type='other'):
self.label = label
self.children = children if children is not None else []
self.node_type=node_type
self.subtree_size=0
def get_children(self):
return self.children
def __str__(self):
return self.label
def print_tree(node, indent=0):
"""Print a tree structure"""
print(' ' * indent + f'└─ {node.label}')
for child in node.children:
print_tree(child, indent + 1)
import timeout_decorator
class LaTeXError(Exception):
def __init__(self, message="LaTeXError"):
super().__init__(message)
class SymPyError(Exception):
def __init__(self, message="SymPyError"):
super().__init__(message)
class TreeError(Exception):
def __init__(self, message="TreeError"):
super().__init__(message)
class DistError(Exception):
def __init__(self, message="DistanceError"):
super().__init__(message)
def EED(answer_latex,test_latex,debug_mode=False):
"""
Computes the similarity score and distance metrics between two LaTeX expressions.
This function evaluates the equivalence of two mathematical expressions represented
in LaTeX format. It uses symbolic computation and tree-based distance metrics to
calculate a similarity score and other related metrics.
tuple: A tuple containing the following elements:
- score (float): The similarity score between the two expressions (0 to 100).
- relative_distance (float): The normalized distance between the two expressions.
- answer_tree_size (int): The size of the expression tree for the answer.
- distance (float): The raw distance between the two expression trees.
Notes:
- If either input contains unsupported LaTeX constructs (e.g., integrals or sums),
the function returns default values indicating failure.
- If the test expression is significantly longer than the answer expression,
the function assumes they are not equivalent.
- The function uses symbolic simplification and tree-based distance metrics to
evaluate equivalence.
- In case of errors during processing, the function returns default values unless
`debug_mode` is enabled, in which case it raises specific exceptions.
Exceptions:
- LaTeXError: Raised when LaTeX conversion to symbolic expressions fails (if `debug_mode` is True).
- SymPyError: Raised when symbolic simplification or tree construction fails (if `debug_mode` is True).
- DistError: Raised when distance calculation fails (if `debug_mode` is True).
Args:
answer_latex: the latex expression of answer expression
test_latex: the latex expression of test expression
debug_mode: whether it raise errors or just skip it
Returns:
tuple: A tuple containing the following elements:
- score (float): The similarity score between the two expressions (0 to 100).
- relative_distance (float): The normalized distance between the two expressions.
- answer_tree_size (int): The size of the expression tree for the answer.
- distance (float): The raw distance between the two expression trees.
"""
if not test_latex:
return 0,-1,-1,-1
if '\\int' in test_latex or '\\int' in answer_latex:
return 0,-1,-1,-1
if '\\sum' in test_latex or '\\sum' in answer_latex:
return 0,-1,-1,1
if answer_latex==test_latex:
return 100,0.0,-1,0
if len(test_latex)>3*len(answer_latex):
return 0,-1,-1,-1
try:
answer_exp=master_convert(answer_latex)
test_exp=master_convert(test_latex)
except:
print(f"Failed to convert input latex to sympy expression,please check it")
if debug_mode:
raise LaTeXError(f"Fail to convert latex.\n GT:{answer_latex}\n GEN:{test_latex}")
return 0,-1,-1,-1
try:
answer_exp,rep1=posify(answer_exp)
answer_exp=time_simplify(answer_exp)
test_exp,rep2=posify(test_exp)
test_exp=time_simplify(test_exp)
answer_exp=answer_exp.subs(rep1)
test_exp=test_exp.subs(rep2)
zero_exp=time_simplify(expand(answer_exp-test_exp))
if answer_exp==test_exp or zero_exp==0:
return 100,0.,0,0
if time_equal(answer_exp,test_exp):
return 100,0.,0,0
except:
print("Something happened during simplification,returning zero")
if debug_mode:
raise SymPyError(f"Failed to simplify the sympy expression. Expressions: answer_exp={answer_exp}, test_exp={test_exp}")
return 0,-1,-1,-1
try:
tree_answer=sympy_to_tree(answer_exp)
tree_test=sympy_to_tree(test_exp)
except:
print("Failed to build expression tree,returning zero")
if debug_mode:
raise SymPyError(f"Failed to build the sympy expression tree.\n GT:{answer_exp}\n GEN:{test_exp}")
return 0,-1,-1,-1
distance=ext_distance(
tree_test,
tree_answer,
get_children=lambda x:x.get_children(),
single_insert_cost=insert_func,
insert_cost=insert_tree_func,
single_remove_cost=remove_func,
remove_cost=remove_tree_func,
update_cost=update_func)
try:
distance=ext_distance(
tree_test,
tree_answer,
get_children=lambda x:x.get_children(),
single_insert_cost=insert_func,
insert_cost=insert_tree_func,
single_remove_cost=remove_func,
remove_cost=remove_tree_func,
update_cost=update_func
)
except:
print("Failed to calculate distance")
if debug_mode:
raise DistError(f"Failed to calculate the distance between trees.\n GT:{answer_latex}\n GEN:{test_latex}")
return 0,-1,calc_tree_size(tree_answer),-1
tree_size=calc_tree_size(tree_answer)
distance_number=distance
rel_distance=distance/tree_size
score=score_calc(distance_number,tree_size)
return score,rel_distance,tree_size,distance_number

View File

@ -0,0 +1,160 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#Original Authors: Tim Henderson and Steve Johnson
#Email: tim.tadh@gmail.com, steve@steveasleep.com
#For licensing see the LICENSE file in the top level directory.
# This is a modified version of zss package.
import collections
import numpy as np
from numpy import zeros,ones
class Node(object):
def __init__(self, label, children=None):
self.label = label
self.children = children or list()
@staticmethod
def get_children(node):
return node.children
@staticmethod
def get_label(node):
return node.label
def addkid(self, node, before=False):
if before: self.children.insert(0, node)
else: self.children.append(node)
return self
def get(self, label):
if self.label == label: return self
for c in self.children:
if label in c: return c.get(label)
class AnnotatedTree(object):
def __init__(self, root, get_children):
self.get_children = get_children
self.root = root
self.nodes = list() # a post-order enumeration of the nodes in the tree
self.ids = list() # a matching list of ids
self.lmds = list() # left most descendents of each nodes
self.keyroots = None
# the keyroots in the original paper
stack = list()
pstack = list()
stack.append((root, collections.deque()))
j = 0
while len(stack) > 0:
n, anc = stack.pop()
nid = j
for c in self.get_children(n):
a = collections.deque(anc)
a.appendleft(nid)
stack.append((c, a))
pstack.append(((n, nid), anc))
j += 1
lmds = dict()
keyroots = dict()
i = 0
while len(pstack) > 0:
(n, nid), anc = pstack.pop()
self.nodes.append(n)
self.ids.append(nid)
if not self.get_children(n):
lmd = i
for a in anc:
if a not in lmds: lmds[a] = i
else: break
else:
try: lmd = lmds[nid]
except:
import pdb
pdb.set_trace()
self.lmds.append(lmd)
keyroots[lmd] = i
i += 1
self.keyroots = sorted(keyroots.values())
def ext_distance(A, B, get_children, single_insert_cost,insert_cost,single_remove_cost, remove_cost, update_cost):
'''Computes the extended tree edit distance between trees A and B with extended-zss algorithm
Args:
A(Node): Root node of tree 1
B(Node): Root node of tree 2
get_children(Func): the get_children method of tree
single_insert_cost(Func): cost of inserting single node
insert_cost(Func): cost of inserting a subtree
update_cost(Func): cost of updating A to B
Return:
Distance(float):the tree editing distance
'''
A, B = AnnotatedTree(A, get_children), AnnotatedTree(B, get_children)
size_a = len(A.nodes)
size_b = len(B.nodes)
treedists = zeros((size_a, size_b), float)
fd=1000*ones((size_a+1,size_b+1),float)
operations = [[[] for _ in range(size_b)] for _ in range(size_a)]
def treedist(x, y):
Al = A.lmds
Bl = B.lmds
An = A.nodes
Bn = B.nodes
m = size_a
n = size_b
fd[Al[x]][Bl[y]]=0
for i in range(Al[x], x+1):
node = An[i]
fd[i+1][Bl[y]] = fd[Al[i]][Bl[y]] + remove_cost(node)
for j in range(Bl[y], y+1):
node = Bn[j]
fd[Al[x]][j+1] = fd[Al[x]][Bl[j]] + insert_cost(node)
for i in range(Al[x], x+1):
for j in range(Bl[y], y+1):
node1 = An[i]
node2 = Bn[j]
costs = [fd[i][j+1] + single_remove_cost(node1),
fd[i+1][j] + single_insert_cost(node2),
fd[Al[i]][j+1]+ remove_cost(node1),
fd[i+1][Bl[j]]+ insert_cost(node2)]
m=min(costs)
if Al[x] == Al[i] and Bl[y] == Bl[j]:
treedists[i][j]=min(m,fd[i][j]+update_cost(node1,node2))
fd[i+1][j+1]=treedists[i][j]
else:
fd[i+1][j+1]=min(m,fd[Al[i]][Bl[j]]+treedists[i][j])
for x in A.keyroots:
for y in B.keyroots:
treedist(x, y)
return treedists[-1][-1]

View File

@ -0,0 +1,523 @@
#This file is used to pre-process input latex expressions
#You only need a "master_convert()"
from latex2sympy2_extended import *
from sympy import simplify
def brackets_balanced(s: str) -> bool:
"""
Check if the brackets in a LaTeX string are balanced
Args:
s(str): the input string
Return:
bool: True if the brackets are balanced, False otherwise
"""
stack = []
bracket_pairs = {')': '(', ']': '[', '}': '{'}
for char in s:
if char in bracket_pairs.values():
stack.append(char)
elif char in bracket_pairs:
if not stack or stack[-1] != bracket_pairs[char]:
return False
stack.pop()
return len(stack) == 0
def remove_non_ascii(text):
return text.encode("ascii", errors="ignore").decode()
import re
def extract_bracket_content(s:str,bracket_position:int) -> str:
start_idx=bracket_position
stack = []
content = []
escaped = False
brace_start=start_idx+1
brace_depth = 0
for i in range(brace_start, len(s)):
char = s[i]
if escaped:
content.append(char)
escaped = False
continue
if char == '\\':
escaped = True
content.append(char)
continue
if char == '{':
brace_depth += 1
content.append(char)
elif char == '}':
if brace_depth == 0:
return ''.join(content),i
brace_depth -= 1
content.append(char)
else:
content.append(char)
return None,-1
def find_first_unescaped_brace(s: str) -> int:
escaped = False
for i, c in enumerate(s):
if c == '\\' and not escaped:
escaped = True
continue
if c == '{' and not escaped:
return i
escaped = False
return -1
def extract_command(s: str, brace_pos: int) -> str | None:
"""extract the command name from a bracket"""
i = brace_pos - 1
parameter_mode=False
while i >= 0:
if not parameter_mode and s[i] in ('^','_'):
return s[i]
if not parameter_mode and not s[i] in (' ','\t',']','['):
break
if s[i]==']':
parameter_mode=True
if s[i]=='[' and parameter_mode:
parameter_mode=False
i -= 1
# Start point
if i < 0 or s[i] == '\\':
return None
# Extract command name
command_end = i
i -= 1
while i >= 0 and s[i].isalpha():
i -= 1
if i<-1 or s[i]!='\\':
return None
return s[i+1:command_end+1]
def remove_command(s,command,keep_inside=False):
def remove_command(s, command, keep_inside=False):
"""
Removes all occurrences of a specified LaTeX-style command from a string.
This function searches for a given command in the input string `s` and removes it,
along with its associated content enclosed in curly braces `{}`. If `keep_inside`
is set to `True`, the content inside the braces is preserved, and only the command
itself is removed. The function handles nested braces correctly.
Args:
s (str): The input string from which the command should be removed.
command (str): The LaTeX-style command to be removed (e.g., "\\textbf").
keep_inside (bool, optional): If `True`, preserves the content inside the braces
while removing the command. Defaults to `False`.
Returns:
str: The modified string with the specified command removed.
Examples:
>>> remove_command("This is \\textbf{bold text}.", "\\textbf")
'This is bold text.'
>>> remove_command("This is \\textbf{bold text}.", "\\textbf", keep_inside=True)
'This is bold text.'
>>> remove_command("Nested \\textbf{bold \\textit{italic text}} example.", "\\textbf")
'Nested bold \\textit{italic text} example.'
"""
pos=s.find(command)
if pos<0:
return s
end_index=pos+len(command)
level=0
escaped=False
#print(end_index,s[end_index])
if end_index < len(s) and s[end_index] == "{":
while end_index<len(s):
if s[end_index]=='{':
level+=1
elif s[end_index]=='}':
level-=1
if level==0:
break
end_index+=1
else:
s1="".join([s[0:pos],s[end_index:]])
if keep_inside:
s1="".join([s[0:pos],s[pos+len(command)+1:end_index],s[end_index+1:]])
else:
s1="".join([s[0:pos],s[end_index+1:]])
if command not in s1:
return s1
else:
return remove_command(s1,command,keep_inside)
import re
def convert_latex_fractions(latex_str):
"""
Convert non-standard fraction like \frac\alpha2 to its standard-convertable \frac{\alpha}{2}
We suppoort single letter,number or standard form
"""
pattern = r'\\frac((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))'
def replacer(match):
numerator, denominator = match.group(1), match.group(2)
wrap_num = f'{{{numerator}}}' if not (numerator.startswith('{') and numerator.endswith('}')) else numerator
wrap_den = f'{{{denominator}}}' if not (denominator.startswith('{') and denominator.endswith('}')) else denominator
return fr'\frac{wrap_num}{wrap_den}'
return re.sub(pattern, replacer, latex_str)
def get_first_brace_command(s: str) -> str | None:
""" Find the first brace """
brace_pos = find_first_unescaped_brace(s)
if brace_pos == -1:
return None
return extract_command(s, brace_pos)
def remove_overall_brace(s:str) -> str:
"""
Remove the overall {xxx} brace
"""
pos=find_first_unescaped_brace(s)
if pos==-1:
return s,0
command=get_first_brace_command(s)
if not command:
content,final=extract_bracket_content(s,pos)
#print(s[final])
if final==len(s) or not '}' in s[final+1:]:
return content,1
return s,0
def exp_frac(s):
def exp_frac_single(s):
position=s.find("^\\frac")+1
if position == 0:
return s
level=0
cnt=0
idx=position
while idx<len(s):
if s[idx]=='{':
cnt+=1
elif s[idx]=='}':
cnt-=1
if cnt==0:
level+=1
if level==2:
break
idx+=1
s1="".join([s[0:position],'{',s[position:idx],'}',s[idx:]])
return s1
s1=exp_frac_single(s)
cnt=0
while s1 != s and cnt<100:
cnt+=1
s=s1
s1=exp_frac_single(s)
return s
def find_all(s, sub_str, allow_overlap=True):
indexes = []
start = 0
step = 1 if allow_overlap else len(sub_str)
cnt=0
while True and cnt<100:
pos = s.find(sub_str, start)
if pos == -1:
break
indexes.append(pos)
start = pos + step
cnt+=1
return indexes
def bar_inside_vec(s):
indices=find_all(s,"\\vec{")
if not indices:
return s
for i in range(len(indices)):
position=find_all(s,"\\vec{")[i]
idx=position+4
#print(s[idx])
idx2=idx
level=0
while idx2<len(s):
if s[idx2]=='{':
level+=1
if s[idx2]=='}':
level-=1
if level==0:
break
idx2+=1
s1=s[idx+1:idx2]
#print(s1)
s1=remove_command(s1,"\\bar",keep_inside=True)
s2= "".join([s[0:idx+1],s1,s[idx2:]])
s=s2
return s
def vec_lower_idx(input_str):
"""
in the annoying latex2sympy, error may occur when\ vec{a_{b}},we need\\vec{a_b}
Args
input_str (str): Original string
Return
str(str): Converted
"""
pattern = r'\\vec\{([^{}]+)_{([^{}]+)}\}'
replacement = r'\\vec{\1}_{\2}'
return re.sub(pattern, replacement, input_str)
def convert_vec_syntax(text):
"""
Converts LaTeX vector syntax to a standardized form.
This function processes a given text string and ensures that LaTeX vector
notations are consistently formatted. Specifically, it transforms instances
of `\vec xxx` into `\vec{xxx}`. The function handles cases where the vector
notation is applied to single characters, Greek letters, or LaTeX commands.
Args:
text (str): The input string containing LaTeX code to be processed.
Returns:
str: The processed string with standardized vector syntax.
Examples:
>>> convert_vec_syntax(r"\vec x + \vec\alpha + \vec\Gamma")
'\\vec{x} + \\vec{\\alpha} + \\vec{\\Gamma}'
"""
pattern = r'\\vec(\s*)(\\?[a-zA-Zα-ωΑ-Ω]+)'
replacement = r'\\vec{\2}'
return re.sub(pattern, replacement, text)
def remove_outer_braces(tex_str):
"""
convert {base}_{subscript} to base_{subscript}
Example
{a}_{xyz} a_{xyz}
{\theta}_{0} \theta_{0}
"""
pattern = r'\{(\\(?:[a-zA-Z]+|.)|[^{}])+\}_\{([^}]+)\}'
return re.sub(pattern, r'\1_{\2}', tex_str)
def extract_last_equal_content(s: str, strip_whitespace: bool = True) -> str:
"""
Extract the content after the last occurrence of specific mathematical comparison or assignment operators.
:param strip_whitespace: If True, removes leading and trailing whitespace from the extracted content. Defaults to True.
(e.g., '=', '\\approx', '\\ge', '\\le', etc.) within the input string `s`. It then extracts
and returns the content that follows the operator. If no operator is found, the entire string
is returned. Optionally, leading and trailing whitespace can be stripped from the extracted content.
Args:
s (str): The input string to process.
strip_whitespace (bool): Whether to strip leading and trailing whitespace from the extracted content. Defaults to True.
Returns:
str: The content after the last matching operator, or the entire string if no operator is found.
"""
comparison_operators=('=','\\approx','\\ge','\\le','\\geq','\\leq','<','>')
content=s
for sign in comparison_operators:
if sign in s:
rfind_index = s.rfind(sign)
if rfind_index != -1:
content = s[rfind_index + 1:]
if strip_whitespace:
return content.strip()
return content
def first_pre_process(s,extrac_box=True):
"""
Perform the first stage of LaTeX string preprocessing.
if not brackets_balanced(s):
raise ValueError("The input string has unbalanced brackets. Please check the LaTeX expression.")
equality or comparison operator.
Args:
s (str): The input LaTeX string to preprocess.
extrac_box (bool): If True, extracts the content inside a '\\boxed' command. Defaults to True.
Returns:
str: The preprocessed LaTeX string.
"""
#s=remove_non_ascii(s)
s=s.replace('\\{','(')
s=s.replace('\\}',')')
if not brackets_balanced(s):
return s
if extrac_box:
boxed_content=remove_command(s,'\\boxed',keep_inside=True)
else:
boxed_content=s
exist_overall_brace=True
cnt=0
while exist_overall_brace and cnt<10:
boxed_content,exist_overall_brace=remove_overall_brace(boxed_content)
cnt+=1
if '\\quad' in boxed_content:
boxed_content = boxed_content.split('\\quad')[0]
last_equal_content=extract_last_equal_content(boxed_content)
exist_overall_brace=True
cnt=0
while exist_overall_brace and cnt<10:
last_equal_content,exist_overall_brace=remove_overall_brace(last_equal_content)
cnt+=1
return last_equal_content
def second_pre_process(s):
"""
Perform the second stage of LaTeX string preprocessing.
This function removes or modifies specific LaTeX commands and content to standardize
the input string for further processing. It handles commands like '\\text', '\\mathbf',
and '\\mathrm', removes unnecessary content, and applies transformations such as
converting fractions and vector syntax.
Args:
s (str): The input LaTeX string to preprocess.
Returns:
str: The preprocessed LaTeX string.
"""
kill_commands=[
'\\begin',
'\\end'
]
remove_commands=[
'\\text',
'\\mathbf',
'\\mathrm',
'\\pmb',
'\\hat',
'\\overline',
'\\boldsymbol',
]
remove_content=[
'\\,','$',',','`','latex','\\left','\\right','\\text','\\mathrm','\\Bigr','\\Bigl','\n','\\]','\\[',
'\\Big','\\bigl','\\bigr','\\biggl','\\biggr','\\displaystyle','\\boldsymbol','\\infty'
]
replace_content=[
('\\operatorname{asin}','\\asin'),
('\\operatorname{sech}','\\sech'),
('\\operatorname{acos}','\\acos'),
('\\operatorname{sinh}','\\sinh'),
('\\dfrac','\\frac'),
('\\tfrac','\\frac'),
('\\Exp','\\exp'),
('\\times','\\bar{times}'),
('\\partial','\\bar{partial}'),
('\\perp','\\bar{perp}'),
('\\epsilon','\\varepsilon'),
('\\varOmega','\\Omega'),
('I','\\bar{I}'),
('_e','_{e}'),
('e_','\\bar{e}_'),
('E_','\\bar{E}_'),
('\\pm','+'),
('\\mp','-'),
('{+}','{p}'),
("{-}",'{m}'),
("_+",'_p'),
('_-',"_m")
]
for command in kill_commands:
s=remove_command(s,command,keep_inside=False)
for command in remove_commands:
s=remove_command(s,command,keep_inside=True)
for content in remove_content:
s=s.replace(content,'')
for content in replace_content:
s=s.replace(content[0],content[1])
s=convert_latex_fractions(s)
#print(s)
s=bar_inside_vec(s)
s=vec_lower_idx(s)
s=convert_vec_syntax(s)
s=exp_frac(s)
#s=remove_outer_braces(s)
if s and s[-1] == '.':
return s[:-1]
return s
class MyConfig:
interpret_as_mixed_fractions: bool = False
interpret_simple_eq_as_assignment: bool = False
interpret_contains_as_eq: bool = True
lowercase_symbols: bool = False
"""
Args:
interpret_as_mixed_fractions (bool): Whether to interpert 2 \frac{1}{2} as 2/2 or 2 + 1/2
interpret_simple_eq_as_assignment (bool): Whether to interpret simple equations as assignments k=1 -> 1
interpret_contains_as_eq (bool): Whether to interpret contains as equality x \\in {1,2,3} -> x = {1,2,3}
lowercase_symbols (bool): Whether to lowercase all symbols
"""
class MyNormalization:
"""Configuration for latex normalization.
Each field controls a group of related normalizations:
- basic_latex: Basic latex command replacements (mathrm, displaystyle, etc.)
- units: Remove units and their variations
- malformed_operators: Fix malformed operators (sqrt, frac, etc.)
- nits: Small formatting fixes (spaces, dots, etc.)
- boxed: Extract content from boxed environments
- equations: Handle equation splitting and approximations (deprecated)
"""
basic_latex: bool = True
units: bool = False
malformed_operators: bool = True
nits: bool = True
boxed = "all"
equations: bool = False
def master_convert(s):
"""
The only function needed to convert a LaTeX string into a SymPy expression.
Args:
s (str): The input LaTeX string. It should be a valid LaTeX mathematical expression,
such as equations, fractions, or symbols, and must have balanced brackets.
Returns:
Sym (Sympy Expression): A SymPy expression representing the mathematical content of the input string.
The returned object can be used for symbolic computation, simplification,
or evaluation using SymPy's functionality.
Example:
>>> master_convert("\\frac{1}{2} + x")
1/2 + x
"""
preprocessed_stage1=first_pre_process(s)
preprocessed_stage2=second_pre_process(preprocessed_stage1)
Sym=latex2sympy(preprocessed_stage2,normalization_config=MyNormalization(),conversion_config=MyConfig())
return Sym

View File

@ -0,0 +1,46 @@
import sys
import os
sys.path.append(os.path.dirname(__file__))
from EED import EED
from opencompass.openicl import BaseEvaluator
from opencompass.registry import ICL_EVALUATOR
@ICL_EVALUATOR.register_module()
class MathEEDEvaluator(BaseEvaluator):
def score(self, predictions, references):
scores = []
for pred, ref in zip(predictions, references):
score, _, _, _ = EED(ref, pred)
scores.append(score)
return {'accuracy': sum(scores) / len(scores)}
from opencompass.datasets import PhyBenchDataset
phybench_datasets = [
dict(
abbr='phybench-eed',
type=PhyBenchDataset,
path='opencompass/PHYBench',
reader_cfg=dict(
input_columns=['input'],
output_column='target',
),
infer_cfg=dict(
prompt_template=dict(
type='plain',
template='Solve the following physics problem and return only the final result as a clean LaTeX expression. No explanation. No text.\n\nQuestion: {{input}}\nAnswer: '
),
retriever=dict(type='zero_shot'),
inferencer=dict(type='gen', max_out_len=512)
),
eval_cfg=dict(
evaluator=dict(type=MathEEDEvaluator)
)
)
]

View File

@ -170,3 +170,4 @@ from .xcopa import * # noqa: F401, F403
from .xiezhi import XiezhiDataset, XiezhiRetriever # noqa: F401, F403
from .xlsum import * # noqa: F401, F403
from .xsum import * # noqa: F401, F403
from .phybench import *

View File

@ -0,0 +1,26 @@
import os.path as osp
import json
from datasets import Dataset
from opencompass.datasets.base import BaseDataset
from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path
@LOAD_DATASET.register_module()
class PhyBenchDataset(BaseDataset):
@staticmethod
def load(path: str, name: str = None, **kwargs):
path = get_data_path(path)
file_path = osp.join(path, 'PHYBench-fullques_v1.json')
with open(file_path, 'r', encoding='utf-8') as f:
raw_data = json.load(f)
inputs = [item['content'] for item in raw_data]
targets = [item['answer'] for item in raw_data]
return Dataset.from_dict({
'input': inputs,
'target': targets
})