mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
add phybench
This commit is contained in:
parent
c3779ebfc1
commit
811161c2d1
@ -1113,3 +1113,9 @@
|
||||
paper: https://arxiv.org/pdf/2203.14371
|
||||
configpath: opencompass/configs/datasets/medmcqa/medmcqa_gen.py
|
||||
configpath_llmjudge: opencompass/configs/datasets/medmcqa/medmcqa_llmjudge_gen.py
|
||||
- phybench:
|
||||
name: PHYBench
|
||||
category: Science /Physics
|
||||
paper: https://arxiv.org/abs/2504.16074
|
||||
configpath: opencompass/configs/datasets/PHYBench/phybench_gen.py
|
||||
configpath_llmjudge: ''
|
361
opencompass/configs/datasets/PHYBench/EED.py
Normal file
361
opencompass/configs/datasets/PHYBench/EED.py
Normal file
@ -0,0 +1,361 @@
|
||||
from sympy import *
|
||||
from sympy.core.function import AppliedUndef
|
||||
from sympy.core.numbers import Pi, Exp1,I,Infinity,NegativeInfinity
|
||||
import numpy as np
|
||||
import timeout_decorator
|
||||
from extended_zss import ext_distance
|
||||
from latex_pre_process import *
|
||||
from sympy.simplify import *
|
||||
"""
|
||||
Guide:
|
||||
You only need to use EED and install the following packages:
|
||||
- sympy
|
||||
- numpy
|
||||
- latex2sympy2_extended
|
||||
- timeout_decorator
|
||||
"""
|
||||
|
||||
"""
|
||||
There are four main categories:
|
||||
|
||||
Constants: such as integers, decimals, or mathematical constants like π and e.
|
||||
Variables: letters like x, y, z, or specified terms in problems (e.g., ħ, c, G).
|
||||
Functions: sine, cosine, exponential, logarithm, etc.
|
||||
Operators: basic binary operations including addition, multiplication, and exponentiation.
|
||||
"""
|
||||
# The costs can be modified if you think their values are different
|
||||
insert_cost={"number":1,"symbol":1,"operator":1,"function":1}
|
||||
delete_cost={"number":1,"symbol":1,"operator":1,"function":1}
|
||||
update_cost={"number":1,"symbol":1,"operator":1,"function":1}
|
||||
|
||||
change_type_cost=1 #the cost of an update between different types,can be set to higher
|
||||
|
||||
bar_size=5 # the minimum size of triggering cluster discount
|
||||
discount_slope=0.6 #discount
|
||||
|
||||
simplify_time_limit=30 #set the time limit of simplify
|
||||
equals_time_limit=10 #set the time limit of equals
|
||||
|
||||
def update_func(x,y):
|
||||
|
||||
if x.label==y.label:
|
||||
return 0
|
||||
|
||||
elif x.label.split("_")[0]==y.label.split("_")[0]:
|
||||
return update_cost[x.label.split("_")[0]]
|
||||
return change_type_cost
|
||||
def remove_func(x):
|
||||
return delete_cost[x.label.split("_")[0]]
|
||||
|
||||
def remove_tree_func(x):
|
||||
if not x.children:
|
||||
return remove_func(x)
|
||||
s=calc_tree_size(x)
|
||||
return min(s,discount_slope*(s-bar_size)+bar_size)
|
||||
|
||||
|
||||
def insert_func(x):
|
||||
return insert_cost[x.label.split("_")[0]]
|
||||
def insert_tree_func(x):
|
||||
return remove_tree_func(x)
|
||||
|
||||
|
||||
|
||||
def calc_tree_size(node):
|
||||
"""
|
||||
Calculate the size of a subtree based on its total insertion cost.
|
||||
The function computes the size of a subtree by summing up the insertion
|
||||
costs of the current node and all its descendant nodes. If the subtree
|
||||
size has already been calculated and stored in `node.subtree_size`, it
|
||||
returns the cached value to avoid redundant computation.
|
||||
Args:
|
||||
node (Node): The root node of the subtree for which the size is to
|
||||
be calculated
|
||||
Returns:
|
||||
int: The total size of the subtree, calculated as the sum of the
|
||||
insertion costs of the current node and all its descendants.
|
||||
Notes:
|
||||
- The `insert_cost` dictionary is assumed to be globally defined
|
||||
and maps node labels to their respective insertion costs.
|
||||
- The function modifies the `subtree_size` attribute of the input
|
||||
node to store the calculated subtree size for future use.
|
||||
"""
|
||||
"""The size of a subtree equals to its total insertion cost"""
|
||||
|
||||
total = insert_cost[node.label.split("_")[0]]
|
||||
|
||||
if node.children and node.subtree_size !=0:
|
||||
|
||||
return node.subtree_size
|
||||
|
||||
for child in node.children:
|
||||
total += calc_tree_size(child)
|
||||
|
||||
node.subtree_size=total
|
||||
|
||||
return total
|
||||
"""
|
||||
Scoring function from relative distance
|
||||
"""
|
||||
def score_calc(tree_dist,tree_size):
|
||||
|
||||
if tree_dist==0.:
|
||||
return 100
|
||||
return max(0,100*discount_slope-100*tree_dist/tree_size)
|
||||
|
||||
|
||||
|
||||
|
||||
@timeout_decorator.timeout(30, timeout_exception=TimeoutError)
|
||||
def simplify_with_timeout(expr):
|
||||
return simplify(expr)
|
||||
def time_simplify(expr):
|
||||
try:
|
||||
result=simplify_with_timeout(expr)
|
||||
return result
|
||||
except TimeoutError:
|
||||
return expr
|
||||
|
||||
@timeout_decorator.timeout(10, timeout_exception=TimeoutError)
|
||||
def equal_with_timeout(expr1,expr2):
|
||||
return expr1.equals(expr2)
|
||||
def time_equal(expr1,expr2):
|
||||
try:
|
||||
result=equal_with_timeout(expr1,expr2)
|
||||
return result
|
||||
except TimeoutError:
|
||||
return False
|
||||
|
||||
|
||||
def sympy_to_tree(expr):
|
||||
"""
|
||||
Convert a SymPy expression into a tree structure.
|
||||
This function takes a SymPy expression and recursively converts it into a tree
|
||||
representation using `TreeNode` objects. Each node in the tree is labeled based
|
||||
on the type of the SymPy expression (e.g., number, symbol, operator, or function),
|
||||
and its children represent the arguments of the expression.
|
||||
Args:
|
||||
expr (sympy.Basic): The SymPy expression to be converted.
|
||||
Returns:
|
||||
TreeNode: The root node of the tree representation of the SymPy expression.
|
||||
Raises:
|
||||
ValueError: If the SymPy expression contains an unsupported type.
|
||||
Supported Types:
|
||||
- Numbers: Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity
|
||||
- Symbols: Symbol
|
||||
- Binary Operators: Add, Mul, Pow
|
||||
- Functions: Any subclass of `sympy.Function`
|
||||
Example:
|
||||
>>> from sympy import symbols, sin, pi
|
||||
>>> x, y = symbols('x y')
|
||||
>>> expr = x + y * sin(pi)
|
||||
>>> tree = sympy_to_tree(expr)
|
||||
>>> print(tree)
|
||||
"""
|
||||
#print(expr)
|
||||
|
||||
"""Convert the sympy expression to a tree"""
|
||||
# Symbols and constants
|
||||
if isinstance(expr, (Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity)):
|
||||
return TreeNode(label="number_"+str(expr), children=[])
|
||||
elif isinstance(expr, (Symbol,)):
|
||||
|
||||
return TreeNode(label="symbol_"+str(expr),children=[])
|
||||
|
||||
|
||||
# Binary operators
|
||||
elif isinstance(expr, (Add, Mul, Pow)):
|
||||
|
||||
op_name = type(expr).__name__
|
||||
children = [sympy_to_tree(arg) for arg in expr.args]
|
||||
return TreeNode(label="operator_"+op_name, children=children)
|
||||
|
||||
|
||||
elif isinstance(expr, (Function)):
|
||||
# Functions
|
||||
|
||||
func_name = expr.func.__name__
|
||||
children = [sympy_to_tree(arg) for arg in expr.args]
|
||||
return TreeNode(label="function_"+func_name, children=children)
|
||||
|
||||
else:
|
||||
#print(expr)
|
||||
print(f"Unsupported Sympy type: {type(expr).__name__}, Expression: {expr}")
|
||||
raise ValueError(f"Unsupported SymPy type: {type(expr)}")
|
||||
|
||||
class TreeNode:
|
||||
def __init__(self, label, children=None,node_type='other'):
|
||||
self.label = label
|
||||
self.children = children if children is not None else []
|
||||
self.node_type=node_type
|
||||
self.subtree_size=0
|
||||
def get_children(self):
|
||||
return self.children
|
||||
|
||||
def __str__(self):
|
||||
return self.label
|
||||
|
||||
|
||||
|
||||
|
||||
def print_tree(node, indent=0):
|
||||
"""Print a tree structure"""
|
||||
print(' ' * indent + f'└─ {node.label}')
|
||||
for child in node.children:
|
||||
print_tree(child, indent + 1)
|
||||
|
||||
|
||||
|
||||
import timeout_decorator
|
||||
|
||||
class LaTeXError(Exception):
|
||||
def __init__(self, message="LaTeXError"):
|
||||
super().__init__(message)
|
||||
class SymPyError(Exception):
|
||||
def __init__(self, message="SymPyError"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class TreeError(Exception):
|
||||
def __init__(self, message="TreeError"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class DistError(Exception):
|
||||
def __init__(self, message="DistanceError"):
|
||||
super().__init__(message)
|
||||
|
||||
def EED(answer_latex,test_latex,debug_mode=False):
|
||||
"""
|
||||
Computes the similarity score and distance metrics between two LaTeX expressions.
|
||||
This function evaluates the equivalence of two mathematical expressions represented
|
||||
in LaTeX format. It uses symbolic computation and tree-based distance metrics to
|
||||
calculate a similarity score and other related metrics.
|
||||
|
||||
tuple: A tuple containing the following elements:
|
||||
- score (float): The similarity score between the two expressions (0 to 100).
|
||||
- relative_distance (float): The normalized distance between the two expressions.
|
||||
- answer_tree_size (int): The size of the expression tree for the answer.
|
||||
- distance (float): The raw distance between the two expression trees.
|
||||
Notes:
|
||||
- If either input contains unsupported LaTeX constructs (e.g., integrals or sums),
|
||||
the function returns default values indicating failure.
|
||||
- If the test expression is significantly longer than the answer expression,
|
||||
the function assumes they are not equivalent.
|
||||
- The function uses symbolic simplification and tree-based distance metrics to
|
||||
evaluate equivalence.
|
||||
- In case of errors during processing, the function returns default values unless
|
||||
`debug_mode` is enabled, in which case it raises specific exceptions.
|
||||
Exceptions:
|
||||
- LaTeXError: Raised when LaTeX conversion to symbolic expressions fails (if `debug_mode` is True).
|
||||
- SymPyError: Raised when symbolic simplification or tree construction fails (if `debug_mode` is True).
|
||||
- DistError: Raised when distance calculation fails (if `debug_mode` is True).
|
||||
Args:
|
||||
answer_latex: the latex expression of answer expression
|
||||
test_latex: the latex expression of test expression
|
||||
debug_mode: whether it raise errors or just skip it
|
||||
Returns:
|
||||
tuple: A tuple containing the following elements:
|
||||
- score (float): The similarity score between the two expressions (0 to 100).
|
||||
- relative_distance (float): The normalized distance between the two expressions.
|
||||
- answer_tree_size (int): The size of the expression tree for the answer.
|
||||
- distance (float): The raw distance between the two expression trees.
|
||||
"""
|
||||
|
||||
if not test_latex:
|
||||
return 0,-1,-1,-1
|
||||
if '\\int' in test_latex or '\\int' in answer_latex:
|
||||
return 0,-1,-1,-1
|
||||
if '\\sum' in test_latex or '\\sum' in answer_latex:
|
||||
return 0,-1,-1,1
|
||||
if answer_latex==test_latex:
|
||||
return 100,0.0,-1,0
|
||||
if len(test_latex)>3*len(answer_latex):
|
||||
return 0,-1,-1,-1
|
||||
|
||||
try:
|
||||
|
||||
answer_exp=master_convert(answer_latex)
|
||||
test_exp=master_convert(test_latex)
|
||||
except:
|
||||
print(f"Failed to convert input latex to sympy expression,please check it")
|
||||
if debug_mode:
|
||||
raise LaTeXError(f"Fail to convert latex.\n GT:{answer_latex}\n GEN:{test_latex}")
|
||||
return 0,-1,-1,-1
|
||||
|
||||
try:
|
||||
|
||||
answer_exp,rep1=posify(answer_exp)
|
||||
|
||||
answer_exp=time_simplify(answer_exp)
|
||||
|
||||
|
||||
test_exp,rep2=posify(test_exp)
|
||||
test_exp=time_simplify(test_exp)
|
||||
|
||||
|
||||
|
||||
answer_exp=answer_exp.subs(rep1)
|
||||
test_exp=test_exp.subs(rep2)
|
||||
|
||||
zero_exp=time_simplify(expand(answer_exp-test_exp))
|
||||
|
||||
|
||||
if answer_exp==test_exp or zero_exp==0:
|
||||
return 100,0.,0,0
|
||||
|
||||
if time_equal(answer_exp,test_exp):
|
||||
return 100,0.,0,0
|
||||
|
||||
except:
|
||||
print("Something happened during simplification,returning zero")
|
||||
if debug_mode:
|
||||
raise SymPyError(f"Failed to simplify the sympy expression. Expressions: answer_exp={answer_exp}, test_exp={test_exp}")
|
||||
return 0,-1,-1,-1
|
||||
|
||||
try:
|
||||
tree_answer=sympy_to_tree(answer_exp)
|
||||
tree_test=sympy_to_tree(test_exp)
|
||||
|
||||
except:
|
||||
|
||||
print("Failed to build expression tree,returning zero")
|
||||
if debug_mode:
|
||||
raise SymPyError(f"Failed to build the sympy expression tree.\n GT:{answer_exp}\n GEN:{test_exp}")
|
||||
return 0,-1,-1,-1
|
||||
|
||||
distance=ext_distance(
|
||||
tree_test,
|
||||
tree_answer,
|
||||
get_children=lambda x:x.get_children(),
|
||||
single_insert_cost=insert_func,
|
||||
insert_cost=insert_tree_func,
|
||||
single_remove_cost=remove_func,
|
||||
remove_cost=remove_tree_func,
|
||||
update_cost=update_func)
|
||||
try:
|
||||
|
||||
|
||||
distance=ext_distance(
|
||||
tree_test,
|
||||
tree_answer,
|
||||
get_children=lambda x:x.get_children(),
|
||||
single_insert_cost=insert_func,
|
||||
insert_cost=insert_tree_func,
|
||||
single_remove_cost=remove_func,
|
||||
remove_cost=remove_tree_func,
|
||||
update_cost=update_func
|
||||
)
|
||||
except:
|
||||
print("Failed to calculate distance")
|
||||
if debug_mode:
|
||||
raise DistError(f"Failed to calculate the distance between trees.\n GT:{answer_latex}\n GEN:{test_latex}")
|
||||
return 0,-1,calc_tree_size(tree_answer),-1
|
||||
tree_size=calc_tree_size(tree_answer)
|
||||
distance_number=distance
|
||||
|
||||
rel_distance=distance/tree_size
|
||||
|
||||
score=score_calc(distance_number,tree_size)
|
||||
|
||||
return score,rel_distance,tree_size,distance_number
|
160
opencompass/configs/datasets/PHYBench/extended_zss.py
Normal file
160
opencompass/configs/datasets/PHYBench/extended_zss.py
Normal file
@ -0,0 +1,160 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#Original Authors: Tim Henderson and Steve Johnson
|
||||
#Email: tim.tadh@gmail.com, steve@steveasleep.com
|
||||
#For licensing see the LICENSE file in the top level directory.
|
||||
|
||||
# This is a modified version of zss package.
|
||||
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
from numpy import zeros,ones
|
||||
|
||||
class Node(object):
|
||||
|
||||
|
||||
def __init__(self, label, children=None):
|
||||
self.label = label
|
||||
self.children = children or list()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_children(node):
|
||||
return node.children
|
||||
|
||||
@staticmethod
|
||||
def get_label(node):
|
||||
return node.label
|
||||
|
||||
def addkid(self, node, before=False):
|
||||
|
||||
if before: self.children.insert(0, node)
|
||||
else: self.children.append(node)
|
||||
return self
|
||||
|
||||
def get(self, label):
|
||||
|
||||
if self.label == label: return self
|
||||
for c in self.children:
|
||||
if label in c: return c.get(label)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class AnnotatedTree(object):
|
||||
|
||||
def __init__(self, root, get_children):
|
||||
self.get_children = get_children
|
||||
|
||||
self.root = root
|
||||
self.nodes = list() # a post-order enumeration of the nodes in the tree
|
||||
self.ids = list() # a matching list of ids
|
||||
self.lmds = list() # left most descendents of each nodes
|
||||
self.keyroots = None
|
||||
# the keyroots in the original paper
|
||||
|
||||
|
||||
stack = list()
|
||||
pstack = list()
|
||||
stack.append((root, collections.deque()))
|
||||
j = 0
|
||||
while len(stack) > 0:
|
||||
n, anc = stack.pop()
|
||||
nid = j
|
||||
for c in self.get_children(n):
|
||||
a = collections.deque(anc)
|
||||
a.appendleft(nid)
|
||||
stack.append((c, a))
|
||||
pstack.append(((n, nid), anc))
|
||||
j += 1
|
||||
lmds = dict()
|
||||
keyroots = dict()
|
||||
i = 0
|
||||
while len(pstack) > 0:
|
||||
(n, nid), anc = pstack.pop()
|
||||
self.nodes.append(n)
|
||||
self.ids.append(nid)
|
||||
if not self.get_children(n):
|
||||
lmd = i
|
||||
for a in anc:
|
||||
if a not in lmds: lmds[a] = i
|
||||
else: break
|
||||
else:
|
||||
try: lmd = lmds[nid]
|
||||
except:
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
self.lmds.append(lmd)
|
||||
keyroots[lmd] = i
|
||||
i += 1
|
||||
self.keyroots = sorted(keyroots.values())
|
||||
|
||||
|
||||
def ext_distance(A, B, get_children, single_insert_cost,insert_cost,single_remove_cost, remove_cost, update_cost):
|
||||
'''Computes the extended tree edit distance between trees A and B with extended-zss algorithm
|
||||
Args:
|
||||
A(Node): Root node of tree 1
|
||||
B(Node): Root node of tree 2
|
||||
get_children(Func): the get_children method of tree
|
||||
single_insert_cost(Func): cost of inserting single node
|
||||
insert_cost(Func): cost of inserting a subtree
|
||||
update_cost(Func): cost of updating A to B
|
||||
|
||||
|
||||
Return:
|
||||
Distance(float):the tree editing distance
|
||||
'''
|
||||
A, B = AnnotatedTree(A, get_children), AnnotatedTree(B, get_children)
|
||||
size_a = len(A.nodes)
|
||||
size_b = len(B.nodes)
|
||||
treedists = zeros((size_a, size_b), float)
|
||||
fd=1000*ones((size_a+1,size_b+1),float)
|
||||
operations = [[[] for _ in range(size_b)] for _ in range(size_a)]
|
||||
|
||||
|
||||
def treedist(x, y):
|
||||
Al = A.lmds
|
||||
Bl = B.lmds
|
||||
An = A.nodes
|
||||
Bn = B.nodes
|
||||
|
||||
m = size_a
|
||||
n = size_b
|
||||
|
||||
fd[Al[x]][Bl[y]]=0
|
||||
for i in range(Al[x], x+1):
|
||||
node = An[i]
|
||||
fd[i+1][Bl[y]] = fd[Al[i]][Bl[y]] + remove_cost(node)
|
||||
|
||||
for j in range(Bl[y], y+1):
|
||||
node = Bn[j]
|
||||
|
||||
fd[Al[x]][j+1] = fd[Al[x]][Bl[j]] + insert_cost(node)
|
||||
|
||||
for i in range(Al[x], x+1):
|
||||
for j in range(Bl[y], y+1):
|
||||
|
||||
node1 = An[i]
|
||||
node2 = Bn[j]
|
||||
costs = [fd[i][j+1] + single_remove_cost(node1),
|
||||
fd[i+1][j] + single_insert_cost(node2),
|
||||
fd[Al[i]][j+1]+ remove_cost(node1),
|
||||
fd[i+1][Bl[j]]+ insert_cost(node2)]
|
||||
m=min(costs)
|
||||
|
||||
if Al[x] == Al[i] and Bl[y] == Bl[j]:
|
||||
treedists[i][j]=min(m,fd[i][j]+update_cost(node1,node2))
|
||||
fd[i+1][j+1]=treedists[i][j]
|
||||
else:
|
||||
fd[i+1][j+1]=min(m,fd[Al[i]][Bl[j]]+treedists[i][j])
|
||||
|
||||
|
||||
for x in A.keyroots:
|
||||
for y in B.keyroots:
|
||||
treedist(x, y)
|
||||
|
||||
return treedists[-1][-1]
|
||||
|
523
opencompass/configs/datasets/PHYBench/latex_pre_process.py
Normal file
523
opencompass/configs/datasets/PHYBench/latex_pre_process.py
Normal file
@ -0,0 +1,523 @@
|
||||
#This file is used to pre-process input latex expressions
|
||||
#You only need a "master_convert()"
|
||||
from latex2sympy2_extended import *
|
||||
from sympy import simplify
|
||||
|
||||
|
||||
|
||||
def brackets_balanced(s: str) -> bool:
|
||||
"""
|
||||
Check if the brackets in a LaTeX string are balanced
|
||||
Args:
|
||||
s(str): the input string
|
||||
Return:
|
||||
bool: True if the brackets are balanced, False otherwise
|
||||
"""
|
||||
stack = []
|
||||
bracket_pairs = {')': '(', ']': '[', '}': '{'}
|
||||
|
||||
for char in s:
|
||||
if char in bracket_pairs.values():
|
||||
stack.append(char)
|
||||
elif char in bracket_pairs:
|
||||
if not stack or stack[-1] != bracket_pairs[char]:
|
||||
return False
|
||||
stack.pop()
|
||||
return len(stack) == 0
|
||||
|
||||
|
||||
|
||||
def remove_non_ascii(text):
|
||||
return text.encode("ascii", errors="ignore").decode()
|
||||
|
||||
import re
|
||||
def extract_bracket_content(s:str,bracket_position:int) -> str:
|
||||
start_idx=bracket_position
|
||||
|
||||
stack = []
|
||||
content = []
|
||||
escaped = False
|
||||
brace_start=start_idx+1
|
||||
brace_depth = 0
|
||||
for i in range(brace_start, len(s)):
|
||||
char = s[i]
|
||||
if escaped:
|
||||
content.append(char)
|
||||
escaped = False
|
||||
continue
|
||||
if char == '\\':
|
||||
escaped = True
|
||||
content.append(char)
|
||||
continue
|
||||
if char == '{':
|
||||
brace_depth += 1
|
||||
content.append(char)
|
||||
elif char == '}':
|
||||
if brace_depth == 0:
|
||||
return ''.join(content),i
|
||||
brace_depth -= 1
|
||||
content.append(char)
|
||||
else:
|
||||
content.append(char)
|
||||
|
||||
return None,-1
|
||||
def find_first_unescaped_brace(s: str) -> int:
|
||||
escaped = False
|
||||
for i, c in enumerate(s):
|
||||
if c == '\\' and not escaped:
|
||||
escaped = True
|
||||
continue
|
||||
if c == '{' and not escaped:
|
||||
return i
|
||||
escaped = False
|
||||
return -1
|
||||
|
||||
def extract_command(s: str, brace_pos: int) -> str | None:
|
||||
"""extract the command name from a bracket"""
|
||||
i = brace_pos - 1
|
||||
parameter_mode=False
|
||||
while i >= 0:
|
||||
if not parameter_mode and s[i] in ('^','_'):
|
||||
return s[i]
|
||||
if not parameter_mode and not s[i] in (' ','\t',']','['):
|
||||
break
|
||||
if s[i]==']':
|
||||
parameter_mode=True
|
||||
if s[i]=='[' and parameter_mode:
|
||||
parameter_mode=False
|
||||
i -= 1
|
||||
|
||||
# Start point
|
||||
if i < 0 or s[i] == '\\':
|
||||
return None
|
||||
|
||||
# Extract command name
|
||||
command_end = i
|
||||
i -= 1
|
||||
while i >= 0 and s[i].isalpha():
|
||||
i -= 1
|
||||
if i<-1 or s[i]!='\\':
|
||||
return None
|
||||
return s[i+1:command_end+1]
|
||||
|
||||
|
||||
def remove_command(s,command,keep_inside=False):
|
||||
def remove_command(s, command, keep_inside=False):
|
||||
"""
|
||||
Removes all occurrences of a specified LaTeX-style command from a string.
|
||||
|
||||
This function searches for a given command in the input string `s` and removes it,
|
||||
along with its associated content enclosed in curly braces `{}`. If `keep_inside`
|
||||
is set to `True`, the content inside the braces is preserved, and only the command
|
||||
itself is removed. The function handles nested braces correctly.
|
||||
|
||||
Args:
|
||||
s (str): The input string from which the command should be removed.
|
||||
command (str): The LaTeX-style command to be removed (e.g., "\\textbf").
|
||||
keep_inside (bool, optional): If `True`, preserves the content inside the braces
|
||||
while removing the command. Defaults to `False`.
|
||||
|
||||
Returns:
|
||||
str: The modified string with the specified command removed.
|
||||
|
||||
Examples:
|
||||
>>> remove_command("This is \\textbf{bold text}.", "\\textbf")
|
||||
'This is bold text.'
|
||||
|
||||
>>> remove_command("This is \\textbf{bold text}.", "\\textbf", keep_inside=True)
|
||||
'This is bold text.'
|
||||
|
||||
>>> remove_command("Nested \\textbf{bold \\textit{italic text}} example.", "\\textbf")
|
||||
'Nested bold \\textit{italic text} example.'
|
||||
"""
|
||||
pos=s.find(command)
|
||||
if pos<0:
|
||||
return s
|
||||
end_index=pos+len(command)
|
||||
level=0
|
||||
escaped=False
|
||||
#print(end_index,s[end_index])
|
||||
if end_index < len(s) and s[end_index] == "{":
|
||||
while end_index<len(s):
|
||||
|
||||
if s[end_index]=='{':
|
||||
level+=1
|
||||
elif s[end_index]=='}':
|
||||
level-=1
|
||||
if level==0:
|
||||
break
|
||||
end_index+=1
|
||||
else:
|
||||
s1="".join([s[0:pos],s[end_index:]])
|
||||
if keep_inside:
|
||||
s1="".join([s[0:pos],s[pos+len(command)+1:end_index],s[end_index+1:]])
|
||||
else:
|
||||
s1="".join([s[0:pos],s[end_index+1:]])
|
||||
|
||||
if command not in s1:
|
||||
return s1
|
||||
else:
|
||||
return remove_command(s1,command,keep_inside)
|
||||
|
||||
import re
|
||||
|
||||
def convert_latex_fractions(latex_str):
|
||||
"""
|
||||
Convert non-standard fraction like \frac\alpha2 to its standard-convertable \frac{\alpha}{2}
|
||||
We suppoort single letter,number or standard form
|
||||
"""
|
||||
pattern = r'\\frac((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))'
|
||||
|
||||
def replacer(match):
|
||||
numerator, denominator = match.group(1), match.group(2)
|
||||
wrap_num = f'{{{numerator}}}' if not (numerator.startswith('{') and numerator.endswith('}')) else numerator
|
||||
wrap_den = f'{{{denominator}}}' if not (denominator.startswith('{') and denominator.endswith('}')) else denominator
|
||||
return fr'\frac{wrap_num}{wrap_den}'
|
||||
|
||||
return re.sub(pattern, replacer, latex_str)
|
||||
|
||||
|
||||
def get_first_brace_command(s: str) -> str | None:
|
||||
""" Find the first brace """
|
||||
brace_pos = find_first_unescaped_brace(s)
|
||||
if brace_pos == -1:
|
||||
return None
|
||||
return extract_command(s, brace_pos)
|
||||
def remove_overall_brace(s:str) -> str:
|
||||
"""
|
||||
Remove the overall {xxx} brace
|
||||
"""
|
||||
pos=find_first_unescaped_brace(s)
|
||||
if pos==-1:
|
||||
return s,0
|
||||
command=get_first_brace_command(s)
|
||||
if not command:
|
||||
|
||||
content,final=extract_bracket_content(s,pos)
|
||||
#print(s[final])
|
||||
if final==len(s) or not '}' in s[final+1:]:
|
||||
return content,1
|
||||
return s,0
|
||||
|
||||
|
||||
def exp_frac(s):
|
||||
|
||||
def exp_frac_single(s):
|
||||
position=s.find("^\\frac")+1
|
||||
if position == 0:
|
||||
return s
|
||||
level=0
|
||||
cnt=0
|
||||
idx=position
|
||||
while idx<len(s):
|
||||
if s[idx]=='{':
|
||||
cnt+=1
|
||||
elif s[idx]=='}':
|
||||
cnt-=1
|
||||
if cnt==0:
|
||||
level+=1
|
||||
if level==2:
|
||||
break
|
||||
idx+=1
|
||||
s1="".join([s[0:position],'{',s[position:idx],'}',s[idx:]])
|
||||
return s1
|
||||
s1=exp_frac_single(s)
|
||||
cnt=0
|
||||
while s1 != s and cnt<100:
|
||||
cnt+=1
|
||||
s=s1
|
||||
s1=exp_frac_single(s)
|
||||
return s
|
||||
|
||||
|
||||
|
||||
def find_all(s, sub_str, allow_overlap=True):
|
||||
indexes = []
|
||||
start = 0
|
||||
step = 1 if allow_overlap else len(sub_str)
|
||||
cnt=0
|
||||
while True and cnt<100:
|
||||
pos = s.find(sub_str, start)
|
||||
if pos == -1:
|
||||
break
|
||||
indexes.append(pos)
|
||||
start = pos + step
|
||||
cnt+=1
|
||||
return indexes
|
||||
|
||||
def bar_inside_vec(s):
|
||||
indices=find_all(s,"\\vec{")
|
||||
if not indices:
|
||||
return s
|
||||
for i in range(len(indices)):
|
||||
position=find_all(s,"\\vec{")[i]
|
||||
idx=position+4
|
||||
#print(s[idx])
|
||||
idx2=idx
|
||||
level=0
|
||||
while idx2<len(s):
|
||||
if s[idx2]=='{':
|
||||
level+=1
|
||||
if s[idx2]=='}':
|
||||
level-=1
|
||||
if level==0:
|
||||
break
|
||||
idx2+=1
|
||||
|
||||
s1=s[idx+1:idx2]
|
||||
#print(s1)
|
||||
|
||||
s1=remove_command(s1,"\\bar",keep_inside=True)
|
||||
s2= "".join([s[0:idx+1],s1,s[idx2:]])
|
||||
s=s2
|
||||
return s
|
||||
def vec_lower_idx(input_str):
|
||||
"""
|
||||
in the annoying latex2sympy, error may occur when\ vec{a_{b}},we need\\vec{a_b}
|
||||
Args:
|
||||
input_str (str): Original string
|
||||
|
||||
Return:
|
||||
str(str): Converted
|
||||
"""
|
||||
pattern = r'\\vec\{([^{}]+)_{([^{}]+)}\}'
|
||||
replacement = r'\\vec{\1}_{\2}'
|
||||
return re.sub(pattern, replacement, input_str)
|
||||
def convert_vec_syntax(text):
|
||||
"""
|
||||
Converts LaTeX vector syntax to a standardized form.
|
||||
|
||||
This function processes a given text string and ensures that LaTeX vector
|
||||
notations are consistently formatted. Specifically, it transforms instances
|
||||
of `\vec xxx` into `\vec{xxx}`. The function handles cases where the vector
|
||||
notation is applied to single characters, Greek letters, or LaTeX commands.
|
||||
|
||||
Args:
|
||||
text (str): The input string containing LaTeX code to be processed.
|
||||
|
||||
Returns:
|
||||
str: The processed string with standardized vector syntax.
|
||||
|
||||
Examples:
|
||||
>>> convert_vec_syntax(r"\vec x + \vec\alpha + \vec\Gamma")
|
||||
'\\vec{x} + \\vec{\\alpha} + \\vec{\\Gamma}'
|
||||
"""
|
||||
|
||||
pattern = r'\\vec(\s*)(\\?[a-zA-Zα-ωΑ-Ω]+)'
|
||||
replacement = r'\\vec{\2}'
|
||||
return re.sub(pattern, replacement, text)
|
||||
|
||||
def remove_outer_braces(tex_str):
|
||||
"""
|
||||
convert {base}_{subscript} to base_{subscript}
|
||||
Example:
|
||||
{a}_{xyz} → a_{xyz}
|
||||
{\theta}_{0} → \theta_{0}
|
||||
"""
|
||||
pattern = r'\{(\\(?:[a-zA-Z]+|.)|[^{}])+\}_\{([^}]+)\}'
|
||||
return re.sub(pattern, r'\1_{\2}', tex_str)
|
||||
|
||||
def extract_last_equal_content(s: str, strip_whitespace: bool = True) -> str:
|
||||
"""
|
||||
Extract the content after the last occurrence of specific mathematical comparison or assignment operators.
|
||||
|
||||
:param strip_whitespace: If True, removes leading and trailing whitespace from the extracted content. Defaults to True.
|
||||
(e.g., '=', '\\approx', '\\ge', '\\le', etc.) within the input string `s`. It then extracts
|
||||
and returns the content that follows the operator. If no operator is found, the entire string
|
||||
is returned. Optionally, leading and trailing whitespace can be stripped from the extracted content.
|
||||
|
||||
Args:
|
||||
s (str): The input string to process.
|
||||
strip_whitespace (bool): Whether to strip leading and trailing whitespace from the extracted content. Defaults to True.
|
||||
|
||||
Returns:
|
||||
str: The content after the last matching operator, or the entire string if no operator is found.
|
||||
"""
|
||||
comparison_operators=('=','\\approx','\\ge','\\le','\\geq','\\leq','<','>')
|
||||
|
||||
content=s
|
||||
for sign in comparison_operators:
|
||||
if sign in s:
|
||||
rfind_index = s.rfind(sign)
|
||||
if rfind_index != -1:
|
||||
content = s[rfind_index + 1:]
|
||||
if strip_whitespace:
|
||||
return content.strip()
|
||||
return content
|
||||
|
||||
|
||||
def first_pre_process(s,extrac_box=True):
|
||||
"""
|
||||
Perform the first stage of LaTeX string preprocessing.
|
||||
|
||||
if not brackets_balanced(s):
|
||||
raise ValueError("The input string has unbalanced brackets. Please check the LaTeX expression.")
|
||||
equality or comparison operator.
|
||||
|
||||
Args:
|
||||
s (str): The input LaTeX string to preprocess.
|
||||
extrac_box (bool): If True, extracts the content inside a '\\boxed' command. Defaults to True.
|
||||
|
||||
Returns:
|
||||
str: The preprocessed LaTeX string.
|
||||
"""
|
||||
#s=remove_non_ascii(s)
|
||||
s=s.replace('\\{','(')
|
||||
s=s.replace('\\}',')')
|
||||
if not brackets_balanced(s):
|
||||
return s
|
||||
if extrac_box:
|
||||
boxed_content=remove_command(s,'\\boxed',keep_inside=True)
|
||||
else:
|
||||
boxed_content=s
|
||||
exist_overall_brace=True
|
||||
cnt=0
|
||||
while exist_overall_brace and cnt<10:
|
||||
boxed_content,exist_overall_brace=remove_overall_brace(boxed_content)
|
||||
cnt+=1
|
||||
|
||||
if '\\quad' in boxed_content:
|
||||
boxed_content = boxed_content.split('\\quad')[0]
|
||||
|
||||
last_equal_content=extract_last_equal_content(boxed_content)
|
||||
|
||||
exist_overall_brace=True
|
||||
cnt=0
|
||||
while exist_overall_brace and cnt<10:
|
||||
last_equal_content,exist_overall_brace=remove_overall_brace(last_equal_content)
|
||||
cnt+=1
|
||||
return last_equal_content
|
||||
def second_pre_process(s):
|
||||
"""
|
||||
Perform the second stage of LaTeX string preprocessing.
|
||||
|
||||
This function removes or modifies specific LaTeX commands and content to standardize
|
||||
the input string for further processing. It handles commands like '\\text', '\\mathbf',
|
||||
and '\\mathrm', removes unnecessary content, and applies transformations such as
|
||||
converting fractions and vector syntax.
|
||||
|
||||
Args:
|
||||
s (str): The input LaTeX string to preprocess.
|
||||
|
||||
Returns:
|
||||
str: The preprocessed LaTeX string.
|
||||
"""
|
||||
|
||||
|
||||
kill_commands=[
|
||||
'\\begin',
|
||||
'\\end'
|
||||
]
|
||||
remove_commands=[
|
||||
'\\text',
|
||||
'\\mathbf',
|
||||
'\\mathrm',
|
||||
'\\pmb',
|
||||
'\\hat',
|
||||
'\\overline',
|
||||
'\\boldsymbol',
|
||||
]
|
||||
|
||||
|
||||
remove_content=[
|
||||
'\\,','$',',','`','latex','\\left','\\right','\\text','\\mathrm','\\Bigr','\\Bigl','\n','\\]','\\[',
|
||||
'\\Big','\\bigl','\\bigr','\\biggl','\\biggr','\\displaystyle','\\boldsymbol','\\infty'
|
||||
]
|
||||
replace_content=[
|
||||
('\\operatorname{asin}','\\asin'),
|
||||
('\\operatorname{sech}','\\sech'),
|
||||
('\\operatorname{acos}','\\acos'),
|
||||
('\\operatorname{sinh}','\\sinh'),
|
||||
('\\dfrac','\\frac'),
|
||||
('\\tfrac','\\frac'),
|
||||
('\\Exp','\\exp'),
|
||||
('\\times','\\bar{times}'),
|
||||
('\\partial','\\bar{partial}'),
|
||||
('\\perp','\\bar{perp}'),
|
||||
('\\epsilon','\\varepsilon'),
|
||||
('\\varOmega','\\Omega'),
|
||||
('I','\\bar{I}'),
|
||||
('_e','_{e}'),
|
||||
('e_','\\bar{e}_'),
|
||||
('E_','\\bar{E}_'),
|
||||
('\\pm','+'),
|
||||
('\\mp','-'),
|
||||
('{+}','{p}'),
|
||||
("{-}",'{m}'),
|
||||
("_+",'_p'),
|
||||
('_-',"_m")
|
||||
]
|
||||
for command in kill_commands:
|
||||
s=remove_command(s,command,keep_inside=False)
|
||||
for command in remove_commands:
|
||||
s=remove_command(s,command,keep_inside=True)
|
||||
for content in remove_content:
|
||||
s=s.replace(content,'')
|
||||
for content in replace_content:
|
||||
s=s.replace(content[0],content[1])
|
||||
s=convert_latex_fractions(s)
|
||||
#print(s)
|
||||
s=bar_inside_vec(s)
|
||||
s=vec_lower_idx(s)
|
||||
s=convert_vec_syntax(s)
|
||||
s=exp_frac(s)
|
||||
#s=remove_outer_braces(s)
|
||||
if s and s[-1] == '.':
|
||||
return s[:-1]
|
||||
return s
|
||||
|
||||
|
||||
class MyConfig:
|
||||
|
||||
interpret_as_mixed_fractions: bool = False
|
||||
interpret_simple_eq_as_assignment: bool = False
|
||||
interpret_contains_as_eq: bool = True
|
||||
lowercase_symbols: bool = False
|
||||
"""
|
||||
Args:
|
||||
interpret_as_mixed_fractions (bool): Whether to interpert 2 \frac{1}{2} as 2/2 or 2 + 1/2
|
||||
interpret_simple_eq_as_assignment (bool): Whether to interpret simple equations as assignments k=1 -> 1
|
||||
interpret_contains_as_eq (bool): Whether to interpret contains as equality x \\in {1,2,3} -> x = {1,2,3}
|
||||
lowercase_symbols (bool): Whether to lowercase all symbols
|
||||
"""
|
||||
class MyNormalization:
|
||||
"""Configuration for latex normalization.
|
||||
|
||||
Each field controls a group of related normalizations:
|
||||
- basic_latex: Basic latex command replacements (mathrm, displaystyle, etc.)
|
||||
- units: Remove units and their variations
|
||||
- malformed_operators: Fix malformed operators (sqrt, frac, etc.)
|
||||
- nits: Small formatting fixes (spaces, dots, etc.)
|
||||
- boxed: Extract content from boxed environments
|
||||
- equations: Handle equation splitting and approximations (deprecated)
|
||||
"""
|
||||
basic_latex: bool = True
|
||||
units: bool = False
|
||||
malformed_operators: bool = True
|
||||
nits: bool = True
|
||||
boxed = "all"
|
||||
equations: bool = False
|
||||
|
||||
def master_convert(s):
|
||||
"""
|
||||
The only function needed to convert a LaTeX string into a SymPy expression.
|
||||
|
||||
Args:
|
||||
s (str): The input LaTeX string. It should be a valid LaTeX mathematical expression,
|
||||
such as equations, fractions, or symbols, and must have balanced brackets.
|
||||
|
||||
Returns:
|
||||
Sym (Sympy Expression): A SymPy expression representing the mathematical content of the input string.
|
||||
The returned object can be used for symbolic computation, simplification,
|
||||
or evaluation using SymPy's functionality.
|
||||
|
||||
Example:
|
||||
>>> master_convert("\\frac{1}{2} + x")
|
||||
1/2 + x
|
||||
"""
|
||||
preprocessed_stage1=first_pre_process(s)
|
||||
|
||||
preprocessed_stage2=second_pre_process(preprocessed_stage1)
|
||||
|
||||
Sym=latex2sympy(preprocessed_stage2,normalization_config=MyNormalization(),conversion_config=MyConfig())
|
||||
return Sym
|
46
opencompass/configs/datasets/PHYBench/phybench_gen.py
Normal file
46
opencompass/configs/datasets/PHYBench/phybench_gen.py
Normal file
@ -0,0 +1,46 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
from EED import EED
|
||||
from opencompass.openicl import BaseEvaluator
|
||||
from opencompass.registry import ICL_EVALUATOR
|
||||
|
||||
|
||||
@ICL_EVALUATOR.register_module()
|
||||
class MathEEDEvaluator(BaseEvaluator):
|
||||
def score(self, predictions, references):
|
||||
scores = []
|
||||
for pred, ref in zip(predictions, references):
|
||||
score, _, _, _ = EED(ref, pred)
|
||||
scores.append(score)
|
||||
return {'accuracy': sum(scores) / len(scores)}
|
||||
|
||||
|
||||
|
||||
from opencompass.datasets import PhyBenchDataset
|
||||
|
||||
phybench_datasets = [
|
||||
dict(
|
||||
abbr='phybench-eed',
|
||||
type=PhyBenchDataset,
|
||||
path='opencompass/PHYBench',
|
||||
reader_cfg=dict(
|
||||
input_columns=['input'],
|
||||
output_column='target',
|
||||
),
|
||||
infer_cfg=dict(
|
||||
prompt_template=dict(
|
||||
type='plain',
|
||||
template='Solve the following physics problem and return only the final result as a clean LaTeX expression. No explanation. No text.\n\nQuestion: {{input}}\nAnswer: '
|
||||
),
|
||||
retriever=dict(type='zero_shot'),
|
||||
inferencer=dict(type='gen', max_out_len=512)
|
||||
),
|
||||
eval_cfg=dict(
|
||||
evaluator=dict(type=MathEEDEvaluator)
|
||||
)
|
||||
)
|
||||
]
|
@ -170,3 +170,4 @@ from .xcopa import * # noqa: F401, F403
|
||||
from .xiezhi import XiezhiDataset, XiezhiRetriever # noqa: F401, F403
|
||||
from .xlsum import * # noqa: F401, F403
|
||||
from .xsum import * # noqa: F401, F403
|
||||
from .phybench import *
|
26
opencompass/datasets/phybench.py
Normal file
26
opencompass/datasets/phybench.py
Normal file
@ -0,0 +1,26 @@
|
||||
import os.path as osp
|
||||
import json
|
||||
from datasets import Dataset
|
||||
from opencompass.datasets.base import BaseDataset
|
||||
from opencompass.registry import LOAD_DATASET
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class PhyBenchDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, name: str = None, **kwargs):
|
||||
path = get_data_path(path)
|
||||
|
||||
file_path = osp.join(path, 'PHYBench-fullques_v1.json')
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
raw_data = json.load(f)
|
||||
|
||||
inputs = [item['content'] for item in raw_data]
|
||||
targets = [item['answer'] for item in raw_data]
|
||||
|
||||
return Dataset.from_dict({
|
||||
'input': inputs,
|
||||
'target': targets
|
||||
})
|
Loading…
Reference in New Issue
Block a user