From faf5cb88565e92f3c4c1173968eaa1d2005a59fe Mon Sep 17 00:00:00 2001 From: sudanl Date: Tue, 21 Jan 2025 14:23:38 +0000 Subject: [PATCH 01/12] Support OlympiadBench Benchmark --- .pre-commit-config.yaml | 1 + configs/eval_OlympiadBench.py | 38 + .../OlympiadBench_0shot_gen_be8b13.py | 52 ++ .../OlympiadBench/OlympiadBench_categories.py | 7 + .../configs/summarizers/OlympiadBench.py | 15 + .../summarizers/groups/OlympiadBench.py | 11 + opencompass/datasets/OlympiadBench.py | 778 ++++++++++++++++++ opencompass/datasets/__init__.py | 1 + 8 files changed, 903 insertions(+) create mode 100644 configs/eval_OlympiadBench.py create mode 100644 opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py create mode 100644 opencompass/configs/datasets/OlympiadBench/OlympiadBench_categories.py create mode 100644 opencompass/configs/summarizers/OlympiadBench.py create mode 100644 opencompass/configs/summarizers/groups/OlympiadBench.py create mode 100644 opencompass/datasets/OlympiadBench.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c5e5eea9..9e68df25 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -100,6 +100,7 @@ repos: rev: v1.3.1 hooks: - id: docformatter + language: python args: ["--in-place", "--wrap-descriptions", "79"] - repo: local hooks: diff --git a/configs/eval_OlympiadBench.py b/configs/eval_OlympiadBench.py new file mode 100644 index 00000000..d1d62dce --- /dev/null +++ b/configs/eval_OlympiadBench.py @@ -0,0 +1,38 @@ +from mmengine.config import read_base + +with read_base(): + from opencompass.configs.datasets.OlympiadBench.OlympiadBench_0shot_gen_be8b13 import olympiadbench_datasets + + # from opencompass.configs.models.qwen2_5.hf_qwen2_5_7b_instruct import models as hf_qwen2_5_7b_instruct_model + from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import models as lmdeploy_qwen2_5_7b_instruct_model + from opencompass.configs.models.hf_llama.lmdeploy_llama3_8b_instruct import models as lmdeploy_llama3_8b_instruct_model + + from opencompass.configs.summarizers.OlympiadBench import summarizer + + +datasets = sum([v for k, v in locals().items() if k.endswith('_datasets') or k == 'datasets'], []) +models = sum([v for k, v in locals().items() if k.endswith('_model')], []) + +from opencompass.runners import LocalRunner +from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner +from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask + +infer = dict( + partitioner=dict(type=NumWorkerPartitioner, num_worker=8), + runner=dict( + type=LocalRunner, + max_num_workers=8, + task=dict(type=OpenICLInferTask) + ), +) + +eval = dict( + partitioner=dict(type=NaivePartitioner, n=10), + runner=dict( + type=LocalRunner, + max_num_workers=256, + task=dict(type=OpenICLEvalTask) + ), +) + +work_dir = 'outputs/debug/OlympiadBench' diff --git a/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py b/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py new file mode 100644 index 00000000..a150ab40 --- /dev/null +++ b/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py @@ -0,0 +1,52 @@ +from mmengine.config import read_base +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +# from opencompass.datasets import MATHDataset, MATHEvaluator, math_postprocess_v2, normalize_final_answer +from opencompass.datasets import OlympiadBenchPrompter, OlympiadBenchDataset, OlympiadBenchEvaluator, olympiadbench_postprocess_v2 + + +with read_base(): + from .OlympiadBench_categories import categories + +# Create prompter instance for problems +olympiadbench_prompter_cfg = dict( + type='OlympiadBenchPrompter' +) + +olympiadbench_reader_cfg = dict( + input_columns=[ + 'problem', 'language', 'subject', 'question_type', + 'answer_type', 'is_multiple_answer', 'unit', 'questions' + ], + output_column='solution' +) + +olympiadbench_datasets = [] +for _name in categories: + olympiadbench_infer_cfg = dict( + prompt_template=dict( + type='OlympiadBenchTemplate' + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), + ) + + olympiadbench_eval_cfg = dict( + evaluator=dict(type=OlympiadBenchEvaluator, version='v2'), + pred_postprocessor=dict(type=olympiadbench_postprocess_v2), + ) + + olympiadbench_datasets.append( + dict( + type=OlympiadBenchDataset, + abbr=f'OlympiadBench_{_name}', + path='data/OlympiadBench', + name=_name, + reader_cfg=olympiadbench_reader_cfg, + infer_cfg=olympiadbench_infer_cfg, + eval_cfg=olympiadbench_eval_cfg, + ) + ) + +del _name diff --git a/opencompass/configs/datasets/OlympiadBench/OlympiadBench_categories.py b/opencompass/configs/datasets/OlympiadBench/OlympiadBench_categories.py new file mode 100644 index 00000000..818e5293 --- /dev/null +++ b/opencompass/configs/datasets/OlympiadBench/OlympiadBench_categories.py @@ -0,0 +1,7 @@ +categories = [ + 'OE_TO_maths_en_COMP', # OpenEnded - TextOnly - maths - COMP + 'OE_TO_maths_zh_COMP', # OpenEnded - TextOnly - maths - COMP + 'OE_TO_maths_zh_CEE', # OpenEnded - TextOnly - maths - CEE + 'OE_TO_physics_en_COMP', # OpenEnded - TextOnly - physics - COMP + 'OE_TO_physics_zh_CEE' # OpenEnded - TextOnly - physics - CEE +] diff --git a/opencompass/configs/summarizers/OlympiadBench.py b/opencompass/configs/summarizers/OlympiadBench.py new file mode 100644 index 00000000..baf26ca1 --- /dev/null +++ b/opencompass/configs/summarizers/OlympiadBench.py @@ -0,0 +1,15 @@ +from mmengine.config import read_base + +with read_base(): + from .groups.OlympiadBench import OlympiadBench_summary_groups + +summarizer = dict( + dataset_abbrs=[ + 'OlympiadBench_OE_TO_maths_en_COMP', + 'OlympiadBench_OE_TO_maths_zh_COMP', + 'OlympiadBench_OE_TO_maths_zh_CEE', + 'OlympiadBench_OE_TO_physics_en_COMP', + 'OlympiadBench_OE_TO_physics_zh_CEE' + ], + summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []), +) diff --git a/opencompass/configs/summarizers/groups/OlympiadBench.py b/opencompass/configs/summarizers/groups/OlympiadBench.py new file mode 100644 index 00000000..12fb5807 --- /dev/null +++ b/opencompass/configs/summarizers/groups/OlympiadBench.py @@ -0,0 +1,11 @@ +categories = [ + 'OE_TO_maths_en_COMP', # OpenEnded - TextOnly - maths - COMP + 'OE_TO_maths_zh_COMP', # OpenEnded - TextOnly - maths - COMP + 'OE_TO_maths_zh_CEE', # OpenEnded - TextOnly - maths - CEE + 'OE_TO_physics_en_COMP', # OpenEnded - TextOnly - physics - COMP + 'OE_TO_physics_zh_CEE' # OpenEnded - TextOnly - physics - CEE +] + +OlympiadBench_summary_groups = [ + {'name': 'OlympiadBench', 'subsets': ['OlympiadBench_' + c.replace(' ', '_') for c in categories]}, +] diff --git a/opencompass/datasets/OlympiadBench.py b/opencompass/datasets/OlympiadBench.py new file mode 100644 index 00000000..896a9e80 --- /dev/null +++ b/opencompass/datasets/OlympiadBench.py @@ -0,0 +1,778 @@ +import json +import math +import os +import re +from os import environ +from typing import Dict + +import sympy as sp +from datasets import Dataset, DatasetDict +from sympy import Eq, Pow, simplify, sympify +from sympy.parsing.latex import parse_latex + +from opencompass.openicl.icl_evaluator import BaseEvaluator +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.registry import (ICL_PROMPT_TEMPLATES, LOAD_DATASET, + TEXT_POSTPROCESSORS) +from opencompass.utils import get_data_path + +from .base import BaseDataset + +# Load Dataset + + +@LOAD_DATASET.register_module() +class OlympiadBenchDataset(BaseDataset): + """Dataset for OlympiadBench. + + Args: + path (str): Path to dataset directory + name (str): Name of specific json file to load (e.g. 'OE_TO_maths_en_COMP') + """ + + @staticmethod + def load(path: str, name: str = None, **kwargs): + """Load dataset. + + Args: + path (str): Path to dataset directory + name (str): Name of specific json file to load + + Returns: + DatasetDict: Dataset with test and train splits + """ + path = get_data_path(path) + dataset = DatasetDict() + raw_data = [] + + if environ.get('DATASET_SOURCE') == 'ModelScope': + from modelscope import MsDataset + + ms_dataset = MsDataset.load(path, split='train') + for item in ms_dataset: + raw_data.append({ + 'problem': + item['question'], + 'solution': + item['final_answer'][0], + 'language': + item['language'], + 'subject': + item['subject'], + 'question_type': + item['question_type'], + 'answer_type': + item['answer_type'], + 'is_multiple_answer': + item['is_multiple_answer'], + 'unit': + item['unit'], + 'error': + item['error'], + 'questions': + item, # may not be used + }) + else: + # Construct file path using name parameter + if name is None: + raise ValueError( + "Must specify 'name' parameter to load specific json file") + + # file_path = os.path.join(path, name, f'{name}.json') + file_path = os.path.join(path, f'{name}.json') + if not os.path.exists(file_path): + raise FileNotFoundError(f'File not found: {file_path}') + + # Load the specified json file + data = json.load(open(file_path, encoding='utf-8')) + for item in data: + raw_data.append({ + 'problem': + item['question'], + 'solution': + item['final_answer'][0], + 'language': + item['language'], + 'subject': + item['subject'], + 'question_type': + item['question_type'], + 'answer_type': + item['answer_type'], + 'is_multiple_answer': + item['is_multiple_answer'], + 'unit': + item['unit'], + 'error': + item['error'], + 'questions': + item, # may not be used + }) + + dataset['test'] = Dataset.from_list(raw_data) + dataset['train'] = Dataset.from_list(raw_data) + return dataset + + +# Construct Prompt + + +def get_single_answer_type_text(answer_type, is_chinese): + if '-' in answer_type: # No need now + answer_type = answer_type[:answer_type.find('-')] + chinese_answer_type_dict = { + 'Numerical': '数值', + 'Expression': '表达式', + 'Equation': '方程', + 'Interval': '区间', + } + english_answer_type_dict = { + 'Numerical': 'a numerical value', + 'Expression': 'an expression', + 'Equation': 'an equation', + 'Interval': 'an interval', + } + + for t in ['Numerical', 'Expression', 'Equation', 'Interval']: + if t in answer_type: + if is_chinese: + return chinese_answer_type_dict[t] + else: + return english_answer_type_dict[t] + raise ValueError(f'Error parsing answer type {answer_type}!') + + +def get_answer_type_text(answer_type, is_chinese, multiple_answer): + if ('Need_human_evaluate' in answer_type) or ('Tuple' in answer_type): + return '' + + if not multiple_answer: + answer_text = get_single_answer_type_text(answer_type, is_chinese) + if is_chinese: + return f',答案类型为{answer_text}' + else: + return f'The answer of The problem should be {answer_text}. ' + + # Multiple answers case + if ',' not in answer_type: # Same answer type for all answers + answer_text = get_single_answer_type_text(answer_type, is_chinese) + if is_chinese: + return f',题目有多个答案,答案类型均为{answer_text}' + else: + return f'The problem has multiple answers, each of them should be {answer_text}. ' + + # Different answer types + answer_types = answer_type.split(',') + answer_types = [ + get_single_answer_type_text(t, is_chinese) for t in answer_types + ] + if len(set(answer_types)) == 1: + answer_text = answer_types[0] + if is_chinese: + return f',题目有多个答案,答案类型均为{answer_text}' + else: + return f'The problem has multiple answers, each of them should be {answer_text}. ' + else: + if is_chinese: + answer_text = '、'.join(answer_types) + return f',题目有多个答案,答案类型分别为{answer_text}' + else: + answer_text = ', '.join(answer_types) + return f'The problem has multiple answers, with the answers in order being {answer_text}. ' + + +class OlympiadBenchPrompter: + + def __init__(self): + pass + + def make_prompt( + self, + language, + subject, + question_type, + answer_type, + is_multiple_answer, + unit, + ): + self.is_chinese = language == 'Chinese' + self.is_math = subject == 'Math' + self.is_theorem_proving = question_type == 'Theorem proof' + """Generate prompt based on question properties.""" + if self.is_chinese: + subject_content = '数学' if self.is_math else '物理' + if self.is_theorem_proving: + prompt = f'以下是中国{subject_content}竞赛中的证明题。请根据题目的要求,运用逻辑推理及常用定理证明题目中的命题。证明过程中使用的变量和公式请使用LaTeX格式表示。' + else: + answer_type_text = get_answer_type_text( + answer_type, + is_chinese=True, + multiple_answer=is_multiple_answer, + ) + + if is_multiple_answer: + multiple_answer_text = '\\boxed{用英文逗号连接的多个答案}' + else: + multiple_answer_text = '\\boxed{答案}' + + unit_text = '' + if unit: + multiple_answer_text += '(单位)' + unit_text = ',注意答案的单位不要放在\\boxed{}中' + + prompt = f'以下是中国{subject_content}竞赛中的解答题{answer_type_text}。请根据题目的要求和所提供的信息计算得出答案。解答过程和结果中使用的变量和公式请使用LaTeX格式表示。请在最后以"所以最终答案是{multiple_answer_text}。"显式给出结果{unit_text}。' + else: + subject_content = 'Math' if self.is_math else 'Physics' + if self.is_theorem_proving: + prompt = f'The following is a theorem proving problem from an International {subject_content} competition. Please use logical reasoning and common theorems to prove the proposition in the problem according to the given requirements. Please use LaTeX format to represent the variables and formulas used in the proof.' + else: + if is_multiple_answer: + multiple_answer_text = ( + '\\boxed{multiple answers connected with commas}') + else: + multiple_answer_text = '\\boxed{answer}' + + unit_text = '' + if unit: + multiple_answer_text += '(unit)' + unit_text = ', note that the unit of the answer should not be included in \\boxed{}' + + answer_type_text = get_answer_type_text( + answer_type, + is_chinese=False, + multiple_answer=is_multiple_answer, + ) + + prompt = f'The following is an open-ended problem from an International {subject_content} competition. {answer_type_text}Please calculate the answer according to the given requirements and the information provided. Please use LaTeX format to represent the variables and formulas used in the solution process and results. Please end your solution with "So the final answer is {multiple_answer_text}." and give the result explicitly{unit_text}.' + + # Add problem statement to the prompt + prompt = prompt + '\n' + '{problem}' + '\n' + + # Add step-by-step reasoning instruction + if self.is_chinese: + prompt += ('\n请通过逐步推理来解答问题,并把最终答案放置于\\boxed{}中。') + else: + prompt += '\nPlease reason step by step, and put your final answer within \\boxed{}.' + + return prompt + + +### Evaluate + + +class MathJudger: + + def __init__(self): + self.special_signal_map = { + '\\left': '', + '\\right': '', + '∶': ':', + ',': ',', + '$': '', + '\\approx': '=', + '\\simeq': '=', + '\\sim': '=', + '^\\prime': "'", + '^{\\prime}': "'", + '^\\circ': '', + '%': '', + } + self.pi = parse_latex('\\pi') + self.precision = 1e-8 + + def split_by_comma(self, expr: str): + in_bracket_num = 0 + splitted_expr = [] + start_idx = 0 + for i, char in enumerate(expr): + if char == '(' or char == '[': + in_bracket_num += 1 + elif char == ')' or char == ']': + in_bracket_num -= 1 + elif char == ',' and in_bracket_num == 0: + splitted_expr.append(expr[start_idx:i].strip()) + start_idx = i + 1 + + if start_idx < len(expr): + splitted_expr.append(expr[start_idx:].strip()) + + return splitted_expr + + def trans_plus_minus_sign(self, expr_list: list): + new_expr_list = [] + for expr in expr_list: + if '\\pm' in expr: + new_expr_list.append(expr.replace('\\pm', '+')) + new_expr_list.append(expr.replace('\\pm', '-')) + else: + new_expr_list.append(expr) + + return new_expr_list + + def judge(self, expression1, expression2, precision=1e-8): + # (默认 expression1 为 Ground_Truth) + precision = precision if type(precision) == list else [precision] + + try: + expression1, expression2 = self.preprocess(expression1, + expression2) + except Exception: # 处理具体异常 + return False + if expression1 == expression2: + return True + + # 去除字符串中的中文字符 + expression1 = re.sub(r'[\u4e00-\u9fff]+', '', expression1) + expression2 = re.sub(r'[\u4e00-\u9fff]+', '', expression2) + + expression1 = self.split_by_comma(expression1) + expression2 = self.split_by_comma(expression2) + + temp_list1 = self.trans_plus_minus_sign(expression1) + temp_list2 = self.trans_plus_minus_sign(expression2) + + # 设计误差值列表 + if len(precision) <= 1: + precision = precision * len(temp_list1) + + if len(temp_list1) != len(temp_list2): + return False + + # 判断两个列表中的元素是否可以两两配对,并且两两相等 + idx = -1 + while len(temp_list1) != 0: + idx = (idx + 1) % len(temp_list1) + + item1 = temp_list1[idx] + self.precision = precision[idx] + + for item2 in temp_list2: + if self.is_equal(item1, item2): + temp_list1.remove(item1) + temp_list2.remove(item2) + precision.remove(self.precision) + break + else: + return False + + # 如果所有元素都匹配并移除,列表可以配对 + return True + + def is_interval(self, epr): + return epr.startswith(('(', '[')) and epr.endswith((')', ']')) + + def sympy_sub_pi(self, expression_sympy): + return expression_sympy.subs(self.pi, math.pi) + + def is_equal(self, expression1, expression2): + if (expression1 == expression2 and expression1 != '' + and expression2 != ''): + return True + + # 先判断是否是两个区间 + if self.is_interval(expression1) and self.is_interval(expression2): + try: + if self.interval_equal(expression1, expression2): + return True + except Exception: # 处理具体异常 + return False + + # 再判断是否在数值上相等 + try: + if self.numerical_equal(expression1, expression2): + return True + except Exception: # 处理具体异常 + pass + + # 再判断是否是表达式相等 + try: + if self.expression_equal( + expression1, expression2) and not ('=' in expression1 + and '=' in expression2): + return True + except Exception: # 处理具体异常 + pass + + # 再判断是否是等式相等 + try: + if self.equation_equal(expression1, expression2): + return True + except Exception: # 处理具体异常 + pass + + return False + + def numerical_equal( + self, + expression1: str, + expression2: str, + include_percentage: bool = True, + ): + """ + (默认 expression1 为 Ground_Truth) + 函数: 判读两个数值是否在误差允许范围内相等 + 步骤1: 将可能出现的百分号的情况包含进来 + 步骤2: 使用 math.isclose 函数判断是否相等 + """ + reference = float(expression1) + prediction = float(expression2) + + if include_percentage: + gt_result = [reference / 100, reference, reference * 100] + else: + gt_result = [reference] + + for item in gt_result: + if abs(item - prediction) <= self.precision * 1.01: + return True + return False + + def expression_equal(self, exp1, exp2): + """ + (默认 expression1 为 Ground_Truth) + 函数: 判断两个表达式是否在数学意义上等价 + 步骤1: 提取表达式, 防止有的模型会给出"x=1"而不是"1" + 步骤2: 使用 sympy 库进行等价判断 + """ + + # 只提取等号右边的表达式 + def extract_expression(expression): + if '=' in expression: + expression = expression.split('=')[1] + return expression.strip() + + exp1 = extract_expression(exp1) + exp2 = extract_expression(exp2) + + # 将表达式转换为 sympy 中能够进行处理的格式 + expr1_sym = sympify(parse_latex(exp1)) + expr2_sym = sympify(parse_latex(exp2)) + + if expr1_sym == expr2_sym: + return True + else: + expr1_sym = self.sympy_sub_pi(expr1_sym) + expr2_sym = self.sympy_sub_pi(expr2_sym) + + if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or ( + not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)): + return False + elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol): + try: + if not (self.can_compute_power(expr1_sym) + and self.can_compute_power(expr2_sym)): + print( + f'These two number can not be calculated by current computer for: "{str(expr1_sym)}" and "{str(expr2_sym)}"' + ) + return False + + if (abs(expr1_sym.evalf() - expr2_sym.evalf()) <= + self.precision * 1.01): + return True + else: + return False + except Exception: # 处理具体异常 + return False + else: + try: + simplified_expr = simplify(expr1_sym - expr2_sym) + + num_value = simplified_expr.evalf() + + return abs(num_value) < 1e-3 + except Exception: # 处理具体异常 + return False + + def equation_equal(self, expression1, expression2): + """ + (expression1 is assumed to be Ground_Truth) + Function: Check if two equations are mathematically equivalent + Step 1: Simplify equations to standard form with right side equal to 0 + Step 2: Use sympy library to calculate quotient of left sides, if quotient or its reciprocal is integer, equations are equivalent + """ + + # Convert equations to sympy format with right side moved to left side + def simplify_equation(latex_eq): + # Split left and right sides of equation + lhs, rhs = latex_eq.split('=') + + # Parse LaTeX expressions using parse_latex + lhs_expr = parse_latex(lhs) + rhs_expr = parse_latex(rhs) + + # Create equation object + equation = Eq(lhs_expr, rhs_expr) + + # Simplify equation by moving right side to left + simplified_eq = simplify(equation.lhs - equation.rhs) + + return simplified_eq + + expr1_sym = simplify_equation(expression1) + expr2_sym = simplify_equation(expression2) + + division_result_1 = simplify(expr1_sym / expr2_sym) + division_result_2 = simplify(expr2_sym / expr1_sym) + + # If division result or its reciprocal is non-zero integer, equations are equivalent + if (division_result_1.is_Integer + and division_result_1 != 0) or (division_result_2.is_Integer + and division_result_2 != 0): + return True + else: + return False + + def interval_equal(self, expression1, expression2): + """ + Function: Check if two intervals are mathematically equivalent + Step 1: Simplify interval expressions, remove irrelevant symbols like "\left", "\right", and "x \in" + Step 2: Compare brackets and mathematical expressions in between + """ + + def compare_two_interval(inter1, inter2): + # First compare brackets on both sides + if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]: + return False + + inter1 = inter1.strip('[]()') + inter2 = inter2.strip('[]()') + + # Split interval into left and right parts + items_1 = inter1.split(',') + items_2 = inter2.split(',') + + for item_1, item_2 in zip(items_1, items_2): + if not self.expression_equal(item_1, item_2): + return False + return True + + interval1 = expression1 + interval2 = expression2 + + if interval1 == interval2: + return True + else: + inter_list1 = interval1.split('\\cup') + inter_list2 = interval2.split('\\cup') + + if len(inter_list1) != len(inter_list2): + return False + else: + for inter1, inter2 in zip(inter_list1, inter_list2): + if not compare_two_interval(inter1, inter2): + return False + return True + + def preprocess(self, expression1, expression2): + """Extract and preprocess expressions from model output.""" + + def extract_boxed_content(latex_str): + # Find all \boxed{...} structures + boxed_matches = re.finditer(r'\\boxed{', latex_str) + results = '' + + for match in boxed_matches: + start_index = match.end() + end_index = start_index + stack = 1 + + # Search from after \boxed{ until finding matching closing brace + while stack > 0 and end_index < len(latex_str): + if latex_str[end_index] == '{': + stack += 1 + elif latex_str[end_index] == '}': + stack -= 1 + end_index += 1 + + if stack == 0: + # Extract content inside \boxed{} + content = latex_str[start_index:end_index - 1] + results += content + ',' + else: + raise ValueError('Mismatched braces in LaTeX string.') + + # If no \boxed{} found, extract formulas from last line + if results == '': + last_line_ans = latex_str.strip().split('\n')[-1] + dollar_pattern = r'\$(.*?)\$' + answers = re.findall(dollar_pattern, last_line_ans) + + if answers: + for ans in answers: + results += ans + ',' + else: + results = latex_str + + return results + + def special_symbol_replace(expression): + if '\\in ' in expression: + expression = expression.split('\\in ')[1] + + # Replace special characters that don't affect LaTeX parsing (decorative) + for signal in self.special_signal_map: + expression = expression.replace( + signal, self.special_signal_map[signal]) + + expression = expression.strip('\n$,.:;^_=+`!@#$%^&*~,。') + + pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}' + expression = re.sub(pattern, r'\1', expression) + + return expression + + exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content( + expression2) + exp1, exp2 = special_symbol_replace(exp1), special_symbol_replace(exp2) + + return exp1, exp2 + + def can_compute_power(self, expr): + """Check if the power expression can be computed. + + Parameters: + expr (sympy expression): The expression to check. + + Returns: + bool: True if the expression can be computed, False otherwise. + """ + # Check if the expression is a power expression + if isinstance(expr, Pow): + # Extract the base and the exponent + base, exp = expr.as_base_exp() + + # Check if the base and the exponent are numbers + if base.is_number and exp.is_number: + # Set a threshold for the maximum size of the exponent + MAX_EXP = 1000 # This threshold can be adjusted based on the computing environment + + # Check if the exponent is greater than the threshold + if abs(exp.evalf()) > MAX_EXP: + return False + else: + return True + else: + # If the base or the exponent is not a number, we cannot compute the power + return False + else: + # If the expression is not a power expression, return True as it is not the case we are checking for + return True + + +@TEXT_POSTPROCESSORS.register_module('olympiadbench_postprocess_v2') +def olympiadbench_postprocess_v2(text: str, + is_chinese: bool = False, + is_deepseek: bool = False) -> str: + """Extract answer from model output.""" + # deepseekmath has special answering format + if is_deepseek: + if is_chinese: + matches = re.findall('## 解题答案(.*)', text) + else: + matches = re.findall('The answer is: (.*)', text) + else: + if is_chinese: + matches = re.findall('所以最终答案是(.*)', text) + else: + matches = re.findall('So the final answer is (.*)', text) + + # If found matches, take the last one, otherwise return the whole text + if matches: + return matches[-1].strip() + return text + + +class OlympiadBenchEvaluator(BaseEvaluator): + """Evaluator for OlympiadBench dataset.""" + + def __init__(self, version='v1'): + assert version in ['v1', 'v2'] + self.version = version + self.judger = MathJudger() + + def score(self, predictions, references): # Remove questions parameter + """Calculate accuracy score. + + Args: + predictions (list): List of model predictions + references (list): List of ground truth answers + """ + if len(predictions) != len(references): + return { + 'error': 'predictions and references have different length' + } + + correct = 0 + count = 0 + details = [] + + for pred, ref in zip(predictions, references): + detail = {'pred': pred, 'answer': ref, 'correct': False} + count += 1 + + # Get precision/error threshold from reference if available + precision = 1e-8 + if isinstance(ref, dict) and 'error' in ref: + if ',' in ref['error']: + # Multiple precisions for multiple answers + precisions = ref['error'].split(',') + precisions = [float(p) if p else 1e-8 for p in precisions] + precision = precisions + else: + precision = float(ref['error']) + + # Check if answer is correct + try: + if (isinstance(ref, dict) and 'answer_type' in ref + and 'Tuple' in ref['answer_type']): + # Special handling for tuple type answers + is_correct = self.judger.judge(pred, + ref['final_answer'][0], + precision) + else: + is_correct = self.judger.judge(pred, ref, precision) + + if is_correct: + correct += 1 + detail['correct'] = True + except Exception as e: # 处理具体异常 + detail['error'] = str(e) + + details.append(detail) + + result = {'accuracy': 100 * correct / count, 'details': details} + return result + + +@ICL_PROMPT_TEMPLATES.register_module() +class OlympiadBenchTemplate(PromptTemplate): + """Template for OlympiadBench dataset.""" + + def __init__(self): + # Define basic template structure + template = dict(round=[dict(role='HUMAN', prompt='{prompt}')]) + super().__init__(template=template) + self.prompter = OlympiadBenchPrompter() + + def generate_item(self, entry: Dict, *args, **kwargs) -> str: + """Generate prompt for a single item.""" + problem = entry.get('problem', '') + language = entry.get('language', 'English') + subject = entry.get('subject', 'Math') + question_type = entry.get('question_type', '') + answer_type = entry.get('answer_type', '') + is_multiple_answer = entry.get('is_multiple_answer', False) + unit = entry.get('unit', '') + + prompt = self.prompter.make_prompt( + language=language, + subject=subject, + question_type=question_type, + answer_type=answer_type, + is_multiple_answer=is_multiple_answer, + unit=unit, + ) + + new_entry = {'prompt': prompt, 'problem': problem} + + return super().generate_item(new_entry, *args, **kwargs) diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index e061286f..b28f78ed 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -103,6 +103,7 @@ from .natural_question import * # noqa: F401, F403 from .natural_question_cn import * # noqa: F401, F403 from .NPHardEval import * # noqa: F401, F403 from .obqa import * # noqa: F401, F403 +from .OlympiadBench import * # noqa: F401, F403 from .OpenFinData import * # noqa: F401, F403 from .piqa import * # noqa: F401, F403 from .py150 import * # noqa: F401, F403 From e0375c89413d47c6be1914f42b3e2c6f326dd095 Mon Sep 17 00:00:00 2001 From: sudanl Date: Thu, 23 Jan 2025 08:34:38 +0000 Subject: [PATCH 02/12] Support OlympiadBench Benchmark --- .pre-commit-config.yaml | 13 ++-- opencompass/datasets/OlympiadBench.py | 95 +++++++++++++++++---------- 2 files changed, 66 insertions(+), 42 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9e68df25..c22875c5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,6 +44,7 @@ repos: rev: v0.32.0 hooks: - id: yapf + args: ["--style=pep8", "--no-local-style"] exclude: | (?x)^( configs/ | @@ -96,12 +97,12 @@ repos: - mdformat_frontmatter - linkify-it-py exclude: configs/ - - repo: https://github.com/myint/docformatter - rev: v1.3.1 - hooks: - - id: docformatter - language: python - args: ["--in-place", "--wrap-descriptions", "79"] + # - repo: https://github.com/myint/docformatter + # rev: v1.3.1 + # hooks: + # - id: docformatter + # language: system + # args: ["--in-place", "--wrap-descriptions", "79"] - repo: local hooks: - id: update-dataset-suffix diff --git a/opencompass/datasets/OlympiadBench.py b/opencompass/datasets/OlympiadBench.py index 896a9e80..6fda55a5 100644 --- a/opencompass/datasets/OlympiadBench.py +++ b/opencompass/datasets/OlympiadBench.py @@ -27,7 +27,8 @@ class OlympiadBenchDataset(BaseDataset): Args: path (str): Path to dataset directory - name (str): Name of specific json file to load (e.g. 'OE_TO_maths_en_COMP') + name (str): Name of specific json file to load + e.g. 'OE_TO_maths_en_COMP' """ @staticmethod @@ -145,22 +146,21 @@ def get_single_answer_type_text(answer_type, is_chinese): def get_answer_type_text(answer_type, is_chinese, multiple_answer): if ('Need_human_evaluate' in answer_type) or ('Tuple' in answer_type): return '' - if not multiple_answer: answer_text = get_single_answer_type_text(answer_type, is_chinese) if is_chinese: return f',答案类型为{answer_text}' else: - return f'The answer of The problem should be {answer_text}. ' - + return (f'The answer of The problem should be ' + f'{answer_text}. ') # Multiple answers case if ',' not in answer_type: # Same answer type for all answers answer_text = get_single_answer_type_text(answer_type, is_chinese) if is_chinese: return f',题目有多个答案,答案类型均为{answer_text}' else: - return f'The problem has multiple answers, each of them should be {answer_text}. ' - + return (f'The problem has multiple answers, each of them ' + f'should be {answer_text}. ') # Different answer types answer_types = answer_type.split(',') answer_types = [ @@ -171,14 +171,16 @@ def get_answer_type_text(answer_type, is_chinese, multiple_answer): if is_chinese: return f',题目有多个答案,答案类型均为{answer_text}' else: - return f'The problem has multiple answers, each of them should be {answer_text}. ' + return (f'The problem has multiple answers, each of them ' + f'should be {answer_text}. ') else: if is_chinese: answer_text = '、'.join(answer_types) return f',题目有多个答案,答案类型分别为{answer_text}' else: answer_text = ', '.join(answer_types) - return f'The problem has multiple answers, with the answers in order being {answer_text}. ' + return (f'The problem has multiple answers, ' + f'with the answers in order being {answer_text}. ') class OlympiadBenchPrompter: @@ -202,62 +204,74 @@ class OlympiadBenchPrompter: if self.is_chinese: subject_content = '数学' if self.is_math else '物理' if self.is_theorem_proving: - prompt = f'以下是中国{subject_content}竞赛中的证明题。请根据题目的要求,运用逻辑推理及常用定理证明题目中的命题。证明过程中使用的变量和公式请使用LaTeX格式表示。' + prompt = (f'以下是中国{subject_content}竞赛中的证明题。请根据题目的要求,' + f'运用逻辑推理及常用定理证明题目中的命题。证明过程中使用的变量和公式请使用LaTeX格式表示。') else: answer_type_text = get_answer_type_text( answer_type, is_chinese=True, multiple_answer=is_multiple_answer, ) - if is_multiple_answer: multiple_answer_text = '\\boxed{用英文逗号连接的多个答案}' else: multiple_answer_text = '\\boxed{答案}' - unit_text = '' if unit: multiple_answer_text += '(单位)' unit_text = ',注意答案的单位不要放在\\boxed{}中' - - prompt = f'以下是中国{subject_content}竞赛中的解答题{answer_type_text}。请根据题目的要求和所提供的信息计算得出答案。解答过程和结果中使用的变量和公式请使用LaTeX格式表示。请在最后以"所以最终答案是{multiple_answer_text}。"显式给出结果{unit_text}。' + prompt = (f'以下是中国{subject_content}竞赛中的解答题{answer_type_text}。' + f'请根据题目的要求和所提供的信息计算得出答案。解答过程和结果中使用的' + f'变量和公式请使用LaTeX格式表示。请在最后以"所以最终答案是' + f'{multiple_answer_text}。"显式给出结果{unit_text}。') else: subject_content = 'Math' if self.is_math else 'Physics' if self.is_theorem_proving: - prompt = f'The following is a theorem proving problem from an International {subject_content} competition. Please use logical reasoning and common theorems to prove the proposition in the problem according to the given requirements. Please use LaTeX format to represent the variables and formulas used in the proof.' + prompt = ( + f'The following is a theorem proving problem from an ' + f'International {subject_content} competition. Please use ' + f'logical reasoning and common theorems to prove the ' + f'proposition in the problem according to the given ' + f'requirements. Please use LaTeX format to represent the ' + f'variables and formulas used in the proof.') else: if is_multiple_answer: multiple_answer_text = ( '\\boxed{multiple answers connected with commas}') else: multiple_answer_text = '\\boxed{answer}' - unit_text = '' if unit: multiple_answer_text += '(unit)' - unit_text = ', note that the unit of the answer should not be included in \\boxed{}' - + unit_text = (', note that the unit of the answer should ' + 'not be included in \\boxed{}') answer_type_text = get_answer_type_text( answer_type, is_chinese=False, multiple_answer=is_multiple_answer, ) - - prompt = f'The following is an open-ended problem from an International {subject_content} competition. {answer_type_text}Please calculate the answer according to the given requirements and the information provided. Please use LaTeX format to represent the variables and formulas used in the solution process and results. Please end your solution with "So the final answer is {multiple_answer_text}." and give the result explicitly{unit_text}.' - + prompt = ( + f'The following is an open-ended problem from an ' + f'International {subject_content} competition. ' + f'{answer_type_text}Please calculate the answer according ' + f'to the given requirements and the information provided. ' + f'Please use LaTeX format to represent the variables and ' + f'formulas used in the solution process and results. ' + f'Please end your solution with "So the final answer is ' + f'{multiple_answer_text}." and give the result explicitly' + f'{unit_text}.') # Add problem statement to the prompt prompt = prompt + '\n' + '{problem}' + '\n' - # Add step-by-step reasoning instruction if self.is_chinese: - prompt += ('\n请通过逐步推理来解答问题,并把最终答案放置于\\boxed{}中。') + prompt += '\n请通过逐步推理来解答问题,并把最终答案放置于\\boxed{}中。' else: - prompt += '\nPlease reason step by step, and put your final answer within \\boxed{}.' - + prompt += ('\nPlease reason step by step, and put your final ' + 'answer within \\boxed{}.') return prompt -### Evaluate +# Evaluate class MathJudger: @@ -461,9 +475,9 @@ class MathJudger: try: if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)): - print( - f'These two number can not be calculated by current computer for: "{str(expr1_sym)}" and "{str(expr2_sym)}"' - ) + print(f'These two number can not be calculated by ' + f'current computer for: ' + f'"{str(expr1_sym)}" and "{str(expr2_sym)}"') return False if (abs(expr1_sym.evalf() - expr2_sym.evalf()) <= @@ -488,7 +502,8 @@ class MathJudger: (expression1 is assumed to be Ground_Truth) Function: Check if two equations are mathematically equivalent Step 1: Simplify equations to standard form with right side equal to 0 - Step 2: Use sympy library to calculate quotient of left sides, if quotient or its reciprocal is integer, equations are equivalent + Step 2: Use sympy library to calculate quotient of left sides, + if quotient or its reciprocal is integer, equations are equivalent """ # Convert equations to sympy format with right side moved to left side @@ -514,7 +529,8 @@ class MathJudger: division_result_1 = simplify(expr1_sym / expr2_sym) division_result_2 = simplify(expr2_sym / expr1_sym) - # If division result or its reciprocal is non-zero integer, equations are equivalent + # If division result or its reciprocal is + # non-zero integer, equations are equivalent if (division_result_1.is_Integer and division_result_1 != 0) or (division_result_2.is_Integer and division_result_2 != 0): @@ -525,7 +541,9 @@ class MathJudger: def interval_equal(self, expression1, expression2): """ Function: Check if two intervals are mathematically equivalent - Step 1: Simplify interval expressions, remove irrelevant symbols like "\left", "\right", and "x \in" + Step 1: Simplify interval expressions, + remove irrelevant symbols + like "\\left", "\\right", and "x \\in" Step 2: Compare brackets and mathematical expressions in between """ @@ -576,7 +594,8 @@ class MathJudger: end_index = start_index stack = 1 - # Search from after \boxed{ until finding matching closing brace + # Search from after \boxed{ until + # finding matching closing brace while stack > 0 and end_index < len(latex_str): if latex_str[end_index] == '{': stack += 1 @@ -609,7 +628,8 @@ class MathJudger: if '\\in ' in expression: expression = expression.split('\\in ')[1] - # Replace special characters that don't affect LaTeX parsing (decorative) + # Replace special characters that + # don't affect LaTeX parsing (decorative) for signal in self.special_signal_map: expression = expression.replace( signal, self.special_signal_map[signal]) @@ -644,7 +664,8 @@ class MathJudger: # Check if the base and the exponent are numbers if base.is_number and exp.is_number: # Set a threshold for the maximum size of the exponent - MAX_EXP = 1000 # This threshold can be adjusted based on the computing environment + # can be adjusted based on the computing environment + MAX_EXP = 1000 # Check if the exponent is greater than the threshold if abs(exp.evalf()) > MAX_EXP: @@ -652,10 +673,12 @@ class MathJudger: else: return True else: - # If the base or the exponent is not a number, we cannot compute the power + # If the base or the exponent is not a number, + # we cannot compute the power return False else: - # If the expression is not a power expression, return True as it is not the case we are checking for + # If the expression is not a power expression, + # return True as it is not the case we are checking for return True From 318648778546469dde73642ee7402abc42ee4b28 Mon Sep 17 00:00:00 2001 From: sudanl Date: Thu, 23 Jan 2025 08:39:31 +0000 Subject: [PATCH 03/12] Support OlympiadBench Benchmark --- .../datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py b/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py index a150ab40..145b6adc 100644 --- a/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py +++ b/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py @@ -3,7 +3,7 @@ from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import GenInferencer # from opencompass.datasets import MATHDataset, MATHEvaluator, math_postprocess_v2, normalize_final_answer -from opencompass.datasets import OlympiadBenchPrompter, OlympiadBenchDataset, OlympiadBenchEvaluator, olympiadbench_postprocess_v2 +from opencompass.datasets import OlympiadBenchDataset, OlympiadBenchEvaluator, olympiadbench_postprocess_v2 with read_base(): From 73f1c48e85a6fcd7ee22b91691bcc0d177bd4a92 Mon Sep 17 00:00:00 2001 From: sudanl Date: Thu, 23 Jan 2025 08:58:48 +0000 Subject: [PATCH 04/12] update dataset path --- .../datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py b/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py index 145b6adc..09312393 100644 --- a/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py +++ b/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py @@ -41,7 +41,7 @@ for _name in categories: dict( type=OlympiadBenchDataset, abbr=f'OlympiadBench_{_name}', - path='data/OlympiadBench', + path='opencompass/OlympiadBench', name=_name, reader_cfg=olympiadbench_reader_cfg, infer_cfg=olympiadbench_infer_cfg, From dbf2cb7fdb8fe4eeb0c6e2449b3901c7c2dbb4f4 Mon Sep 17 00:00:00 2001 From: liushz Date: Thu, 23 Jan 2025 15:12:20 +0000 Subject: [PATCH 05/12] Update olmpiadBench --- configs/eval_OlympiadBench.py | 2 +- .../OlympiadBench/OlympiadBench_0shot_gen_be8b13.py | 2 -- opencompass/utils/datasets_info.py | 9 +++++++++ requirements/extra.txt | 2 ++ 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/configs/eval_OlympiadBench.py b/configs/eval_OlympiadBench.py index d1d62dce..78a9fb24 100644 --- a/configs/eval_OlympiadBench.py +++ b/configs/eval_OlympiadBench.py @@ -5,7 +5,7 @@ with read_base(): # from opencompass.configs.models.qwen2_5.hf_qwen2_5_7b_instruct import models as hf_qwen2_5_7b_instruct_model from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import models as lmdeploy_qwen2_5_7b_instruct_model - from opencompass.configs.models.hf_llama.lmdeploy_llama3_8b_instruct import models as lmdeploy_llama3_8b_instruct_model + # from opencompass.configs.models.hf_llama.lmdeploy_llama3_8b_instruct import models as lmdeploy_llama3_8b_instruct_model from opencompass.configs.summarizers.OlympiadBench import summarizer diff --git a/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py b/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py index 09312393..36e2f37f 100644 --- a/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py +++ b/opencompass/configs/datasets/OlympiadBench/OlympiadBench_0shot_gen_be8b13.py @@ -1,8 +1,6 @@ from mmengine.config import read_base -from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import GenInferencer -# from opencompass.datasets import MATHDataset, MATHEvaluator, math_postprocess_v2, normalize_final_answer from opencompass.datasets import OlympiadBenchDataset, OlympiadBenchEvaluator, olympiadbench_postprocess_v2 diff --git a/opencompass/utils/datasets_info.py b/opencompass/utils/datasets_info.py index 10749c58..050d5983 100644 --- a/opencompass/utils/datasets_info.py +++ b/opencompass/utils/datasets_info.py @@ -398,9 +398,18 @@ DATASETS_MAPPING = { "hf_id": "THUDM/LongBench-v2", "local": "./data/longbenchv2/data.json", }, + "opencompass/OlympiadBench": { + "ms_id": "", + "hf_id": "", + "local": "./data/OlympiadBench", + }, } DATASETS_URL = { + "/OlympiadBench": { + "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/OlympiadBench.zip", + "md5": "97e8b1ae7f6170d94817288a8930ef00", + }, "/longbenchv2":{ "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/longbenchv2.zip", "md5": "09b7e06e6f98c5cca8ad597b3d7b42f0", diff --git a/requirements/extra.txt b/requirements/extra.txt index 96789956..7f04c9d0 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -1,5 +1,7 @@ # Alpaca-eval alpaca-eval==0.6 +# OlympiadBench +antlr4-python3-runtime==4.11 cn2an # Dingo dingo-python==1.1.2 From eb286c3be38d7d7670bd43c4161bdd95b02dafee Mon Sep 17 00:00:00 2001 From: liushz Date: Thu, 23 Jan 2025 15:17:49 +0000 Subject: [PATCH 06/12] Update olmpiadBench --- .pre-commit-config.yaml | 71 ++++++++--------------------------------- 1 file changed, 13 insertions(+), 58 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c22875c5..9f72ae42 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,10 +14,6 @@ exclude: | opencompass/datasets/subjective/mtbench101.py| docs/zh_cn/advanced_guides/compassbench_intro.md | docs/zh_cn/advanced_guides/compassbench_v2_0.md | - opencompass/configs/datasets/ | - opencompass/configs/models/| - opencompass/configs/summarizers/ | - opencompass/configs/dataset_collections/ | opencompass/utils/datasets.py | opencompass/utils/datasets_info.py ) @@ -28,8 +24,8 @@ repos: - id: flake8 exclude: | (?x)^( - configs/ | - example_scripts/ + opencompass/configs/| + examples/ ) - repo: https://github.com/PyCQA/isort rev: 5.11.5 @@ -37,18 +33,17 @@ repos: - id: isort exclude: | (?x)^( - configs/ | - example_scripts/ + opencompass/configs/| + examples/ ) - repo: https://github.com/pre-commit/mirrors-yapf rev: v0.32.0 hooks: - id: yapf - args: ["--style=pep8", "--no-local-style"] exclude: | (?x)^( - configs/ | - example_scripts/ + opencompass/configs/| + examples/ ) - repo: https://github.com/codespell-project/codespell rev: v2.2.1 @@ -58,9 +53,8 @@ repos: (?x)^( .*\.jsonl| .*\.md.template| - configs/ | opencompass/configs/ | - example_scripts/ + examples/ ) - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 @@ -70,7 +64,6 @@ repos: (?x)^( dicts/| projects/.*?/dicts/| - configs/.*?/.*\.txt ) - id: check-yaml - id: end-of-file-fixer @@ -78,7 +71,6 @@ repos: (?x)^( dicts/| projects/.*?/dicts/| - configs/.*?/.*\.txt ) - id: requirements-txt-fixer - id: double-quote-string-fixer @@ -97,12 +89,11 @@ repos: - mdformat_frontmatter - linkify-it-py exclude: configs/ - # - repo: https://github.com/myint/docformatter - # rev: v1.3.1 - # hooks: - # - id: docformatter - # language: system - # args: ["--in-place", "--wrap-descriptions", "79"] + - repo: https://github.com/myint/docformatter + rev: v1.3.1 + hooks: + - id: docformatter + args: ["--in-place", "--wrap-descriptions", "79"] - repo: local hooks: - id: update-dataset-suffix @@ -111,7 +102,7 @@ repos: language: script pass_filenames: true require_serial: true - files: ^configs/datasets + files: ^opencompass/configs/datasets - repo: local hooks: - id: update-dataset-suffix-pacakge @@ -124,42 +115,6 @@ repos: args: - --root_folder - opencompass/configs/datasets - - repo: local - hooks: - - id: compare-configs-datasets - name: compare configs datasets - entry: ./tools/compare_configs.py - language: script - pass_filenames: false - # require_serial: true - args: - - configs/datasets - - opencompass/configs/datasets - - repo: local - hooks: - - id: compare-configs-models - name: compare configs models - entry: ./tools/compare_configs.py - language: script - pass_filenames: false - # require_serial: true - args: - - configs/models - - opencompass/configs/models - - --ignore - - llama - - repo: local - hooks: - - id: compare-configs-summarizers - name: compare configs summarizers - entry: ./tools/compare_configs.py - language: script - pass_filenames: false - # require_serial: true - args: - - configs/summarizers - - opencompass/configs/summarizers - # - repo: https://github.com/open-mmlab/pre-commit-hooks # rev: v0.2.0 # Use the ref you want to point at # hooks: From 55fe911ea2cbee4312179d017fe9efdaa1f196dc Mon Sep 17 00:00:00 2001 From: liushz Date: Thu, 23 Jan 2025 15:47:53 +0000 Subject: [PATCH 07/12] Update olmpiadBench --- examples/eval_OlympiadBench.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/eval_OlympiadBench.py b/examples/eval_OlympiadBench.py index 78a9fb24..090a4cfa 100644 --- a/examples/eval_OlympiadBench.py +++ b/examples/eval_OlympiadBench.py @@ -3,9 +3,7 @@ from mmengine.config import read_base with read_base(): from opencompass.configs.datasets.OlympiadBench.OlympiadBench_0shot_gen_be8b13 import olympiadbench_datasets - # from opencompass.configs.models.qwen2_5.hf_qwen2_5_7b_instruct import models as hf_qwen2_5_7b_instruct_model from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import models as lmdeploy_qwen2_5_7b_instruct_model - # from opencompass.configs.models.hf_llama.lmdeploy_llama3_8b_instruct import models as lmdeploy_llama3_8b_instruct_model from opencompass.configs.summarizers.OlympiadBench import summarizer From 34cc0a5f5ff672958ec92fda7e9145716d0246e0 Mon Sep 17 00:00:00 2001 From: liushz Date: Fri, 28 Feb 2025 07:55:17 +0000 Subject: [PATCH 08/12] Add HLE dataset --- opencompass/configs/datasets/HLE/hle_gen.py | 4 + .../datasets/HLE/hle_llmjudge_gen_6ff468.py | 91 +++++++++++++++++++ opencompass/datasets/__init__.py | 1 + opencompass/datasets/hle.py | 17 ++++ 4 files changed, 113 insertions(+) create mode 100644 opencompass/configs/datasets/HLE/hle_gen.py create mode 100644 opencompass/configs/datasets/HLE/hle_llmjudge_gen_6ff468.py create mode 100644 opencompass/datasets/hle.py diff --git a/opencompass/configs/datasets/HLE/hle_gen.py b/opencompass/configs/datasets/HLE/hle_gen.py new file mode 100644 index 00000000..a4ff86b4 --- /dev/null +++ b/opencompass/configs/datasets/HLE/hle_gen.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .hle_llmjudge_gen_63a000 import hle_datasets # noqa: F401, F403 diff --git a/opencompass/configs/datasets/HLE/hle_llmjudge_gen_6ff468.py b/opencompass/configs/datasets/HLE/hle_llmjudge_gen_6ff468.py new file mode 100644 index 00000000..bb6f40bf --- /dev/null +++ b/opencompass/configs/datasets/HLE/hle_llmjudge_gen_6ff468.py @@ -0,0 +1,91 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.evaluator import GenericLLMEvaluator +from opencompass.datasets import generic_llmjudge_postprocess +from opencompass.datasets import HLEDataset + +# ----------------------------- Detailed Config ----------------------------- + +math_reader_cfg = dict(input_columns=['problem'], output_column='answer') + +math_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt='{problem}\nRemember to put your final answer within \\boxed{}.'), + ] + ), + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer), +) + +GRADER_TEMPLATE = """ + Please as a grading expert, judge whether the final answers given by the candidates below are consistent with the standard answers, that is, whether the candidates answered correctly. + + Here are some evaluation criteria: + 1. Please refer to the given standard answer. You don't need to re-generate the answer to the question because the standard answer has been given. You only need to judge whether the candidate's answer is consistent with the standard answer according to the form of the question. Don't try to answer the original question. You can assume that the standard answer is definitely correct. + 2. Because the candidate's answer may be different from the standard answer in the form of expression, before making a judgment, please understand the question and the standard answer first, and then judge whether the candidate's answer is correct, but be careful not to try to answer the original question. + 3. Some answers may contain multiple items, such as multiple-choice questions, multiple-select questions, fill-in-the-blank questions, etc. As long as the answer is the same as the standard answer, it is enough. For multiple-select questions and multiple-blank fill-in-the-blank questions, the candidate needs to answer all the corresponding options or blanks correctly to be considered correct. + 4. Some answers may be expressed in different ways, such as some answers may be a mathematical expression, some answers may be a textual description, as long as the meaning expressed is the same. And some formulas are expressed in different ways, but they are equivalent and correct. + 5. If the prediction is given with \\boxed{}, please ignore the \\boxed{} and only judge whether the candidate's answer is consistent with the standard answer. + + Please judge whether the following answers are consistent with the standard answer based on the above criteria. Grade the predicted answer of this new question as one of: + A: CORRECT + B: INCORRECT + Just return the letters "A" or "B", with no text around it. + + Here is your task. Simply reply with either CORRECT, INCORRECT. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer. + + + : \n{problem}\n\n\n + : \n{answer}\n\n\n + : \n{prediction}\n\n\n + + Judging the correctness of candidates' answers: +""".strip() + +# Evaluation configuration +math_eval_cfg = dict( + evaluator=dict( + type=GenericLLMEvaluator, + prompt_template=dict( + type=PromptTemplate, + template=dict( + begin=[ + dict( + role='SYSTEM', + fallback_role='HUMAN', + prompt="You are a helpful assistant who evaluates the correctness and quality of models' outputs.") + ], + round=[ + dict( + role='HUMAN', + prompt = GRADER_TEMPLATE + ), + ]), + ), + dataset_cfg=dict( + type=HLEDataset, + path='cais/hle', + reader_cfg=math_reader_cfg, + ), + judge_cfg=dict(), + dict_postprocessor=dict(type=generic_llmjudge_postprocess), + ), + pred_role='BOT', +) + + +hle_datasets = [ + dict( + type=HLEDataset, + abbr='hle_llmjudge', + path='cais/hle', + reader_cfg=math_reader_cfg, + infer_cfg=math_infer_cfg, + eval_cfg=math_eval_cfg, + ) +] diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index b28f78ed..4052c630 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -57,6 +57,7 @@ from .gpqa import * # noqa: F401, F403 from .gsm8k import * # noqa: F401, F403 from .gsm_hard import * # noqa: F401, F403 from .hellaswag import * # noqa: F401, F403 +from .hle import * # noqa: F401, F403 from .huggingface import * # noqa: F401, F403 from .humaneval import * # noqa: F401, F403 from .humaneval_multi import * # noqa: F401, F403 diff --git a/opencompass/datasets/hle.py b/opencompass/datasets/hle.py new file mode 100644 index 00000000..2d7cf74b --- /dev/null +++ b/opencompass/datasets/hle.py @@ -0,0 +1,17 @@ +from datasets import load_dataset + +from opencompass.registry import LOAD_DATASET + +from .base import BaseDataset + + +@LOAD_DATASET.register_module() +class HLEDataset(BaseDataset): + + @staticmethod + def load(path: str): + dataset = load_dataset(path) + dataset['test'] = dataset['test'].filter(lambda x: x['image'] == '') + dataset['test'] = dataset['test'].rename_column('question', 'problem') + dataset['train'] = dataset['test'] + return dataset From 5a2462a26faf56c0990513309afb7676aa50a8d8 Mon Sep 17 00:00:00 2001 From: liushz Date: Mon, 3 Mar 2025 10:52:30 +0000 Subject: [PATCH 09/12] Add HLE dataset --- dataset-index.yml | 5 +++++ opencompass/configs/datasets/HLE/hle_gen.py | 3 ++- ...le_llmjudge_gen_6ff468.py => hle_llmverify_gen_6ff468.py} | 0 3 files changed, 7 insertions(+), 1 deletion(-) rename opencompass/configs/datasets/HLE/{hle_llmjudge_gen_6ff468.py => hle_llmverify_gen_6ff468.py} (100%) diff --git a/dataset-index.yml b/dataset-index.yml index 9fbde8bd..c764c369 100644 --- a/dataset-index.yml +++ b/dataset-index.yml @@ -399,6 +399,11 @@ category: Math paper: https://proceedings.mlr.press/v202/gao23f/gao23f.pdf configpath: opencompass/configs/datasets/gsm_hard +- hellaswag: + name: HLE + category: Reasoning + paper: https://lastexam.ai/paper + configpath: opencompass/configs/datasets/HLE - hellaswag: name: HellaSwag category: Reasoning diff --git a/opencompass/configs/datasets/HLE/hle_gen.py b/opencompass/configs/datasets/HLE/hle_gen.py index a4ff86b4..598f1dde 100644 --- a/opencompass/configs/datasets/HLE/hle_gen.py +++ b/opencompass/configs/datasets/HLE/hle_gen.py @@ -1,4 +1,5 @@ from mmengine.config import read_base with read_base(): - from .hle_llmjudge_gen_63a000 import hle_datasets # noqa: F401, F403 + # Default use LLM as a judge + from .hle_llmverify_gen_6ff468 import hle_datasets # noqa: F401, F403 diff --git a/opencompass/configs/datasets/HLE/hle_llmjudge_gen_6ff468.py b/opencompass/configs/datasets/HLE/hle_llmverify_gen_6ff468.py similarity index 100% rename from opencompass/configs/datasets/HLE/hle_llmjudge_gen_6ff468.py rename to opencompass/configs/datasets/HLE/hle_llmverify_gen_6ff468.py From a69d02f7468e070bfc6b2ca988bedb356f07fbbd Mon Sep 17 00:00:00 2001 From: liushz Date: Mon, 3 Mar 2025 10:54:32 +0000 Subject: [PATCH 10/12] Add HLE dataset --- dataset-index.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dataset-index.yml b/dataset-index.yml index c764c369..b8ec7041 100644 --- a/dataset-index.yml +++ b/dataset-index.yml @@ -399,8 +399,8 @@ category: Math paper: https://proceedings.mlr.press/v202/gao23f/gao23f.pdf configpath: opencompass/configs/datasets/gsm_hard -- hellaswag: - name: HLE +- hle: + name: HLE(Humanity's Last Exam) category: Reasoning paper: https://lastexam.ai/paper configpath: opencompass/configs/datasets/HLE From 520bf5867de138875ccc1b38b71277ee12e259cf Mon Sep 17 00:00:00 2001 From: liushz Date: Wed, 12 Mar 2025 10:23:46 +0000 Subject: [PATCH 11/12] Add AIME2025 oss info --- opencompass/utils/datasets_info.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/opencompass/utils/datasets_info.py b/opencompass/utils/datasets_info.py index 25c877c6..5f055cc0 100644 --- a/opencompass/utils/datasets_info.py +++ b/opencompass/utils/datasets_info.py @@ -309,6 +309,11 @@ DATASETS_MAPPING = { "hf_id": "", "local": "./data/aime.jsonl", }, + "opencompass/aime2025": { + "ms_id": "", + "hf_id": "", + "local": "./data/aime2025/aime2025.jsonl", + }, "opencompass/cmo_fib": { "ms_id": "", "hf_id": "", @@ -652,11 +657,16 @@ DATASETS_URL = { "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/test_generation.zip", "md5": "918a6ea2b1eee6f2b1314db3c21cb4c7", }, - "/aime": { + "/aime2024": { "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/aime.zip", "md5": "fbe2d0577fc210962a549f8cea1a00c8", }, + "/aime2025": { + "url": + "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/aime2025.zip", + "md5": "aa18cd5d2e2de246c5397f5eb1e61004", + }, "/cmo": { "url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/cmo.zip", From 5cbd8b7324992ee5e86a5ae7546fbe45069d1ae7 Mon Sep 17 00:00:00 2001 From: liushz Date: Tue, 25 Mar 2025 02:54:26 +0000 Subject: [PATCH 12/12] Fix torch dtype error --- opencompass/models/huggingface_above_v4_33.py | 56 ++++++++++++++----- 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/opencompass/models/huggingface_above_v4_33.py b/opencompass/models/huggingface_above_v4_33.py index 5cd38b4a..8e9bfe94 100644 --- a/opencompass/models/huggingface_above_v4_33.py +++ b/opencompass/models/huggingface_above_v4_33.py @@ -124,20 +124,48 @@ def _get_meta_template(meta_template): return APITemplateParser(meta_template or default_meta_template) -def _set_model_kwargs_torch_dtype(model_kwargs): +def _set_model_kwargs_torch_dtype(model_kwargs, path=None): import torch - if 'torch_dtype' not in model_kwargs: - torch_dtype = torch.float16 + from transformers import AutoConfig + + # If torch_dtype already exists and is not a string, return directly + if 'torch_dtype' in model_kwargs and not isinstance(model_kwargs['torch_dtype'], str): + return model_kwargs + + # Mapping from string to torch data types + dtype_map = { + 'torch.float16': torch.float16, 'float16': torch.float16, + 'torch.bfloat16': torch.bfloat16, 'bfloat16': torch.bfloat16, + 'torch.float': torch.float, 'float': torch.float, + 'torch.float32': torch.float32, 'float32': torch.float32, + 'auto': 'auto', 'None': None + } + + # 1. Priority: Use torch_dtype from model_kwargs if available + if 'torch_dtype' in model_kwargs: + torch_dtype = dtype_map.get(model_kwargs['torch_dtype'], torch.float16) + + # 2. Secondary: Try to read from model config + elif path is not None: + try: + config = AutoConfig.from_pretrained(path) + if hasattr(config, 'torch_dtype'): + config_dtype = config.torch_dtype + if isinstance(config_dtype, str): + torch_dtype = dtype_map.get(config_dtype, torch.float16) + else: + torch_dtype = config_dtype + else: + torch_dtype = torch.float16 + except Exception: + torch_dtype = torch.float16 + + # 3. Default: Use float16 as fallback else: - torch_dtype = { - 'torch.float16': torch.float16, - 'torch.bfloat16': torch.bfloat16, - 'torch.float': torch.float, - 'auto': 'auto', - 'None': None, - }.get(model_kwargs['torch_dtype']) - if torch_dtype is not None: - model_kwargs['torch_dtype'] = torch_dtype + torch_dtype = torch.float16 + + # Update model_kwargs with the resolved torch_dtype + model_kwargs['torch_dtype'] = torch_dtype return model_kwargs @@ -218,12 +246,12 @@ class HuggingFacewithChatTemplate(BaseModel): raise ValueError('pad_token_id is not set for this tokenizer. Please set `pad_token_id={PAD_TOKEN_ID}` in model_cfg.') def _load_model(self, path: str, kwargs: dict, peft_path: Optional[str] = None, peft_kwargs: dict = dict()): - from transformers import AutoModel, AutoModelForCausalLM + from transformers import AutoConfig, AutoModel, AutoModelForCausalLM DEFAULT_MODEL_KWARGS = dict(device_map='auto', trust_remote_code=True) model_kwargs = DEFAULT_MODEL_KWARGS model_kwargs.update(kwargs) - model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs) + model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs, path) self.logger.debug(f'using model_kwargs: {model_kwargs}') if is_npu_available(): model_kwargs['device_map'] = 'npu'