diff --git a/opencompass/datasets/PHYBench/EED/EED.py b/opencompass/datasets/PHYBench/EED/EED.py new file mode 100644 index 00000000..7e811a56 --- /dev/null +++ b/opencompass/datasets/PHYBench/EED/EED.py @@ -0,0 +1,636 @@ +import re + +import timeout_decorator +from latex2sympy2_extended import latex2sympy +from sympy import (Add, Float, Function, Integer, Mul, Pow, Rational, Symbol, + expand, simplify) +from sympy.core.numbers import Exp1, Infinity, NegativeInfinity, Pi +from sympy.simplify import posify + +from opencompass.datasets.PHYBench.EED.extended_zss import ext_distance + + +def brackets_balanced(s: str) -> bool: + stack = [] + bracket_pairs = {')': '(', ']': '[', '}': '{'} + + for char in s: + if char in bracket_pairs.values(): + stack.append(char) + elif char in bracket_pairs: + if not stack or stack[-1] != bracket_pairs[char]: + return False + stack.pop() + return len(stack) == 0 + + +def remove_non_ascii(text): + return text.encode('ascii', errors='ignore').decode() + + +def extract_bracket_content(s: str, bracket_position: int) -> str: + start_idx = bracket_position + + content = [] + escaped = False + brace_start = start_idx + 1 + brace_depth = 0 + for i in range(brace_start, len(s)): + char = s[i] + if escaped: + content.append(char) + escaped = False + continue + if char == '\\': + escaped = True + content.append(char) + continue + if char == '{': + brace_depth += 1 + content.append(char) + elif char == '}': + if brace_depth == 0: + return ''.join(content), i + brace_depth -= 1 + content.append(char) + else: + content.append(char) + + return None, -1 + + +def find_first_unescaped_brace(s: str) -> int: + escaped = False + for i, c in enumerate(s): + if c == '\\' and not escaped: + escaped = True + continue + if c == '{' and not escaped: + return i + escaped = False + return -1 + + +def extract_command(s: str, brace_pos: int) -> str | None: + """extract the command name from a bracket.""" + i = brace_pos - 1 + parameter_mode = False + while i >= 0: + if not parameter_mode and s[i] in ('^', '_'): + return s[i] + if not parameter_mode and not s[i] in (' ', '\t', ']', '['): + break + if s[i] == ']': + parameter_mode = True + if s[i] == '[' and parameter_mode: + parameter_mode = False + i -= 1 + + # Start point + if i < 0 or s[i] == '\\': + return None + + # Extract command name + command_end = i + i -= 1 + while i >= 0 and s[i].isalpha(): + i -= 1 + if i < -1 or s[i] != '\\': + return None + return s[i + 1:command_end + 1] + + +def remove_command(s, command, keep_inside=False): + pos = s.find(command) + if pos < 0: + return s + end_index = pos + len(command) + level = 0 + if end_index < len(s) and s[end_index] == '{': + while end_index < len(s): + if s[end_index] == '{': + level += 1 + elif s[end_index] == '}': + level -= 1 + if level == 0: + break + end_index += 1 + else: + s1 = ''.join([s[0:pos], s[end_index:]]) + if keep_inside: + s1 = ''.join( + [s[0:pos], s[pos + len(command) + 1:end_index], s[end_index + 1:]]) + else: + s1 = ''.join([s[0:pos], s[end_index + 1:]]) + + if command not in s1: + return s1 + else: + return remove_command(s1, command, keep_inside) + + +def convert_latex_fractions(latex_str): + """Convert non-standard fraction like \frac\alpha2 to its standard- + convertable \frac{\alpha}{2}. + + We support single letter, number or standard form. + """ + pattern = (r'\\frac((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))' + r'((?:\\[a-zA-Z]+|\d|[a-zA-Z]|{[^{}]*}))') + + def replacer(match): + numerator, denominator = match.group(1), match.group(2) + wrap_num = f'{{{numerator}}}' if not ( + numerator.startswith('{') + and numerator.endswith('}')) else numerator + wrap_den = f'{{{denominator}}}' if not ( + denominator.startswith('{') + and denominator.endswith('}')) else denominator + return fr'\frac{wrap_num}{wrap_den}' + + return re.sub(pattern, replacer, latex_str) + + +def get_first_brace_command(s: str) -> str | None: + """Find the first brace.""" + brace_pos = find_first_unescaped_brace(s) + if brace_pos == -1: + return None + return extract_command(s, brace_pos) + + +def remove_overall_brace(s: str) -> str: + """Remove the overall {xxx} brace.""" + pos = find_first_unescaped_brace(s) + if pos == -1: + return s, 0 + command = get_first_brace_command(s) + if not command: + content, final = extract_bracket_content(s, pos) + if final == len(s) or '}' not in s[final + 1:]: + return content, 1 + return s, 0 + + +def exp_frac(s): + + def exp_frac_single(s): + position = s.find('^\\frac') + 1 + if position == 0: + return s + level = 0 + cnt = 0 + idx = position + while idx < len(s): + if s[idx] == '{': + cnt += 1 + elif s[idx] == '}': + cnt -= 1 + if cnt == 0: + level += 1 + if level == 2: + break + idx += 1 + s1 = ''.join([s[0:position], '{', s[position:idx], '}', s[idx:]]) + return s1 + + s1 = exp_frac_single(s) + cnt = 0 + while s1 != s and cnt < 100: + cnt += 1 + s = s1 + s1 = exp_frac_single(s) + return s + + +def find_all(s, sub_str, allow_overlap=True): + indexes = [] + start = 0 + step = 1 if allow_overlap else len(sub_str) + cnt = 0 + while True and cnt < 100: + pos = s.find(sub_str, start) + if pos == -1: + break + indexes.append(pos) + start = pos + step + cnt += 1 + return indexes + + +def bar_inside_vec(s): + indices = find_all(s, '\\vec{') + if not indices: + return s + for i in range(len(indices)): + position = find_all(s, '\\vec{')[i] + idx = position + 4 + idx2 = idx + level = 0 + while idx2 < len(s): + if s[idx2] == '{': + level += 1 + if s[idx2] == '}': + level -= 1 + if level == 0: + break + idx2 += 1 + + s1 = s[idx + 1:idx2] + + s1 = remove_command(s1, '\\bar', keep_inside=True) + s2 = ''.join([s[0:idx + 1], s1, s[idx2:]]) + s = s2 + return s + + +def vec_lower_idx(input_str): + pattern = r'\\vec\{([^{}]+)_{([^{}]+)}\}' + replacement = r'\\vec{\1}_{\2}' + return re.sub(pattern, replacement, input_str) + + +def convert_vec_syntax(text): + pattern = r'\\vec(\s*)(\\?[a-zA-Zα-ωΑ-Ω]+)' + replacement = r'\\vec{\2}' + return re.sub(pattern, replacement, text) + + +def remove_outer_braces(tex_str): + pattern = r'\{(\\(?:[a-zA-Z]+|.)|[^{}])+\}_\{([^}]+)\}' + return re.sub(pattern, r'\1_{\2}', tex_str) + + +def extract_last_equal_content(s: str, strip_whitespace: bool = True) -> str: + + comparison_operators = ('=', '\\approx', '\\ge', '\\le', '\\geq', '\\leq', + '<', '>') + + content = s + for sign in comparison_operators: + if sign in s: + rfind_index = s.rfind(sign) + if rfind_index != -1: + content = s[rfind_index + 1:] + if strip_whitespace: + return content.strip() + return content + + +def first_pre_process(s, extrac_box=True): + s = s.replace('\\{', '(') + s = s.replace('\\}', ')') + if not brackets_balanced(s): + return s + if extrac_box: + boxed_content = remove_command(s, '\\boxed', keep_inside=True) + else: + boxed_content = s + exist_overall_brace = True + cnt = 0 + while exist_overall_brace and cnt < 10: + boxed_content, exist_overall_brace = remove_overall_brace( + boxed_content) + cnt += 1 + + if '\\quad' in boxed_content: + boxed_content = boxed_content.split('\\quad')[0] + + last_equal_content = extract_last_equal_content(boxed_content) + + exist_overall_brace = True + cnt = 0 + while exist_overall_brace and cnt < 10: + last_equal_content, exist_overall_brace = remove_overall_brace( + last_equal_content) + cnt += 1 + return last_equal_content + + +def second_pre_process(s): + kill_commands = ['\\begin', '\\end'] + remove_commands = [ + '\\text', + '\\mathbf', + '\\mathrm', + '\\pmb', + '\\hat', + '\\overline', + '\\boldsymbol', + ] + + remove_content = [ + '\\,', '$', ',', '`', 'latex', '\\left', '\\right', '\\text', + '\\mathrm', '\\Bigr', '\\Bigl', '\n', '\\]', '\\[', '\\Big', '\\bigl', + '\\bigr', '\\biggl', '\\biggr', '\\displaystyle', '\\boldsymbol', + '\\infty' + ] + replace_content = [ + ('\\operatorname{asin}', '\\asin'), ('\\operatorname{sech}', '\\sech'), + ('\\operatorname{acos}', '\\acos'), ('\\operatorname{sinh}', '\\sinh'), + ('\\dfrac', '\\frac'), ('\\tfrac', '\\frac'), ('\\Exp', '\\exp'), + ('\\times', '\\bar{times}'), ('\\partial', '\\bar{partial}'), + ('\\perp', '\\bar{perp}'), ('\\epsilon', '\\varepsilon'), + ('\\varOmega', '\\Omega'), ('I', '\\bar{I}'), ('_e', '_{e}'), + ('e_', '\\bar{e}_'), ('E_', '\\bar{E}_'), ('\\pm', '+'), ('\\mp', '-'), + ('{+}', '{p}'), ('{-}', '{m}'), ('_+', '_p'), ('_-', '_m') + ] + for command in kill_commands: + s = remove_command(s, command, keep_inside=False) + for command in remove_commands: + s = remove_command(s, command, keep_inside=True) + for content in remove_content: + s = s.replace(content, '') + for content in replace_content: + s = s.replace(content[0], content[1]) + s = convert_latex_fractions(s) + s = bar_inside_vec(s) + s = vec_lower_idx(s) + s = convert_vec_syntax(s) + s = exp_frac(s) + if s and s[-1] == '.': + return s[:-1] + return s + + +class MyConfig: + + interpret_as_mixed_fractions: bool = False + interpret_simple_eq_as_assignment: bool = False + interpret_contains_as_eq: bool = True + lowercase_symbols: bool = False + + +class MyNormalization: + basic_latex: bool = True + units: bool = False + malformed_operators: bool = True + nits: bool = True + boxed = 'all' + equations: bool = False + + +def master_convert(s): + preprocessed_stage1 = first_pre_process(s) + + preprocessed_stage2 = second_pre_process(preprocessed_stage1) + + Sym = latex2sympy(preprocessed_stage2, + normalization_config=MyNormalization(), + conversion_config=MyConfig()) + return Sym + + +# The costs can be modified if you think their values are different +insert_cost = {'number': 1, 'symbol': 1, 'operator': 1, 'function': 1} +delete_cost = {'number': 1, 'symbol': 1, 'operator': 1, 'function': 1} +update_cost = {'number': 1, 'symbol': 1, 'operator': 1, 'function': 1} + +change_type_cost = 1 + +bar_size = 5 +discount_slope = 0.6 + +simplify_time_limit = 30 +equals_time_limit = 10 + + +def update_func(x, y): + + if x.label == y.label: + return 0 + + elif x.label.split('_')[0] == y.label.split('_')[0]: + return update_cost[x.label.split('_')[0]] + return change_type_cost + + +def remove_func(x): + return delete_cost[x.label.split('_')[0]] + + +def remove_tree_func(x): + if not x.children: + return remove_func(x) + s = calc_tree_size(x) + return min(s, discount_slope * (s - bar_size) + bar_size) + + +def insert_func(x): + return insert_cost[x.label.split('_')[0]] + + +def insert_tree_func(x): + return remove_tree_func(x) + + +def calc_tree_size(node): + + total = insert_cost[node.label.split('_')[0]] + + if node.children and node.subtree_size != 0: + + return node.subtree_size + + for child in node.children: + total += calc_tree_size(child) + + node.subtree_size = total + + return total + + +""" +Scoring function from relative distance +""" + + +def score_calc(tree_dist, tree_size): + + if tree_dist == 0.: + return 100 + return max(0, 100 * discount_slope - 100 * tree_dist / tree_size) + + +@timeout_decorator.timeout(30, timeout_exception=TimeoutError) +def simplify_with_timeout(expr): + return simplify(expr) + + +def time_simplify(expr): + try: + result = simplify_with_timeout(expr) + return result + except TimeoutError: + return expr + + +@timeout_decorator.timeout(10, timeout_exception=TimeoutError) +def equal_with_timeout(expr1, expr2): + return expr1.equals(expr2) + + +def time_equal(expr1, expr2): + try: + result = equal_with_timeout(expr1, expr2) + return result + except TimeoutError: + return False + + +def sympy_to_tree(expr): + """Convert the sympy expression to a tree.""" + # Symbols and constants + if_list = [Integer, Pi, Exp1, Float, Rational, Infinity, NegativeInfinity] + for i in if_list: + if isinstance(expr, i): + return TreeNode(label='number_' + str(expr), children=[]) + if isinstance(expr, (Symbol, )): + + return TreeNode(label='symbol_' + str(expr), children=[]) + + # Binary operators + elif isinstance(expr, (Add, Mul, Pow)): + + op_name = type(expr).__name__ + children = [sympy_to_tree(arg) for arg in expr.args] + return TreeNode(label='operator_' + op_name, children=children) + + elif isinstance(expr, (Function)): + # Functions + + func_name = expr.func.__name__ + children = [sympy_to_tree(arg) for arg in expr.args] + return TreeNode(label='function_' + func_name, children=children) + + else: + raise ValueError(f'Unsupported SymPy type: {type(expr)}') + + +class TreeNode: + + def __init__(self, label, children=None, node_type='other'): + self.label = label + self.children = children if children is not None else [] + self.node_type = node_type + self.subtree_size = 0 + + def get_children(self): + return self.children + + def __str__(self): + return self.label + + +def print_tree(node, indent=0): + """Print a tree structure.""" + print(' ' * indent + f'└─ {node.label}') + for child in node.children: + print_tree(child, indent + 1) + + +class LaTeXError(Exception): + + def __init__(self, message='LaTeXError'): + super().__init__(message) + + +class SymPyError(Exception): + + def __init__(self, message='SymPyError'): + super().__init__(message) + + +class TreeError(Exception): + + def __init__(self, message='TreeError'): + super().__init__(message) + + +class DistError(Exception): + + def __init__(self, message='DistanceError'): + super().__init__(message) + + +def EED(answer_latex, test_latex, debug_mode=False): + if not test_latex: + return 0, -1, -1, -1 + if '\\int' in test_latex or '\\int' in answer_latex: + return 0, -1, -1, -1 + if '\\sum' in test_latex or '\\sum' in answer_latex: + return 0, -1, -1, 1 + if answer_latex == test_latex: + return 100, 0.0, -1, 0 + if len(test_latex) > 3 * len(answer_latex): + return 0, -1, -1, -1 + + try: + + answer_exp = master_convert(answer_latex) + test_exp = master_convert(test_latex) + except Exception: + if debug_mode: + raise LaTeXError(f'Fail to convert latex.\n GT:{answer_latex}\n' + f' GEN:{test_latex}') + return 0, -1, -1, -1 + + try: + + answer_exp, rep1 = posify(answer_exp) + + answer_exp = time_simplify(answer_exp) + + test_exp, rep2 = posify(test_exp) + test_exp = time_simplify(test_exp) + + answer_exp = answer_exp.subs(rep1) + test_exp = test_exp.subs(rep2) + + zero_exp = time_simplify(expand(answer_exp - test_exp)) + + if answer_exp == test_exp or zero_exp == 0: + return 100, 0., 0, 0 + + if time_equal(answer_exp, test_exp): + return 100, 0., 0, 0 + + except Exception: + if debug_mode: + raise SymPyError( + f'Failed to simplify the sympy expression. Expressions: ' + f'answer_exp={answer_exp}, test_exp={test_exp}') + return 0, -1, -1, -1 + + try: + tree_answer = sympy_to_tree(answer_exp) + tree_test = sympy_to_tree(test_exp) + + except Exception: + if debug_mode: + raise SymPyError(f'Failed to build the sympy expression tree.\n' + f' GT:{answer_exp}\n GEN:{test_exp}') + return 0, -1, -1, -1 + + try: + distance = ext_distance(tree_test, + tree_answer, + get_children=lambda x: x.get_children(), + single_insert_cost=insert_func, + insert_cost=insert_tree_func, + single_remove_cost=remove_func, + remove_cost=remove_tree_func, + update_cost=update_func) + except Exception: + if debug_mode: + raise DistError( + f'Failed to calculate the distance between trees.\n' + f' GT:{answer_latex}\n GEN:{test_latex}') + return 0, -1, calc_tree_size(tree_answer), -1 + tree_size = calc_tree_size(tree_answer) + distance_number = distance + + rel_distance = distance / tree_size + + score = score_calc(distance_number, tree_size) + return score, rel_distance, tree_size, distance_number diff --git a/opencompass/datasets/PHYBench/EED/extended_zss.py b/opencompass/datasets/PHYBench/EED/extended_zss.py new file mode 100644 index 00000000..da117c2b --- /dev/null +++ b/opencompass/datasets/PHYBench/EED/extended_zss.py @@ -0,0 +1,136 @@ +import collections + +from numpy import ones, zeros + + +class Node(object): + + def __init__(self, label, children=None): + self.label = label + self.children = children or list() + + @staticmethod + def get_children(node): + return node.children + + @staticmethod + def get_label(node): + return node.label + + def addkid(self, node, before=False): + if before: + self.children.insert(0, node) + else: + self.children.append(node) + return self + + def get(self, label): + if self.label == label: + return self + for c in self.children: + if label in c: + return c.get(label) + + +class AnnotatedTree(object): + + def __init__(self, root, get_children): + self.get_children = get_children + + self.root = root + self.nodes = list( + ) # a post-order enumeration of the nodes in the tree + self.ids = list() # a matching list of ids + self.lmds = list() # left most descendents of each nodes + self.keyroots = None + # the keyroots in the original paper + + stack = list() + pstack = list() + stack.append((root, collections.deque())) + j = 0 + while len(stack) > 0: + n, anc = stack.pop() + nid = j + for c in self.get_children(n): + a = collections.deque(anc) + a.appendleft(nid) + stack.append((c, a)) + pstack.append(((n, nid), anc)) + j += 1 + lmds = dict() + keyroots = dict() + i = 0 + while len(pstack) > 0: + (n, nid), anc = pstack.pop() + self.nodes.append(n) + self.ids.append(nid) + if not self.get_children(n): + lmd = i + for a in anc: + if a not in lmds: + lmds[a] = i + else: + break + else: + try: + lmd = lmds[nid] + except KeyError: + import pdb + pdb.set_trace() + self.lmds.append(lmd) + keyroots[lmd] = i + i += 1 + self.keyroots = sorted(keyroots.values()) + + +def ext_distance(A, B, get_children, single_insert_cost, insert_cost, + single_remove_cost, remove_cost, update_cost): + A, B = AnnotatedTree(A, get_children), AnnotatedTree(B, get_children) + size_a = len(A.nodes) + size_b = len(B.nodes) + treedists = zeros((size_a, size_b), float) + fd = 1000 * ones((size_a + 1, size_b + 1), float) + + def treedist(x, y): + Al = A.lmds + Bl = B.lmds + An = A.nodes + Bn = B.nodes + + fd[Al[x]][Bl[y]] = 0 + for i in range(Al[x], x + 1): + node = An[i] + fd[i + 1][Bl[y]] = fd[Al[i]][Bl[y]] + remove_cost(node) + + for j in range(Bl[y], y + 1): + node = Bn[j] + + fd[Al[x]][j + 1] = fd[Al[x]][Bl[j]] + insert_cost(node) + + for i in range(Al[x], x + 1): + for j in range(Bl[y], y + 1): + + node1 = An[i] + node2 = Bn[j] + costs = [ + fd[i][j + 1] + single_remove_cost(node1), + fd[i + 1][j] + single_insert_cost(node2), + fd[Al[i]][j + 1] + remove_cost(node1), + fd[i + 1][Bl[j]] + insert_cost(node2) + ] + min_cost = min(costs) + + if Al[x] == Al[i] and Bl[y] == Bl[j]: + treedists[i][j] = min(min_cost, + fd[i][j] + update_cost(node1, node2)) + fd[i + 1][j + 1] = treedists[i][j] + else: + fd[i + 1][j + 1] = min(min_cost, + fd[Al[i]][Bl[j]] + treedists[i][j]) + + for x in A.keyroots: + for y in B.keyroots: + treedist(x, y) + + return treedists[-1][-1] diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index b00162d1..e96d7c6a 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -113,6 +113,7 @@ from .obqa import * # noqa: F401, F403 from .olymmath import * # noqa: F401, F403 from .OlympiadBench import * # noqa: F401, F403 from .OpenFinData import * # noqa: F401, F403 +from .PHYBench import * # noqa: F401, F403 from .physics import * # noqa: F401, F403 from .piqa import * # noqa: F401, F403 from .py150 import * # noqa: F401, F403