mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support OlympiadBench Benchmark (#1841)
* Support OlympiadBench Benchmark * Support OlympiadBench Benchmark * Support OlympiadBench Benchmark * update dataset path * Update olmpiadBench * Update olmpiadBench * Update olmpiadBench --------- Co-authored-by: liushz <qq1791167085@163.com>
This commit is contained in:
parent
70f2c963d3
commit
412199f802
36
examples/eval_OlympiadBench.py
Normal file
36
examples/eval_OlympiadBench.py
Normal file
@ -0,0 +1,36 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from opencompass.configs.datasets.OlympiadBench.OlympiadBench_0shot_gen_be8b13 import olympiadbench_datasets
|
||||
|
||||
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import models as lmdeploy_qwen2_5_7b_instruct_model
|
||||
|
||||
from opencompass.configs.summarizers.OlympiadBench import summarizer
|
||||
|
||||
|
||||
datasets = sum([v for k, v in locals().items() if k.endswith('_datasets') or k == 'datasets'], [])
|
||||
models = sum([v for k, v in locals().items() if k.endswith('_model')], [])
|
||||
|
||||
from opencompass.runners import LocalRunner
|
||||
from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner
|
||||
from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask
|
||||
|
||||
infer = dict(
|
||||
partitioner=dict(type=NumWorkerPartitioner, num_worker=8),
|
||||
runner=dict(
|
||||
type=LocalRunner,
|
||||
max_num_workers=8,
|
||||
task=dict(type=OpenICLInferTask)
|
||||
),
|
||||
)
|
||||
|
||||
eval = dict(
|
||||
partitioner=dict(type=NaivePartitioner, n=10),
|
||||
runner=dict(
|
||||
type=LocalRunner,
|
||||
max_num_workers=256,
|
||||
task=dict(type=OpenICLEvalTask)
|
||||
),
|
||||
)
|
||||
|
||||
work_dir = 'outputs/debug/OlympiadBench'
|
@ -0,0 +1,50 @@
|
||||
from mmengine.config import read_base
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import OlympiadBenchDataset, OlympiadBenchEvaluator, olympiadbench_postprocess_v2
|
||||
|
||||
|
||||
with read_base():
|
||||
from .OlympiadBench_categories import categories
|
||||
|
||||
# Create prompter instance for problems
|
||||
olympiadbench_prompter_cfg = dict(
|
||||
type='OlympiadBenchPrompter'
|
||||
)
|
||||
|
||||
olympiadbench_reader_cfg = dict(
|
||||
input_columns=[
|
||||
'problem', 'language', 'subject', 'question_type',
|
||||
'answer_type', 'is_multiple_answer', 'unit', 'questions'
|
||||
],
|
||||
output_column='solution'
|
||||
)
|
||||
|
||||
olympiadbench_datasets = []
|
||||
for _name in categories:
|
||||
olympiadbench_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type='OlympiadBenchTemplate'
|
||||
),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer),
|
||||
)
|
||||
|
||||
olympiadbench_eval_cfg = dict(
|
||||
evaluator=dict(type=OlympiadBenchEvaluator, version='v2'),
|
||||
pred_postprocessor=dict(type=olympiadbench_postprocess_v2),
|
||||
)
|
||||
|
||||
olympiadbench_datasets.append(
|
||||
dict(
|
||||
type=OlympiadBenchDataset,
|
||||
abbr=f'OlympiadBench_{_name}',
|
||||
path='opencompass/OlympiadBench',
|
||||
name=_name,
|
||||
reader_cfg=olympiadbench_reader_cfg,
|
||||
infer_cfg=olympiadbench_infer_cfg,
|
||||
eval_cfg=olympiadbench_eval_cfg,
|
||||
)
|
||||
)
|
||||
|
||||
del _name
|
@ -0,0 +1,7 @@
|
||||
categories = [
|
||||
'OE_TO_maths_en_COMP', # OpenEnded - TextOnly - maths - COMP
|
||||
'OE_TO_maths_zh_COMP', # OpenEnded - TextOnly - maths - COMP
|
||||
'OE_TO_maths_zh_CEE', # OpenEnded - TextOnly - maths - CEE
|
||||
'OE_TO_physics_en_COMP', # OpenEnded - TextOnly - physics - COMP
|
||||
'OE_TO_physics_zh_CEE' # OpenEnded - TextOnly - physics - CEE
|
||||
]
|
15
opencompass/configs/summarizers/OlympiadBench.py
Normal file
15
opencompass/configs/summarizers/OlympiadBench.py
Normal file
@ -0,0 +1,15 @@
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .groups.OlympiadBench import OlympiadBench_summary_groups
|
||||
|
||||
summarizer = dict(
|
||||
dataset_abbrs=[
|
||||
'OlympiadBench_OE_TO_maths_en_COMP',
|
||||
'OlympiadBench_OE_TO_maths_zh_COMP',
|
||||
'OlympiadBench_OE_TO_maths_zh_CEE',
|
||||
'OlympiadBench_OE_TO_physics_en_COMP',
|
||||
'OlympiadBench_OE_TO_physics_zh_CEE'
|
||||
],
|
||||
summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),
|
||||
)
|
11
opencompass/configs/summarizers/groups/OlympiadBench.py
Normal file
11
opencompass/configs/summarizers/groups/OlympiadBench.py
Normal file
@ -0,0 +1,11 @@
|
||||
categories = [
|
||||
'OE_TO_maths_en_COMP', # OpenEnded - TextOnly - maths - COMP
|
||||
'OE_TO_maths_zh_COMP', # OpenEnded - TextOnly - maths - COMP
|
||||
'OE_TO_maths_zh_CEE', # OpenEnded - TextOnly - maths - CEE
|
||||
'OE_TO_physics_en_COMP', # OpenEnded - TextOnly - physics - COMP
|
||||
'OE_TO_physics_zh_CEE' # OpenEnded - TextOnly - physics - CEE
|
||||
]
|
||||
|
||||
OlympiadBench_summary_groups = [
|
||||
{'name': 'OlympiadBench', 'subsets': ['OlympiadBench_' + c.replace(' ', '_') for c in categories]},
|
||||
]
|
801
opencompass/datasets/OlympiadBench.py
Normal file
801
opencompass/datasets/OlympiadBench.py
Normal file
@ -0,0 +1,801 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from os import environ
|
||||
from typing import Dict
|
||||
|
||||
import sympy as sp
|
||||
from datasets import Dataset, DatasetDict
|
||||
from sympy import Eq, Pow, simplify, sympify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.registry import (ICL_PROMPT_TEMPLATES, LOAD_DATASET,
|
||||
TEXT_POSTPROCESSORS)
|
||||
from opencompass.utils import get_data_path
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
# Load Dataset
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class OlympiadBenchDataset(BaseDataset):
|
||||
"""Dataset for OlympiadBench.
|
||||
|
||||
Args:
|
||||
path (str): Path to dataset directory
|
||||
name (str): Name of specific json file to load
|
||||
e.g. 'OE_TO_maths_en_COMP'
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, name: str = None, **kwargs):
|
||||
"""Load dataset.
|
||||
|
||||
Args:
|
||||
path (str): Path to dataset directory
|
||||
name (str): Name of specific json file to load
|
||||
|
||||
Returns:
|
||||
DatasetDict: Dataset with test and train splits
|
||||
"""
|
||||
path = get_data_path(path)
|
||||
dataset = DatasetDict()
|
||||
raw_data = []
|
||||
|
||||
if environ.get('DATASET_SOURCE') == 'ModelScope':
|
||||
from modelscope import MsDataset
|
||||
|
||||
ms_dataset = MsDataset.load(path, split='train')
|
||||
for item in ms_dataset:
|
||||
raw_data.append({
|
||||
'problem':
|
||||
item['question'],
|
||||
'solution':
|
||||
item['final_answer'][0],
|
||||
'language':
|
||||
item['language'],
|
||||
'subject':
|
||||
item['subject'],
|
||||
'question_type':
|
||||
item['question_type'],
|
||||
'answer_type':
|
||||
item['answer_type'],
|
||||
'is_multiple_answer':
|
||||
item['is_multiple_answer'],
|
||||
'unit':
|
||||
item['unit'],
|
||||
'error':
|
||||
item['error'],
|
||||
'questions':
|
||||
item, # may not be used
|
||||
})
|
||||
else:
|
||||
# Construct file path using name parameter
|
||||
if name is None:
|
||||
raise ValueError(
|
||||
"Must specify 'name' parameter to load specific json file")
|
||||
|
||||
# file_path = os.path.join(path, name, f'{name}.json')
|
||||
file_path = os.path.join(path, f'{name}.json')
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f'File not found: {file_path}')
|
||||
|
||||
# Load the specified json file
|
||||
data = json.load(open(file_path, encoding='utf-8'))
|
||||
for item in data:
|
||||
raw_data.append({
|
||||
'problem':
|
||||
item['question'],
|
||||
'solution':
|
||||
item['final_answer'][0],
|
||||
'language':
|
||||
item['language'],
|
||||
'subject':
|
||||
item['subject'],
|
||||
'question_type':
|
||||
item['question_type'],
|
||||
'answer_type':
|
||||
item['answer_type'],
|
||||
'is_multiple_answer':
|
||||
item['is_multiple_answer'],
|
||||
'unit':
|
||||
item['unit'],
|
||||
'error':
|
||||
item['error'],
|
||||
'questions':
|
||||
item, # may not be used
|
||||
})
|
||||
|
||||
dataset['test'] = Dataset.from_list(raw_data)
|
||||
dataset['train'] = Dataset.from_list(raw_data)
|
||||
return dataset
|
||||
|
||||
|
||||
# Construct Prompt
|
||||
|
||||
|
||||
def get_single_answer_type_text(answer_type, is_chinese):
|
||||
if '-' in answer_type: # No need now
|
||||
answer_type = answer_type[:answer_type.find('-')]
|
||||
chinese_answer_type_dict = {
|
||||
'Numerical': '数值',
|
||||
'Expression': '表达式',
|
||||
'Equation': '方程',
|
||||
'Interval': '区间',
|
||||
}
|
||||
english_answer_type_dict = {
|
||||
'Numerical': 'a numerical value',
|
||||
'Expression': 'an expression',
|
||||
'Equation': 'an equation',
|
||||
'Interval': 'an interval',
|
||||
}
|
||||
|
||||
for t in ['Numerical', 'Expression', 'Equation', 'Interval']:
|
||||
if t in answer_type:
|
||||
if is_chinese:
|
||||
return chinese_answer_type_dict[t]
|
||||
else:
|
||||
return english_answer_type_dict[t]
|
||||
raise ValueError(f'Error parsing answer type {answer_type}!')
|
||||
|
||||
|
||||
def get_answer_type_text(answer_type, is_chinese, multiple_answer):
|
||||
if ('Need_human_evaluate' in answer_type) or ('Tuple' in answer_type):
|
||||
return ''
|
||||
if not multiple_answer:
|
||||
answer_text = get_single_answer_type_text(answer_type, is_chinese)
|
||||
if is_chinese:
|
||||
return f',答案类型为{answer_text}'
|
||||
else:
|
||||
return (f'The answer of The problem should be '
|
||||
f'{answer_text}. ')
|
||||
# Multiple answers case
|
||||
if ',' not in answer_type: # Same answer type for all answers
|
||||
answer_text = get_single_answer_type_text(answer_type, is_chinese)
|
||||
if is_chinese:
|
||||
return f',题目有多个答案,答案类型均为{answer_text}'
|
||||
else:
|
||||
return (f'The problem has multiple answers, each of them '
|
||||
f'should be {answer_text}. ')
|
||||
# Different answer types
|
||||
answer_types = answer_type.split(',')
|
||||
answer_types = [
|
||||
get_single_answer_type_text(t, is_chinese) for t in answer_types
|
||||
]
|
||||
if len(set(answer_types)) == 1:
|
||||
answer_text = answer_types[0]
|
||||
if is_chinese:
|
||||
return f',题目有多个答案,答案类型均为{answer_text}'
|
||||
else:
|
||||
return (f'The problem has multiple answers, each of them '
|
||||
f'should be {answer_text}. ')
|
||||
else:
|
||||
if is_chinese:
|
||||
answer_text = '、'.join(answer_types)
|
||||
return f',题目有多个答案,答案类型分别为{answer_text}'
|
||||
else:
|
||||
answer_text = ', '.join(answer_types)
|
||||
return (f'The problem has multiple answers, '
|
||||
f'with the answers in order being {answer_text}. ')
|
||||
|
||||
|
||||
class OlympiadBenchPrompter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def make_prompt(
|
||||
self,
|
||||
language,
|
||||
subject,
|
||||
question_type,
|
||||
answer_type,
|
||||
is_multiple_answer,
|
||||
unit,
|
||||
):
|
||||
self.is_chinese = language == 'Chinese'
|
||||
self.is_math = subject == 'Math'
|
||||
self.is_theorem_proving = question_type == 'Theorem proof'
|
||||
"""Generate prompt based on question properties."""
|
||||
if self.is_chinese:
|
||||
subject_content = '数学' if self.is_math else '物理'
|
||||
if self.is_theorem_proving:
|
||||
prompt = (f'以下是中国{subject_content}竞赛中的证明题。请根据题目的要求,'
|
||||
f'运用逻辑推理及常用定理证明题目中的命题。证明过程中使用的变量和公式请使用LaTeX格式表示。')
|
||||
else:
|
||||
answer_type_text = get_answer_type_text(
|
||||
answer_type,
|
||||
is_chinese=True,
|
||||
multiple_answer=is_multiple_answer,
|
||||
)
|
||||
if is_multiple_answer:
|
||||
multiple_answer_text = '\\boxed{用英文逗号连接的多个答案}'
|
||||
else:
|
||||
multiple_answer_text = '\\boxed{答案}'
|
||||
unit_text = ''
|
||||
if unit:
|
||||
multiple_answer_text += '(单位)'
|
||||
unit_text = ',注意答案的单位不要放在\\boxed{}中'
|
||||
prompt = (f'以下是中国{subject_content}竞赛中的解答题{answer_type_text}。'
|
||||
f'请根据题目的要求和所提供的信息计算得出答案。解答过程和结果中使用的'
|
||||
f'变量和公式请使用LaTeX格式表示。请在最后以"所以最终答案是'
|
||||
f'{multiple_answer_text}。"显式给出结果{unit_text}。')
|
||||
else:
|
||||
subject_content = 'Math' if self.is_math else 'Physics'
|
||||
if self.is_theorem_proving:
|
||||
prompt = (
|
||||
f'The following is a theorem proving problem from an '
|
||||
f'International {subject_content} competition. Please use '
|
||||
f'logical reasoning and common theorems to prove the '
|
||||
f'proposition in the problem according to the given '
|
||||
f'requirements. Please use LaTeX format to represent the '
|
||||
f'variables and formulas used in the proof.')
|
||||
else:
|
||||
if is_multiple_answer:
|
||||
multiple_answer_text = (
|
||||
'\\boxed{multiple answers connected with commas}')
|
||||
else:
|
||||
multiple_answer_text = '\\boxed{answer}'
|
||||
unit_text = ''
|
||||
if unit:
|
||||
multiple_answer_text += '(unit)'
|
||||
unit_text = (', note that the unit of the answer should '
|
||||
'not be included in \\boxed{}')
|
||||
answer_type_text = get_answer_type_text(
|
||||
answer_type,
|
||||
is_chinese=False,
|
||||
multiple_answer=is_multiple_answer,
|
||||
)
|
||||
prompt = (
|
||||
f'The following is an open-ended problem from an '
|
||||
f'International {subject_content} competition. '
|
||||
f'{answer_type_text}Please calculate the answer according '
|
||||
f'to the given requirements and the information provided. '
|
||||
f'Please use LaTeX format to represent the variables and '
|
||||
f'formulas used in the solution process and results. '
|
||||
f'Please end your solution with "So the final answer is '
|
||||
f'{multiple_answer_text}." and give the result explicitly'
|
||||
f'{unit_text}.')
|
||||
# Add problem statement to the prompt
|
||||
prompt = prompt + '\n' + '{problem}' + '\n'
|
||||
# Add step-by-step reasoning instruction
|
||||
if self.is_chinese:
|
||||
prompt += '\n请通过逐步推理来解答问题,并把最终答案放置于\\boxed{}中。'
|
||||
else:
|
||||
prompt += ('\nPlease reason step by step, and put your final '
|
||||
'answer within \\boxed{}.')
|
||||
return prompt
|
||||
|
||||
|
||||
# Evaluate
|
||||
|
||||
|
||||
class MathJudger:
|
||||
|
||||
def __init__(self):
|
||||
self.special_signal_map = {
|
||||
'\\left': '',
|
||||
'\\right': '',
|
||||
'∶': ':',
|
||||
',': ',',
|
||||
'$': '',
|
||||
'\\approx': '=',
|
||||
'\\simeq': '=',
|
||||
'\\sim': '=',
|
||||
'^\\prime': "'",
|
||||
'^{\\prime}': "'",
|
||||
'^\\circ': '',
|
||||
'%': '',
|
||||
}
|
||||
self.pi = parse_latex('\\pi')
|
||||
self.precision = 1e-8
|
||||
|
||||
def split_by_comma(self, expr: str):
|
||||
in_bracket_num = 0
|
||||
splitted_expr = []
|
||||
start_idx = 0
|
||||
for i, char in enumerate(expr):
|
||||
if char == '(' or char == '[':
|
||||
in_bracket_num += 1
|
||||
elif char == ')' or char == ']':
|
||||
in_bracket_num -= 1
|
||||
elif char == ',' and in_bracket_num == 0:
|
||||
splitted_expr.append(expr[start_idx:i].strip())
|
||||
start_idx = i + 1
|
||||
|
||||
if start_idx < len(expr):
|
||||
splitted_expr.append(expr[start_idx:].strip())
|
||||
|
||||
return splitted_expr
|
||||
|
||||
def trans_plus_minus_sign(self, expr_list: list):
|
||||
new_expr_list = []
|
||||
for expr in expr_list:
|
||||
if '\\pm' in expr:
|
||||
new_expr_list.append(expr.replace('\\pm', '+'))
|
||||
new_expr_list.append(expr.replace('\\pm', '-'))
|
||||
else:
|
||||
new_expr_list.append(expr)
|
||||
|
||||
return new_expr_list
|
||||
|
||||
def judge(self, expression1, expression2, precision=1e-8):
|
||||
# (默认 expression1 为 Ground_Truth)
|
||||
precision = precision if type(precision) == list else [precision]
|
||||
|
||||
try:
|
||||
expression1, expression2 = self.preprocess(expression1,
|
||||
expression2)
|
||||
except Exception: # 处理具体异常
|
||||
return False
|
||||
if expression1 == expression2:
|
||||
return True
|
||||
|
||||
# 去除字符串中的中文字符
|
||||
expression1 = re.sub(r'[\u4e00-\u9fff]+', '', expression1)
|
||||
expression2 = re.sub(r'[\u4e00-\u9fff]+', '', expression2)
|
||||
|
||||
expression1 = self.split_by_comma(expression1)
|
||||
expression2 = self.split_by_comma(expression2)
|
||||
|
||||
temp_list1 = self.trans_plus_minus_sign(expression1)
|
||||
temp_list2 = self.trans_plus_minus_sign(expression2)
|
||||
|
||||
# 设计误差值列表
|
||||
if len(precision) <= 1:
|
||||
precision = precision * len(temp_list1)
|
||||
|
||||
if len(temp_list1) != len(temp_list2):
|
||||
return False
|
||||
|
||||
# 判断两个列表中的元素是否可以两两配对,并且两两相等
|
||||
idx = -1
|
||||
while len(temp_list1) != 0:
|
||||
idx = (idx + 1) % len(temp_list1)
|
||||
|
||||
item1 = temp_list1[idx]
|
||||
self.precision = precision[idx]
|
||||
|
||||
for item2 in temp_list2:
|
||||
if self.is_equal(item1, item2):
|
||||
temp_list1.remove(item1)
|
||||
temp_list2.remove(item2)
|
||||
precision.remove(self.precision)
|
||||
break
|
||||
else:
|
||||
return False
|
||||
|
||||
# 如果所有元素都匹配并移除,列表可以配对
|
||||
return True
|
||||
|
||||
def is_interval(self, epr):
|
||||
return epr.startswith(('(', '[')) and epr.endswith((')', ']'))
|
||||
|
||||
def sympy_sub_pi(self, expression_sympy):
|
||||
return expression_sympy.subs(self.pi, math.pi)
|
||||
|
||||
def is_equal(self, expression1, expression2):
|
||||
if (expression1 == expression2 and expression1 != ''
|
||||
and expression2 != ''):
|
||||
return True
|
||||
|
||||
# 先判断是否是两个区间
|
||||
if self.is_interval(expression1) and self.is_interval(expression2):
|
||||
try:
|
||||
if self.interval_equal(expression1, expression2):
|
||||
return True
|
||||
except Exception: # 处理具体异常
|
||||
return False
|
||||
|
||||
# 再判断是否在数值上相等
|
||||
try:
|
||||
if self.numerical_equal(expression1, expression2):
|
||||
return True
|
||||
except Exception: # 处理具体异常
|
||||
pass
|
||||
|
||||
# 再判断是否是表达式相等
|
||||
try:
|
||||
if self.expression_equal(
|
||||
expression1, expression2) and not ('=' in expression1
|
||||
and '=' in expression2):
|
||||
return True
|
||||
except Exception: # 处理具体异常
|
||||
pass
|
||||
|
||||
# 再判断是否是等式相等
|
||||
try:
|
||||
if self.equation_equal(expression1, expression2):
|
||||
return True
|
||||
except Exception: # 处理具体异常
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def numerical_equal(
|
||||
self,
|
||||
expression1: str,
|
||||
expression2: str,
|
||||
include_percentage: bool = True,
|
||||
):
|
||||
"""
|
||||
(默认 expression1 为 Ground_Truth)
|
||||
函数: 判读两个数值是否在误差允许范围内相等
|
||||
步骤1: 将可能出现的百分号的情况包含进来
|
||||
步骤2: 使用 math.isclose 函数判断是否相等
|
||||
"""
|
||||
reference = float(expression1)
|
||||
prediction = float(expression2)
|
||||
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
|
||||
for item in gt_result:
|
||||
if abs(item - prediction) <= self.precision * 1.01:
|
||||
return True
|
||||
return False
|
||||
|
||||
def expression_equal(self, exp1, exp2):
|
||||
"""
|
||||
(默认 expression1 为 Ground_Truth)
|
||||
函数: 判断两个表达式是否在数学意义上等价
|
||||
步骤1: 提取表达式, 防止有的模型会给出"x=1"而不是"1"
|
||||
步骤2: 使用 sympy 库进行等价判断
|
||||
"""
|
||||
|
||||
# 只提取等号右边的表达式
|
||||
def extract_expression(expression):
|
||||
if '=' in expression:
|
||||
expression = expression.split('=')[1]
|
||||
return expression.strip()
|
||||
|
||||
exp1 = extract_expression(exp1)
|
||||
exp2 = extract_expression(exp2)
|
||||
|
||||
# 将表达式转换为 sympy 中能够进行处理的格式
|
||||
expr1_sym = sympify(parse_latex(exp1))
|
||||
expr2_sym = sympify(parse_latex(exp2))
|
||||
|
||||
if expr1_sym == expr2_sym:
|
||||
return True
|
||||
else:
|
||||
expr1_sym = self.sympy_sub_pi(expr1_sym)
|
||||
expr2_sym = self.sympy_sub_pi(expr2_sym)
|
||||
|
||||
if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or (
|
||||
not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
|
||||
return False
|
||||
elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
|
||||
try:
|
||||
if not (self.can_compute_power(expr1_sym)
|
||||
and self.can_compute_power(expr2_sym)):
|
||||
print(f'These two number can not be calculated by '
|
||||
f'current computer for: '
|
||||
f'"{str(expr1_sym)}" and "{str(expr2_sym)}"')
|
||||
return False
|
||||
|
||||
if (abs(expr1_sym.evalf() - expr2_sym.evalf()) <=
|
||||
self.precision * 1.01):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception: # 处理具体异常
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
simplified_expr = simplify(expr1_sym - expr2_sym)
|
||||
|
||||
num_value = simplified_expr.evalf()
|
||||
|
||||
return abs(num_value) < 1e-3
|
||||
except Exception: # 处理具体异常
|
||||
return False
|
||||
|
||||
def equation_equal(self, expression1, expression2):
|
||||
"""
|
||||
(expression1 is assumed to be Ground_Truth)
|
||||
Function: Check if two equations are mathematically equivalent
|
||||
Step 1: Simplify equations to standard form with right side equal to 0
|
||||
Step 2: Use sympy library to calculate quotient of left sides,
|
||||
if quotient or its reciprocal is integer, equations are equivalent
|
||||
"""
|
||||
|
||||
# Convert equations to sympy format with right side moved to left side
|
||||
def simplify_equation(latex_eq):
|
||||
# Split left and right sides of equation
|
||||
lhs, rhs = latex_eq.split('=')
|
||||
|
||||
# Parse LaTeX expressions using parse_latex
|
||||
lhs_expr = parse_latex(lhs)
|
||||
rhs_expr = parse_latex(rhs)
|
||||
|
||||
# Create equation object
|
||||
equation = Eq(lhs_expr, rhs_expr)
|
||||
|
||||
# Simplify equation by moving right side to left
|
||||
simplified_eq = simplify(equation.lhs - equation.rhs)
|
||||
|
||||
return simplified_eq
|
||||
|
||||
expr1_sym = simplify_equation(expression1)
|
||||
expr2_sym = simplify_equation(expression2)
|
||||
|
||||
division_result_1 = simplify(expr1_sym / expr2_sym)
|
||||
division_result_2 = simplify(expr2_sym / expr1_sym)
|
||||
|
||||
# If division result or its reciprocal is
|
||||
# non-zero integer, equations are equivalent
|
||||
if (division_result_1.is_Integer
|
||||
and division_result_1 != 0) or (division_result_2.is_Integer
|
||||
and division_result_2 != 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def interval_equal(self, expression1, expression2):
|
||||
"""
|
||||
Function: Check if two intervals are mathematically equivalent
|
||||
Step 1: Simplify interval expressions,
|
||||
remove irrelevant symbols
|
||||
like "\\left", "\\right", and "x \\in"
|
||||
Step 2: Compare brackets and mathematical expressions in between
|
||||
"""
|
||||
|
||||
def compare_two_interval(inter1, inter2):
|
||||
# First compare brackets on both sides
|
||||
if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
|
||||
return False
|
||||
|
||||
inter1 = inter1.strip('[]()')
|
||||
inter2 = inter2.strip('[]()')
|
||||
|
||||
# Split interval into left and right parts
|
||||
items_1 = inter1.split(',')
|
||||
items_2 = inter2.split(',')
|
||||
|
||||
for item_1, item_2 in zip(items_1, items_2):
|
||||
if not self.expression_equal(item_1, item_2):
|
||||
return False
|
||||
return True
|
||||
|
||||
interval1 = expression1
|
||||
interval2 = expression2
|
||||
|
||||
if interval1 == interval2:
|
||||
return True
|
||||
else:
|
||||
inter_list1 = interval1.split('\\cup')
|
||||
inter_list2 = interval2.split('\\cup')
|
||||
|
||||
if len(inter_list1) != len(inter_list2):
|
||||
return False
|
||||
else:
|
||||
for inter1, inter2 in zip(inter_list1, inter_list2):
|
||||
if not compare_two_interval(inter1, inter2):
|
||||
return False
|
||||
return True
|
||||
|
||||
def preprocess(self, expression1, expression2):
|
||||
"""Extract and preprocess expressions from model output."""
|
||||
|
||||
def extract_boxed_content(latex_str):
|
||||
# Find all \boxed{...} structures
|
||||
boxed_matches = re.finditer(r'\\boxed{', latex_str)
|
||||
results = ''
|
||||
|
||||
for match in boxed_matches:
|
||||
start_index = match.end()
|
||||
end_index = start_index
|
||||
stack = 1
|
||||
|
||||
# Search from after \boxed{ until
|
||||
# finding matching closing brace
|
||||
while stack > 0 and end_index < len(latex_str):
|
||||
if latex_str[end_index] == '{':
|
||||
stack += 1
|
||||
elif latex_str[end_index] == '}':
|
||||
stack -= 1
|
||||
end_index += 1
|
||||
|
||||
if stack == 0:
|
||||
# Extract content inside \boxed{}
|
||||
content = latex_str[start_index:end_index - 1]
|
||||
results += content + ','
|
||||
else:
|
||||
raise ValueError('Mismatched braces in LaTeX string.')
|
||||
|
||||
# If no \boxed{} found, extract formulas from last line
|
||||
if results == '':
|
||||
last_line_ans = latex_str.strip().split('\n')[-1]
|
||||
dollar_pattern = r'\$(.*?)\$'
|
||||
answers = re.findall(dollar_pattern, last_line_ans)
|
||||
|
||||
if answers:
|
||||
for ans in answers:
|
||||
results += ans + ','
|
||||
else:
|
||||
results = latex_str
|
||||
|
||||
return results
|
||||
|
||||
def special_symbol_replace(expression):
|
||||
if '\\in ' in expression:
|
||||
expression = expression.split('\\in ')[1]
|
||||
|
||||
# Replace special characters that
|
||||
# don't affect LaTeX parsing (decorative)
|
||||
for signal in self.special_signal_map:
|
||||
expression = expression.replace(
|
||||
signal, self.special_signal_map[signal])
|
||||
|
||||
expression = expression.strip('\n$,.:;^_=+`!@#$%^&*~,。')
|
||||
|
||||
pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
|
||||
expression = re.sub(pattern, r'\1', expression)
|
||||
|
||||
return expression
|
||||
|
||||
exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(
|
||||
expression2)
|
||||
exp1, exp2 = special_symbol_replace(exp1), special_symbol_replace(exp2)
|
||||
|
||||
return exp1, exp2
|
||||
|
||||
def can_compute_power(self, expr):
|
||||
"""Check if the power expression can be computed.
|
||||
|
||||
Parameters:
|
||||
expr (sympy expression): The expression to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the expression can be computed, False otherwise.
|
||||
"""
|
||||
# Check if the expression is a power expression
|
||||
if isinstance(expr, Pow):
|
||||
# Extract the base and the exponent
|
||||
base, exp = expr.as_base_exp()
|
||||
|
||||
# Check if the base and the exponent are numbers
|
||||
if base.is_number and exp.is_number:
|
||||
# Set a threshold for the maximum size of the exponent
|
||||
# can be adjusted based on the computing environment
|
||||
MAX_EXP = 1000
|
||||
|
||||
# Check if the exponent is greater than the threshold
|
||||
if abs(exp.evalf()) > MAX_EXP:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
# If the base or the exponent is not a number,
|
||||
# we cannot compute the power
|
||||
return False
|
||||
else:
|
||||
# If the expression is not a power expression,
|
||||
# return True as it is not the case we are checking for
|
||||
return True
|
||||
|
||||
|
||||
@TEXT_POSTPROCESSORS.register_module('olympiadbench_postprocess_v2')
|
||||
def olympiadbench_postprocess_v2(text: str,
|
||||
is_chinese: bool = False,
|
||||
is_deepseek: bool = False) -> str:
|
||||
"""Extract answer from model output."""
|
||||
# deepseekmath has special answering format
|
||||
if is_deepseek:
|
||||
if is_chinese:
|
||||
matches = re.findall('## 解题答案(.*)', text)
|
||||
else:
|
||||
matches = re.findall('The answer is: (.*)', text)
|
||||
else:
|
||||
if is_chinese:
|
||||
matches = re.findall('所以最终答案是(.*)', text)
|
||||
else:
|
||||
matches = re.findall('So the final answer is (.*)', text)
|
||||
|
||||
# If found matches, take the last one, otherwise return the whole text
|
||||
if matches:
|
||||
return matches[-1].strip()
|
||||
return text
|
||||
|
||||
|
||||
class OlympiadBenchEvaluator(BaseEvaluator):
|
||||
"""Evaluator for OlympiadBench dataset."""
|
||||
|
||||
def __init__(self, version='v1'):
|
||||
assert version in ['v1', 'v2']
|
||||
self.version = version
|
||||
self.judger = MathJudger()
|
||||
|
||||
def score(self, predictions, references): # Remove questions parameter
|
||||
"""Calculate accuracy score.
|
||||
|
||||
Args:
|
||||
predictions (list): List of model predictions
|
||||
references (list): List of ground truth answers
|
||||
"""
|
||||
if len(predictions) != len(references):
|
||||
return {
|
||||
'error': 'predictions and references have different length'
|
||||
}
|
||||
|
||||
correct = 0
|
||||
count = 0
|
||||
details = []
|
||||
|
||||
for pred, ref in zip(predictions, references):
|
||||
detail = {'pred': pred, 'answer': ref, 'correct': False}
|
||||
count += 1
|
||||
|
||||
# Get precision/error threshold from reference if available
|
||||
precision = 1e-8
|
||||
if isinstance(ref, dict) and 'error' in ref:
|
||||
if ',' in ref['error']:
|
||||
# Multiple precisions for multiple answers
|
||||
precisions = ref['error'].split(',')
|
||||
precisions = [float(p) if p else 1e-8 for p in precisions]
|
||||
precision = precisions
|
||||
else:
|
||||
precision = float(ref['error'])
|
||||
|
||||
# Check if answer is correct
|
||||
try:
|
||||
if (isinstance(ref, dict) and 'answer_type' in ref
|
||||
and 'Tuple' in ref['answer_type']):
|
||||
# Special handling for tuple type answers
|
||||
is_correct = self.judger.judge(pred,
|
||||
ref['final_answer'][0],
|
||||
precision)
|
||||
else:
|
||||
is_correct = self.judger.judge(pred, ref, precision)
|
||||
|
||||
if is_correct:
|
||||
correct += 1
|
||||
detail['correct'] = True
|
||||
except Exception as e: # 处理具体异常
|
||||
detail['error'] = str(e)
|
||||
|
||||
details.append(detail)
|
||||
|
||||
result = {'accuracy': 100 * correct / count, 'details': details}
|
||||
return result
|
||||
|
||||
|
||||
@ICL_PROMPT_TEMPLATES.register_module()
|
||||
class OlympiadBenchTemplate(PromptTemplate):
|
||||
"""Template for OlympiadBench dataset."""
|
||||
|
||||
def __init__(self):
|
||||
# Define basic template structure
|
||||
template = dict(round=[dict(role='HUMAN', prompt='{prompt}')])
|
||||
super().__init__(template=template)
|
||||
self.prompter = OlympiadBenchPrompter()
|
||||
|
||||
def generate_item(self, entry: Dict, *args, **kwargs) -> str:
|
||||
"""Generate prompt for a single item."""
|
||||
problem = entry.get('problem', '')
|
||||
language = entry.get('language', 'English')
|
||||
subject = entry.get('subject', 'Math')
|
||||
question_type = entry.get('question_type', '')
|
||||
answer_type = entry.get('answer_type', '')
|
||||
is_multiple_answer = entry.get('is_multiple_answer', False)
|
||||
unit = entry.get('unit', '')
|
||||
|
||||
prompt = self.prompter.make_prompt(
|
||||
language=language,
|
||||
subject=subject,
|
||||
question_type=question_type,
|
||||
answer_type=answer_type,
|
||||
is_multiple_answer=is_multiple_answer,
|
||||
unit=unit,
|
||||
)
|
||||
|
||||
new_entry = {'prompt': prompt, 'problem': problem}
|
||||
|
||||
return super().generate_item(new_entry, *args, **kwargs)
|
@ -103,6 +103,7 @@ from .natural_question import * # noqa: F401, F403
|
||||
from .natural_question_cn import * # noqa: F401, F403
|
||||
from .NPHardEval import * # noqa: F401, F403
|
||||
from .obqa import * # noqa: F401, F403
|
||||
from .OlympiadBench import * # noqa: F401, F403
|
||||
from .OpenFinData import * # noqa: F401, F403
|
||||
from .piqa import * # noqa: F401, F403
|
||||
from .py150 import * # noqa: F401, F403
|
||||
|
@ -398,9 +398,18 @@ DATASETS_MAPPING = {
|
||||
"hf_id": "THUDM/LongBench-v2",
|
||||
"local": "./data/longbenchv2/data.json",
|
||||
},
|
||||
"opencompass/OlympiadBench": {
|
||||
"ms_id": "",
|
||||
"hf_id": "",
|
||||
"local": "./data/OlympiadBench",
|
||||
},
|
||||
}
|
||||
|
||||
DATASETS_URL = {
|
||||
"/OlympiadBench": {
|
||||
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/OlympiadBench.zip",
|
||||
"md5": "97e8b1ae7f6170d94817288a8930ef00",
|
||||
},
|
||||
"/longbenchv2":{
|
||||
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/longbenchv2.zip",
|
||||
"md5": "09b7e06e6f98c5cca8ad597b3d7b42f0",
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Alpaca-eval
|
||||
alpaca-eval==0.6
|
||||
# OlympiadBench
|
||||
antlr4-python3-runtime==4.11
|
||||
cn2an
|
||||
# Dingo
|
||||
dingo-python==1.1.2
|
||||
|
Loading…
Reference in New Issue
Block a user