mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
362 lines
12 KiB
Python
362 lines
12 KiB
Python
![]() |
from sympy import *
|
||
|
from sympy.core.function import AppliedUndef
|
||
|
from sympy.core.numbers import Pi, Exp1,I,Infinity,NegativeInfinity
|
||
|
import numpy as np
|
||
|
import timeout_decorator
|
||
|
from extended_zss import ext_distance
|
||
|
from latex_pre_process import *
|
||
|
from sympy.simplify import *
|
||
|
"""
|
||
|
Guide:
|
||
|
You only need to use EED and install the following packages:
|
||
|
- sympy
|
||
|
- numpy
|
||
|
- latex2sympy2_extended
|
||
|
- timeout_decorator
|
||
|
"""
|
||
|
|
||
|
"""
|
||
|
There are four main categories:
|
||
|
|
||
|
Constants: such as integers, decimals, or mathematical constants like π and e.
|
||
|
Variables: letters like x, y, z, or specified terms in problems (e.g., ħ, c, G).
|
||
|
Functions: sine, cosine, exponential, logarithm, etc.
|
||
|
Operators: basic binary operations including addition, multiplication, and exponentiation.
|
||
|
"""
|
||
|
# The costs can be modified if you think their values are different
|
||
|
insert_cost={"number":1,"symbol":1,"operator":1,"function":1}
|
||
|
delete_cost={"number":1,"symbol":1,"operator":1,"function":1}
|
||
|
update_cost={"number":1,"symbol":1,"operator":1,"function":1}
|
||
|
|
||
|
change_type_cost=1 #the cost of an update between different types,can be set to higher
|
||
|
|
||
|
bar_size=5 # the minimum size of triggering cluster discount
|
||
|
discount_slope=0.6 #discount
|
||
|
|
||
|
simplify_time_limit=30 #set the time limit of simplify
|
||
|
equals_time_limit=10 #set the time limit of equals
|
||
|
|
||
|
def update_func(x,y):
|
||
|
|
||
|
if x.label==y.label:
|
||
|
return 0
|
||
|
|
||
|
elif x.label.split("_")[0]==y.label.split("_")[0]:
|
||
|
return update_cost[x.label.split("_")[0]]
|
||
|
return change_type_cost
|
||
|
def remove_func(x):
|
||
|
return delete_cost[x.label.split("_")[0]]
|
||
|
|
||
|
def remove_tree_func(x):
|
||
|
if not x.children:
|
||
|
return remove_func(x)
|
||
|
s=calc_tree_size(x)
|
||
|
return min(s,discount_slope*(s-bar_size)+bar_size)
|
||
|
|
||
|
|
||
|
def insert_func(x):
|
||
|
return insert_cost[x.label.split("_")[0]]
|
||
|
def insert_tree_func(x):
|
||
|
return remove_tree_func(x)
|
||
|
|
||
|
|
||
|
|
||
|
def calc_tree_size(node):
|
||
|
"""
|
||
|
Calculate the size of a subtree based on its total insertion cost.
|
||
|
The function computes the size of a subtree by summing up the insertion
|
||
|
costs of the current node and all its descendant nodes. If the subtree
|
||
|
size has already been calculated and stored in `node.subtree_size`, it
|
||
|
returns the cached value to avoid redundant computation.
|
||
|
Args:
|
||
|
node (Node): The root node of the subtree for which the size is to
|
||
|
be calculated
|
||
|
Returns:
|
||
|
int: The total size of the subtree, calculated as the sum of the
|
||
|
insertion costs of the current node and all its descendants.
|
||
|
Notes:
|
||
|
- The `insert_cost` dictionary is assumed to be globally defined
|
||
|
and maps node labels to their respective insertion costs.
|
||
|
- The function modifies the `subtree_size` attribute of the input
|
||
|
node to store the calculated subtree size for future use.
|
||
|
"""
|
||
|
"""The size of a subtree equals to its total insertion cost"""
|
||
|
|
||
|
total = insert_cost[node.label.split("_")[0]]
|
||
|
|
||
|
if node.children and node.subtree_size !=0:
|
||
|
|
||
|
return node.subtree_size
|
||
|
|
||
|
for child in node.children:
|
||
|
total += calc_tree_size(child)
|
||
|
|
||
|
node.subtree_size=total
|
||
|
|
||
|
return total
|
||
|
"""
|
||
|
Scoring function from relative distance
|
||
|
"""
|
||
|
def score_calc(tree_dist,tree_size):
|
||
|
|
||
|
if tree_dist==0.:
|
||
|
return 100
|
||
|
return max(0,100*discount_slope-100*tree_dist/tree_size)
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
@timeout_decorator.timeout(30, timeout_exception=TimeoutError)
|
||
|
def simplify_with_timeout(expr):
|
||
|
return simplify(expr)
|
||
|
def time_simplify(expr):
|
||
|
try:
|
||
|
result=simplify_with_timeout(expr)
|
||
|
return result
|
||
|
except TimeoutError:
|
||
|
return expr
|
||
|
|
||
|
@timeout_decorator.timeout(10, timeout_exception=TimeoutError)
|
||
|
def equal_with_timeout(expr1,expr2):
|
||
|
return expr1.equals(expr2)
|
||
|
def time_equal(expr1,expr2):
|
||
|
try:
|
||
|
result=equal_with_timeout(expr1,expr2)
|
||
|
return result
|
||
|
except TimeoutError:
|
||
|
return False
|
||
|
|
||
|
|
||
|
def sympy_to_tree(expr):
|
||
|
"""
|
||
|
Convert a SymPy expression into a tree structure.
|
||
|
This function takes a SymPy expression and recursively converts it into a tree
|
||
|
representation using `TreeNode` objects. Each node in the tree is labeled based
|
||
|
on the type of the SymPy expression (e.g., number, symbol, operator, or function),
|
||
|
and its children represent the arguments of the expression.
|
||
|
Args:
|
||
|
expr (sympy.Basic): The SymPy expression to be converted.
|
||
|
Returns:
|
||
|
TreeNode: The root node of the tree representation of the SymPy expression.
|
||
|
Raises:
|
||
|
ValueError: If the SymPy expression contains an unsupported type.
|
||
|
Supported Types:
|
||
|
- Numbers: Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity
|
||
|
- Symbols: Symbol
|
||
|
- Binary Operators: Add, Mul, Pow
|
||
|
- Functions: Any subclass of `sympy.Function`
|
||
|
Example:
|
||
|
>>> from sympy import symbols, sin, pi
|
||
|
>>> x, y = symbols('x y')
|
||
|
>>> expr = x + y * sin(pi)
|
||
|
>>> tree = sympy_to_tree(expr)
|
||
|
>>> print(tree)
|
||
|
"""
|
||
|
#print(expr)
|
||
|
|
||
|
"""Convert the sympy expression to a tree"""
|
||
|
# Symbols and constants
|
||
|
if isinstance(expr, (Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity)):
|
||
|
return TreeNode(label="number_"+str(expr), children=[])
|
||
|
elif isinstance(expr, (Symbol,)):
|
||
|
|
||
|
return TreeNode(label="symbol_"+str(expr),children=[])
|
||
|
|
||
|
|
||
|
# Binary operators
|
||
|
elif isinstance(expr, (Add, Mul, Pow)):
|
||
|
|
||
|
op_name = type(expr).__name__
|
||
|
children = [sympy_to_tree(arg) for arg in expr.args]
|
||
|
return TreeNode(label="operator_"+op_name, children=children)
|
||
|
|
||
|
|
||
|
elif isinstance(expr, (Function)):
|
||
|
# Functions
|
||
|
|
||
|
func_name = expr.func.__name__
|
||
|
children = [sympy_to_tree(arg) for arg in expr.args]
|
||
|
return TreeNode(label="function_"+func_name, children=children)
|
||
|
|
||
|
else:
|
||
|
#print(expr)
|
||
|
print(f"Unsupported Sympy type: {type(expr).__name__}, Expression: {expr}")
|
||
|
raise ValueError(f"Unsupported SymPy type: {type(expr)}")
|
||
|
|
||
|
class TreeNode:
|
||
|
def __init__(self, label, children=None,node_type='other'):
|
||
|
self.label = label
|
||
|
self.children = children if children is not None else []
|
||
|
self.node_type=node_type
|
||
|
self.subtree_size=0
|
||
|
def get_children(self):
|
||
|
return self.children
|
||
|
|
||
|
def __str__(self):
|
||
|
return self.label
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
def print_tree(node, indent=0):
|
||
|
"""Print a tree structure"""
|
||
|
print(' ' * indent + f'└─ {node.label}')
|
||
|
for child in node.children:
|
||
|
print_tree(child, indent + 1)
|
||
|
|
||
|
|
||
|
|
||
|
import timeout_decorator
|
||
|
|
||
|
class LaTeXError(Exception):
|
||
|
def __init__(self, message="LaTeXError"):
|
||
|
super().__init__(message)
|
||
|
class SymPyError(Exception):
|
||
|
def __init__(self, message="SymPyError"):
|
||
|
super().__init__(message)
|
||
|
|
||
|
|
||
|
class TreeError(Exception):
|
||
|
def __init__(self, message="TreeError"):
|
||
|
super().__init__(message)
|
||
|
|
||
|
|
||
|
class DistError(Exception):
|
||
|
def __init__(self, message="DistanceError"):
|
||
|
super().__init__(message)
|
||
|
|
||
|
def EED(answer_latex,test_latex,debug_mode=False):
|
||
|
"""
|
||
|
Computes the similarity score and distance metrics between two LaTeX expressions.
|
||
|
This function evaluates the equivalence of two mathematical expressions represented
|
||
|
in LaTeX format. It uses symbolic computation and tree-based distance metrics to
|
||
|
calculate a similarity score and other related metrics.
|
||
|
|
||
|
tuple: A tuple containing the following elements:
|
||
|
- score (float): The similarity score between the two expressions (0 to 100).
|
||
|
- relative_distance (float): The normalized distance between the two expressions.
|
||
|
- answer_tree_size (int): The size of the expression tree for the answer.
|
||
|
- distance (float): The raw distance between the two expression trees.
|
||
|
Notes:
|
||
|
- If either input contains unsupported LaTeX constructs (e.g., integrals or sums),
|
||
|
the function returns default values indicating failure.
|
||
|
- If the test expression is significantly longer than the answer expression,
|
||
|
the function assumes they are not equivalent.
|
||
|
- The function uses symbolic simplification and tree-based distance metrics to
|
||
|
evaluate equivalence.
|
||
|
- In case of errors during processing, the function returns default values unless
|
||
|
`debug_mode` is enabled, in which case it raises specific exceptions.
|
||
|
Exceptions:
|
||
|
- LaTeXError: Raised when LaTeX conversion to symbolic expressions fails (if `debug_mode` is True).
|
||
|
- SymPyError: Raised when symbolic simplification or tree construction fails (if `debug_mode` is True).
|
||
|
- DistError: Raised when distance calculation fails (if `debug_mode` is True).
|
||
|
Args:
|
||
|
answer_latex: the latex expression of answer expression
|
||
|
test_latex: the latex expression of test expression
|
||
|
debug_mode: whether it raise errors or just skip it
|
||
|
Returns:
|
||
|
tuple: A tuple containing the following elements:
|
||
|
- score (float): The similarity score between the two expressions (0 to 100).
|
||
|
- relative_distance (float): The normalized distance between the two expressions.
|
||
|
- answer_tree_size (int): The size of the expression tree for the answer.
|
||
|
- distance (float): The raw distance between the two expression trees.
|
||
|
"""
|
||
|
|
||
|
if not test_latex:
|
||
|
return 0,-1,-1,-1
|
||
|
if '\\int' in test_latex or '\\int' in answer_latex:
|
||
|
return 0,-1,-1,-1
|
||
|
if '\\sum' in test_latex or '\\sum' in answer_latex:
|
||
|
return 0,-1,-1,1
|
||
|
if answer_latex==test_latex:
|
||
|
return 100,0.0,-1,0
|
||
|
if len(test_latex)>3*len(answer_latex):
|
||
|
return 0,-1,-1,-1
|
||
|
|
||
|
try:
|
||
|
|
||
|
answer_exp=master_convert(answer_latex)
|
||
|
test_exp=master_convert(test_latex)
|
||
|
except:
|
||
|
print(f"Failed to convert input latex to sympy expression,please check it")
|
||
|
if debug_mode:
|
||
|
raise LaTeXError(f"Fail to convert latex.\n GT:{answer_latex}\n GEN:{test_latex}")
|
||
|
return 0,-1,-1,-1
|
||
|
|
||
|
try:
|
||
|
|
||
|
answer_exp,rep1=posify(answer_exp)
|
||
|
|
||
|
answer_exp=time_simplify(answer_exp)
|
||
|
|
||
|
|
||
|
test_exp,rep2=posify(test_exp)
|
||
|
test_exp=time_simplify(test_exp)
|
||
|
|
||
|
|
||
|
|
||
|
answer_exp=answer_exp.subs(rep1)
|
||
|
test_exp=test_exp.subs(rep2)
|
||
|
|
||
|
zero_exp=time_simplify(expand(answer_exp-test_exp))
|
||
|
|
||
|
|
||
|
if answer_exp==test_exp or zero_exp==0:
|
||
|
return 100,0.,0,0
|
||
|
|
||
|
if time_equal(answer_exp,test_exp):
|
||
|
return 100,0.,0,0
|
||
|
|
||
|
except:
|
||
|
print("Something happened during simplification,returning zero")
|
||
|
if debug_mode:
|
||
|
raise SymPyError(f"Failed to simplify the sympy expression. Expressions: answer_exp={answer_exp}, test_exp={test_exp}")
|
||
|
return 0,-1,-1,-1
|
||
|
|
||
|
try:
|
||
|
tree_answer=sympy_to_tree(answer_exp)
|
||
|
tree_test=sympy_to_tree(test_exp)
|
||
|
|
||
|
except:
|
||
|
|
||
|
print("Failed to build expression tree,returning zero")
|
||
|
if debug_mode:
|
||
|
raise SymPyError(f"Failed to build the sympy expression tree.\n GT:{answer_exp}\n GEN:{test_exp}")
|
||
|
return 0,-1,-1,-1
|
||
|
|
||
|
distance=ext_distance(
|
||
|
tree_test,
|
||
|
tree_answer,
|
||
|
get_children=lambda x:x.get_children(),
|
||
|
single_insert_cost=insert_func,
|
||
|
insert_cost=insert_tree_func,
|
||
|
single_remove_cost=remove_func,
|
||
|
remove_cost=remove_tree_func,
|
||
|
update_cost=update_func)
|
||
|
try:
|
||
|
|
||
|
|
||
|
distance=ext_distance(
|
||
|
tree_test,
|
||
|
tree_answer,
|
||
|
get_children=lambda x:x.get_children(),
|
||
|
single_insert_cost=insert_func,
|
||
|
insert_cost=insert_tree_func,
|
||
|
single_remove_cost=remove_func,
|
||
|
remove_cost=remove_tree_func,
|
||
|
update_cost=update_func
|
||
|
)
|
||
|
except:
|
||
|
print("Failed to calculate distance")
|
||
|
if debug_mode:
|
||
|
raise DistError(f"Failed to calculate the distance between trees.\n GT:{answer_latex}\n GEN:{test_latex}")
|
||
|
return 0,-1,calc_tree_size(tree_answer),-1
|
||
|
tree_size=calc_tree_size(tree_answer)
|
||
|
distance_number=distance
|
||
|
|
||
|
rel_distance=distance/tree_size
|
||
|
|
||
|
score=score_calc(distance_number,tree_size)
|
||
|
|
||
|
return score,rel_distance,tree_size,distance_number
|