OpenCompass/opencompass/datasets/PHYBench/EED/EED.py
2025-05-14 11:38:52 +00:00

632 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import timeout_decorator
from latex2sympy2_extended import latex2sympy
from sympy import (Add, Float, Function, Integer, Mul, Pow, Rational, Symbol,
expand, simplify)
from sympy.core.numbers import Exp1, Infinity, NegativeInfinity, Pi
from sympy.simplify import posify
from opencompass.datasets.PHYBench.EED.extended_zss import ext_distance
def brackets_balanced(s: str) -> bool:
stack = []
bracket_pairs = {')': '(', ']': '[', '}': '{'}
for char in s:
if char in bracket_pairs.values():
stack.append(char)
elif char in bracket_pairs:
if not stack or stack[-1] != bracket_pairs[char]:
return False
stack.pop()
return len(stack) == 0
def remove_non_ascii(text):
return text.encode('ascii', errors='ignore').decode()
def extract_bracket_content(s: str, bracket_position: int) -> str:
start_idx = bracket_position
content = []
escaped = False
brace_start = start_idx + 1
brace_depth = 0
for i in range(brace_start, len(s)):
char = s[i]
if escaped:
content.append(char)
escaped = False
continue
if char == '\\':
escaped = True
content.append(char)
continue
if char == '{':
brace_depth += 1
content.append(char)
elif char == '}':
if brace_depth == 0:
return ''.join(content), i
brace_depth -= 1
content.append(char)
else:
content.append(char)
return None, -1
def find_first_unescaped_brace(s: str) -> int:
escaped = False
for i, c in enumerate(s):
if c == '\\' and not escaped:
escaped = True
continue
if c == '{' and not escaped:
return i
escaped = False
return -1
def extract_command(s: str, brace_pos: int) -> str | None:
"""extract the command name from a bracket."""
i = brace_pos - 1
parameter_mode = False
while i >= 0:
if not parameter_mode and s[i] in ('^', '_'):
return s[i]
if not parameter_mode and not s[i] in (' ', '\t', ']', '['):
break
if s[i] == ']':
parameter_mode = True
if s[i] == '[' and parameter_mode:
parameter_mode = False
i -= 1
# Start point
if i < 0 or s[i] == '\\':
return None
# Extract command name
command_end = i
i -= 1
while i >= 0 and s[i].isalpha():
i -= 1
if i < -1 or s[i] != '\\':
return None
return s[i + 1:command_end + 1]
def remove_command(s, command, keep_inside=False):
pos = s.find(command)
if pos < 0:
return s
end_index = pos + len(command)
level = 0
if end_index < len(s) and s[end_index] == '{':
while end_index < len(s):
if s[end_index] == '{':
level += 1
elif s[end_index] == '}':
level -= 1
if level == 0:
break
end_index += 1
else:
s1 = ''.join([s[0:pos], s[end_index:]])
if keep_inside:
s1 = ''.join(
[s[0:pos], s[pos + len(command) + 1:end_index], s[end_index + 1:]])
else:
s1 = ''.join([s[0:pos], s[end_index + 1:]])
if command not in s1:
return s1
else:
return remove_command(s1, command, keep_inside)
def convert_latex_fractions(latex_str):
pattern = (r'\\frac((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))'
r'((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))')
def replacer(match):
numerator, denominator = match.group(1), match.group(2)
wrap_num = f'{{{numerator}}}' if not (
numerator.startswith('{')
and numerator.endswith('}')) else numerator
wrap_den = f'{{{denominator}}}' if not (
denominator.startswith('{')
and denominator.endswith('}')) else denominator
return fr'\frac{wrap_num}{wrap_den}'
return re.sub(pattern, replacer, latex_str)
def get_first_brace_command(s: str) -> str | None:
"""Find the first brace."""
brace_pos = find_first_unescaped_brace(s)
if brace_pos == -1:
return None
return extract_command(s, brace_pos)
def remove_overall_brace(s: str) -> str:
"""Remove the overall {xxx} brace."""
pos = find_first_unescaped_brace(s)
if pos == -1:
return s, 0
command = get_first_brace_command(s)
if not command:
content, final = extract_bracket_content(s, pos)
if final == len(s) or '}' not in s[final + 1:]:
return content, 1
return s, 0
def exp_frac(s):
def exp_frac_single(s):
position = s.find('^\\frac') + 1
if position == 0:
return s
level = 0
cnt = 0
idx = position
while idx < len(s):
if s[idx] == '{':
cnt += 1
elif s[idx] == '}':
cnt -= 1
if cnt == 0:
level += 1
if level == 2:
break
idx += 1
s1 = ''.join([s[0:position], '{', s[position:idx], '}', s[idx:]])
return s1
s1 = exp_frac_single(s)
cnt = 0
while s1 != s and cnt < 100:
cnt += 1
s = s1
s1 = exp_frac_single(s)
return s
def find_all(s, sub_str, allow_overlap=True):
indexes = []
start = 0
step = 1 if allow_overlap else len(sub_str)
cnt = 0
while True and cnt < 100:
pos = s.find(sub_str, start)
if pos == -1:
break
indexes.append(pos)
start = pos + step
cnt += 1
return indexes
def bar_inside_vec(s):
indices = find_all(s, '\\vec{')
if not indices:
return s
for i in range(len(indices)):
position = find_all(s, '\\vec{')[i]
idx = position + 4
idx2 = idx
level = 0
while idx2 < len(s):
if s[idx2] == '{':
level += 1
if s[idx2] == '}':
level -= 1
if level == 0:
break
idx2 += 1
s1 = s[idx + 1:idx2]
s1 = remove_command(s1, '\\bar', keep_inside=True)
s2 = ''.join([s[0:idx + 1], s1, s[idx2:]])
s = s2
return s
def vec_lower_idx(input_str):
pattern = r'\\vec\{([^{}]+)_{([^{}]+)}\}'
replacement = r'\\vec{\1}_{\2}'
return re.sub(pattern, replacement, input_str)
def convert_vec_syntax(text):
pattern = r'\\vec(\s*)(\\?[a-zA-Zα-ωΑ-Ω]+)'
replacement = r'\\vec{\2}'
return re.sub(pattern, replacement, text)
def remove_outer_braces(tex_str):
pattern = r'\{(\\(?:[a-zA-Z]+|.)|[^{}])+\}_\{([^}]+)\}'
return re.sub(pattern, r'\1_{\2}', tex_str)
def extract_last_equal_content(s: str, strip_whitespace: bool = True) -> str:
comparison_operators = ('=', '\\approx', '\\ge', '\\le', '\\geq', '\\leq',
'<', '>')
content = s
for sign in comparison_operators:
if sign in s:
rfind_index = s.rfind(sign)
if rfind_index != -1:
content = s[rfind_index + 1:]
if strip_whitespace:
return content.strip()
return content
def first_pre_process(s, extrac_box=True):
s = s.replace('\\{', '(')
s = s.replace('\\}', ')')
if not brackets_balanced(s):
return s
if extrac_box:
boxed_content = remove_command(s, '\\boxed', keep_inside=True)
else:
boxed_content = s
exist_overall_brace = True
cnt = 0
while exist_overall_brace and cnt < 10:
boxed_content, exist_overall_brace = remove_overall_brace(
boxed_content)
cnt += 1
if '\\quad' in boxed_content:
boxed_content = boxed_content.split('\\quad')[0]
last_equal_content = extract_last_equal_content(boxed_content)
exist_overall_brace = True
cnt = 0
while exist_overall_brace and cnt < 10:
last_equal_content, exist_overall_brace = remove_overall_brace(
last_equal_content)
cnt += 1
return last_equal_content
def second_pre_process(s):
kill_commands = ['\\begin', '\\end']
remove_commands = [
'\\text',
'\\mathbf',
'\\mathrm',
'\\pmb',
'\\hat',
'\\overline',
'\\boldsymbol',
]
remove_content = [
'\\,', '$', ',', '`', 'latex', '\\left', '\\right', '\\text',
'\\mathrm', '\\Bigr', '\\Bigl', '\n', '\\]', '\\[', '\\Big', '\\bigl',
'\\bigr', '\\biggl', '\\biggr', '\\displaystyle', '\\boldsymbol',
'\\infty'
]
replace_content = [
('\\operatorname{asin}', '\\asin'), ('\\operatorname{sech}', '\\sech'),
('\\operatorname{acos}', '\\acos'), ('\\operatorname{sinh}', '\\sinh'),
('\\dfrac', '\\frac'), ('\\tfrac', '\\frac'), ('\\Exp', '\\exp'),
('\\times', '\\bar{times}'), ('\\partial', '\\bar{partial}'),
('\\perp', '\\bar{perp}'), ('\\epsilon', '\\varepsilon'),
('\\varOmega', '\\Omega'), ('I', '\\bar{I}'), ('_e', '_{e}'),
('e_', '\\bar{e}_'), ('E_', '\\bar{E}_'), ('\\pm', '+'), ('\\mp', '-'),
('{+}', '{p}'), ('{-}', '{m}'), ('_+', '_p'), ('_-', '_m')
]
for command in kill_commands:
s = remove_command(s, command, keep_inside=False)
for command in remove_commands:
s = remove_command(s, command, keep_inside=True)
for content in remove_content:
s = s.replace(content, '')
for content in replace_content:
s = s.replace(content[0], content[1])
s = convert_latex_fractions(s)
s = bar_inside_vec(s)
s = vec_lower_idx(s)
s = convert_vec_syntax(s)
s = exp_frac(s)
if s and s[-1] == '.':
return s[:-1]
return s
class MyConfig:
interpret_as_mixed_fractions: bool = False
interpret_simple_eq_as_assignment: bool = False
interpret_contains_as_eq: bool = True
lowercase_symbols: bool = False
class MyNormalization:
basic_latex: bool = True
units: bool = False
malformed_operators: bool = True
nits: bool = True
boxed = 'all'
equations: bool = False
def master_convert(s):
preprocessed_stage1 = first_pre_process(s)
preprocessed_stage2 = second_pre_process(preprocessed_stage1)
Sym = latex2sympy(preprocessed_stage2,
normalization_config=MyNormalization(),
conversion_config=MyConfig())
return Sym
# The costs can be modified if you think their values are different
insert_cost = {'number': 1, 'symbol': 1, 'operator': 1, 'function': 1}
delete_cost = {'number': 1, 'symbol': 1, 'operator': 1, 'function': 1}
update_cost = {'number': 1, 'symbol': 1, 'operator': 1, 'function': 1}
change_type_cost = 1
bar_size = 5
discount_slope = 0.6
simplify_time_limit = 30
equals_time_limit = 10
def update_func(x, y):
if x.label == y.label:
return 0
elif x.label.split('_')[0] == y.label.split('_')[0]:
return update_cost[x.label.split('_')[0]]
return change_type_cost
def remove_func(x):
return delete_cost[x.label.split('_')[0]]
def remove_tree_func(x):
if not x.children:
return remove_func(x)
s = calc_tree_size(x)
return min(s, discount_slope * (s - bar_size) + bar_size)
def insert_func(x):
return insert_cost[x.label.split('_')[0]]
def insert_tree_func(x):
return remove_tree_func(x)
def calc_tree_size(node):
total = insert_cost[node.label.split('_')[0]]
if node.children and node.subtree_size != 0:
return node.subtree_size
for child in node.children:
total += calc_tree_size(child)
node.subtree_size = total
return total
"""
Scoring function from relative distance
"""
def score_calc(tree_dist, tree_size):
if tree_dist == 0.:
return 100
return max(0, 100 * discount_slope - 100 * tree_dist / tree_size)
@timeout_decorator.timeout(30, timeout_exception=TimeoutError)
def simplify_with_timeout(expr):
return simplify(expr)
def time_simplify(expr):
try:
result = simplify_with_timeout(expr)
return result
except TimeoutError:
return expr
@timeout_decorator.timeout(10, timeout_exception=TimeoutError)
def equal_with_timeout(expr1, expr2):
return expr1.equals(expr2)
def time_equal(expr1, expr2):
try:
result = equal_with_timeout(expr1, expr2)
return result
except TimeoutError:
return False
def sympy_to_tree(expr):
"""Convert the sympy expression to a tree."""
# Symbols and constants
if_list = [Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity]
for i in if_list:
if isinstance(expr, i):
return TreeNode(label='number_' + str(expr), children=[])
if isinstance(expr, (Symbol, )):
return TreeNode(label='symbol_' + str(expr), children=[])
# Binary operators
elif isinstance(expr, (Add, Mul, Pow)):
op_name = type(expr).__name__
children = [sympy_to_tree(arg) for arg in expr.args]
return TreeNode(label='operator_' + op_name, children=children)
elif isinstance(expr, (Function)):
# Functions
func_name = expr.func.__name__
children = [sympy_to_tree(arg) for arg in expr.args]
return TreeNode(label='function_' + func_name, children=children)
else:
raise ValueError(f'Unsupported SymPy type: {type(expr)}')
class TreeNode:
def __init__(self, label, children=None, node_type='other'):
self.label = label
self.children = children if children is not None else []
self.node_type = node_type
self.subtree_size = 0
def get_children(self):
return self.children
def __str__(self):
return self.label
def print_tree(node, indent=0):
"""Print a tree structure."""
print(' ' * indent + f'└─ {node.label}')
for child in node.children:
print_tree(child, indent + 1)
class LaTeXError(Exception):
def __init__(self, message='LaTeXError'):
super().__init__(message)
class SymPyError(Exception):
def __init__(self, message='SymPyError'):
super().__init__(message)
class TreeError(Exception):
def __init__(self, message='TreeError'):
super().__init__(message)
class DistError(Exception):
def __init__(self, message='DistanceError'):
super().__init__(message)
def EED(answer_latex, test_latex, debug_mode=False):
if not test_latex:
return 0, -1, -1, -1
if '\\int' in test_latex or '\\int' in answer_latex:
return 0, -1, -1, -1
if '\\sum' in test_latex or '\\sum' in answer_latex:
return 0, -1, -1, 1
if answer_latex == test_latex:
return 100, 0.0, -1, 0
if len(test_latex) > 3 * len(answer_latex):
return 0, -1, -1, -1
try:
answer_exp = master_convert(answer_latex)
test_exp = master_convert(test_latex)
except Exception:
if debug_mode:
raise LaTeXError(f'Fail to convert latex.\n GT:{answer_latex}\n'
f' GEN:{test_latex}')
return 0, -1, -1, -1
try:
answer_exp, rep1 = posify(answer_exp)
answer_exp = time_simplify(answer_exp)
test_exp, rep2 = posify(test_exp)
test_exp = time_simplify(test_exp)
answer_exp = answer_exp.subs(rep1)
test_exp = test_exp.subs(rep2)
zero_exp = time_simplify(expand(answer_exp - test_exp))
if answer_exp == test_exp or zero_exp == 0:
return 100, 0., 0, 0
if time_equal(answer_exp, test_exp):
return 100, 0., 0, 0
except Exception:
if debug_mode:
raise SymPyError(
f'Failed to simplify the sympy expression. Expressions: '
f'answer_exp={answer_exp}, test_exp={test_exp}')
return 0, -1, -1, -1
try:
tree_answer = sympy_to_tree(answer_exp)
tree_test = sympy_to_tree(test_exp)
except Exception:
if debug_mode:
raise SymPyError(f'Failed to build the sympy expression tree.\n'
f' GT:{answer_exp}\n GEN:{test_exp}')
return 0, -1, -1, -1
try:
distance = ext_distance(tree_test,
tree_answer,
get_children=lambda x: x.get_children(),
single_insert_cost=insert_func,
insert_cost=insert_tree_func,
single_remove_cost=remove_func,
remove_cost=remove_tree_func,
update_cost=update_func)
except Exception:
if debug_mode:
raise DistError(
f'Failed to calculate the distance between trees.\n'
f' GT:{answer_latex}\n GEN:{test_latex}')
return 0, -1, calc_tree_size(tree_answer), -1
tree_size = calc_tree_size(tree_answer)
distance_number = distance
rel_distance = distance / tree_size
score = score_calc(distance_number, tree_size)
return score, rel_distance, tree_size, distance_number