mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature] Support OlympiadBench Benchmark (#1841)
* Support OlympiadBench Benchmark * Support OlympiadBench Benchmark * Support OlympiadBench Benchmark * update dataset path * Update olmpiadBench * Update olmpiadBench * Update olmpiadBench --------- Co-authored-by: liushz <qq1791167085@163.com>
This commit is contained in:
parent
70f2c963d3
commit
412199f802
36
examples/eval_OlympiadBench.py
Normal file
36
examples/eval_OlympiadBench.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from opencompass.configs.datasets.OlympiadBench.OlympiadBench_0shot_gen_be8b13 import olympiadbench_datasets
|
||||||
|
|
||||||
|
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import models as lmdeploy_qwen2_5_7b_instruct_model
|
||||||
|
|
||||||
|
from opencompass.configs.summarizers.OlympiadBench import summarizer
|
||||||
|
|
||||||
|
|
||||||
|
datasets = sum([v for k, v in locals().items() if k.endswith('_datasets') or k == 'datasets'], [])
|
||||||
|
models = sum([v for k, v in locals().items() if k.endswith('_model')], [])
|
||||||
|
|
||||||
|
from opencompass.runners import LocalRunner
|
||||||
|
from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner
|
||||||
|
from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask
|
||||||
|
|
||||||
|
infer = dict(
|
||||||
|
partitioner=dict(type=NumWorkerPartitioner, num_worker=8),
|
||||||
|
runner=dict(
|
||||||
|
type=LocalRunner,
|
||||||
|
max_num_workers=8,
|
||||||
|
task=dict(type=OpenICLInferTask)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
eval = dict(
|
||||||
|
partitioner=dict(type=NaivePartitioner, n=10),
|
||||||
|
runner=dict(
|
||||||
|
type=LocalRunner,
|
||||||
|
max_num_workers=256,
|
||||||
|
task=dict(type=OpenICLEvalTask)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
work_dir = 'outputs/debug/OlympiadBench'
|
@ -0,0 +1,50 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
from opencompass.datasets import OlympiadBenchDataset, OlympiadBenchEvaluator, olympiadbench_postprocess_v2
|
||||||
|
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from .OlympiadBench_categories import categories
|
||||||
|
|
||||||
|
# Create prompter instance for problems
|
||||||
|
olympiadbench_prompter_cfg = dict(
|
||||||
|
type='OlympiadBenchPrompter'
|
||||||
|
)
|
||||||
|
|
||||||
|
olympiadbench_reader_cfg = dict(
|
||||||
|
input_columns=[
|
||||||
|
'problem', 'language', 'subject', 'question_type',
|
||||||
|
'answer_type', 'is_multiple_answer', 'unit', 'questions'
|
||||||
|
],
|
||||||
|
output_column='solution'
|
||||||
|
)
|
||||||
|
|
||||||
|
olympiadbench_datasets = []
|
||||||
|
for _name in categories:
|
||||||
|
olympiadbench_infer_cfg = dict(
|
||||||
|
prompt_template=dict(
|
||||||
|
type='OlympiadBenchTemplate'
|
||||||
|
),
|
||||||
|
retriever=dict(type=ZeroRetriever),
|
||||||
|
inferencer=dict(type=GenInferencer),
|
||||||
|
)
|
||||||
|
|
||||||
|
olympiadbench_eval_cfg = dict(
|
||||||
|
evaluator=dict(type=OlympiadBenchEvaluator, version='v2'),
|
||||||
|
pred_postprocessor=dict(type=olympiadbench_postprocess_v2),
|
||||||
|
)
|
||||||
|
|
||||||
|
olympiadbench_datasets.append(
|
||||||
|
dict(
|
||||||
|
type=OlympiadBenchDataset,
|
||||||
|
abbr=f'OlympiadBench_{_name}',
|
||||||
|
path='opencompass/OlympiadBench',
|
||||||
|
name=_name,
|
||||||
|
reader_cfg=olympiadbench_reader_cfg,
|
||||||
|
infer_cfg=olympiadbench_infer_cfg,
|
||||||
|
eval_cfg=olympiadbench_eval_cfg,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
del _name
|
@ -0,0 +1,7 @@
|
|||||||
|
categories = [
|
||||||
|
'OE_TO_maths_en_COMP', # OpenEnded - TextOnly - maths - COMP
|
||||||
|
'OE_TO_maths_zh_COMP', # OpenEnded - TextOnly - maths - COMP
|
||||||
|
'OE_TO_maths_zh_CEE', # OpenEnded - TextOnly - maths - CEE
|
||||||
|
'OE_TO_physics_en_COMP', # OpenEnded - TextOnly - physics - COMP
|
||||||
|
'OE_TO_physics_zh_CEE' # OpenEnded - TextOnly - physics - CEE
|
||||||
|
]
|
15
opencompass/configs/summarizers/OlympiadBench.py
Normal file
15
opencompass/configs/summarizers/OlympiadBench.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from .groups.OlympiadBench import OlympiadBench_summary_groups
|
||||||
|
|
||||||
|
summarizer = dict(
|
||||||
|
dataset_abbrs=[
|
||||||
|
'OlympiadBench_OE_TO_maths_en_COMP',
|
||||||
|
'OlympiadBench_OE_TO_maths_zh_COMP',
|
||||||
|
'OlympiadBench_OE_TO_maths_zh_CEE',
|
||||||
|
'OlympiadBench_OE_TO_physics_en_COMP',
|
||||||
|
'OlympiadBench_OE_TO_physics_zh_CEE'
|
||||||
|
],
|
||||||
|
summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),
|
||||||
|
)
|
11
opencompass/configs/summarizers/groups/OlympiadBench.py
Normal file
11
opencompass/configs/summarizers/groups/OlympiadBench.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
categories = [
|
||||||
|
'OE_TO_maths_en_COMP', # OpenEnded - TextOnly - maths - COMP
|
||||||
|
'OE_TO_maths_zh_COMP', # OpenEnded - TextOnly - maths - COMP
|
||||||
|
'OE_TO_maths_zh_CEE', # OpenEnded - TextOnly - maths - CEE
|
||||||
|
'OE_TO_physics_en_COMP', # OpenEnded - TextOnly - physics - COMP
|
||||||
|
'OE_TO_physics_zh_CEE' # OpenEnded - TextOnly - physics - CEE
|
||||||
|
]
|
||||||
|
|
||||||
|
OlympiadBench_summary_groups = [
|
||||||
|
{'name': 'OlympiadBench', 'subsets': ['OlympiadBench_' + c.replace(' ', '_') for c in categories]},
|
||||||
|
]
|
801
opencompass/datasets/OlympiadBench.py
Normal file
801
opencompass/datasets/OlympiadBench.py
Normal file
@ -0,0 +1,801 @@
|
|||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from os import environ
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import sympy as sp
|
||||||
|
from datasets import Dataset, DatasetDict
|
||||||
|
from sympy import Eq, Pow, simplify, sympify
|
||||||
|
from sympy.parsing.latex import parse_latex
|
||||||
|
|
||||||
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||||
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
|
from opencompass.registry import (ICL_PROMPT_TEMPLATES, LOAD_DATASET,
|
||||||
|
TEXT_POSTPROCESSORS)
|
||||||
|
from opencompass.utils import get_data_path
|
||||||
|
|
||||||
|
from .base import BaseDataset
|
||||||
|
|
||||||
|
# Load Dataset
|
||||||
|
|
||||||
|
|
||||||
|
@LOAD_DATASET.register_module()
|
||||||
|
class OlympiadBenchDataset(BaseDataset):
|
||||||
|
"""Dataset for OlympiadBench.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path to dataset directory
|
||||||
|
name (str): Name of specific json file to load
|
||||||
|
e.g. 'OE_TO_maths_en_COMP'
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(path: str, name: str = None, **kwargs):
|
||||||
|
"""Load dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path to dataset directory
|
||||||
|
name (str): Name of specific json file to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatasetDict: Dataset with test and train splits
|
||||||
|
"""
|
||||||
|
path = get_data_path(path)
|
||||||
|
dataset = DatasetDict()
|
||||||
|
raw_data = []
|
||||||
|
|
||||||
|
if environ.get('DATASET_SOURCE') == 'ModelScope':
|
||||||
|
from modelscope import MsDataset
|
||||||
|
|
||||||
|
ms_dataset = MsDataset.load(path, split='train')
|
||||||
|
for item in ms_dataset:
|
||||||
|
raw_data.append({
|
||||||
|
'problem':
|
||||||
|
item['question'],
|
||||||
|
'solution':
|
||||||
|
item['final_answer'][0],
|
||||||
|
'language':
|
||||||
|
item['language'],
|
||||||
|
'subject':
|
||||||
|
item['subject'],
|
||||||
|
'question_type':
|
||||||
|
item['question_type'],
|
||||||
|
'answer_type':
|
||||||
|
item['answer_type'],
|
||||||
|
'is_multiple_answer':
|
||||||
|
item['is_multiple_answer'],
|
||||||
|
'unit':
|
||||||
|
item['unit'],
|
||||||
|
'error':
|
||||||
|
item['error'],
|
||||||
|
'questions':
|
||||||
|
item, # may not be used
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# Construct file path using name parameter
|
||||||
|
if name is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Must specify 'name' parameter to load specific json file")
|
||||||
|
|
||||||
|
# file_path = os.path.join(path, name, f'{name}.json')
|
||||||
|
file_path = os.path.join(path, f'{name}.json')
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f'File not found: {file_path}')
|
||||||
|
|
||||||
|
# Load the specified json file
|
||||||
|
data = json.load(open(file_path, encoding='utf-8'))
|
||||||
|
for item in data:
|
||||||
|
raw_data.append({
|
||||||
|
'problem':
|
||||||
|
item['question'],
|
||||||
|
'solution':
|
||||||
|
item['final_answer'][0],
|
||||||
|
'language':
|
||||||
|
item['language'],
|
||||||
|
'subject':
|
||||||
|
item['subject'],
|
||||||
|
'question_type':
|
||||||
|
item['question_type'],
|
||||||
|
'answer_type':
|
||||||
|
item['answer_type'],
|
||||||
|
'is_multiple_answer':
|
||||||
|
item['is_multiple_answer'],
|
||||||
|
'unit':
|
||||||
|
item['unit'],
|
||||||
|
'error':
|
||||||
|
item['error'],
|
||||||
|
'questions':
|
||||||
|
item, # may not be used
|
||||||
|
})
|
||||||
|
|
||||||
|
dataset['test'] = Dataset.from_list(raw_data)
|
||||||
|
dataset['train'] = Dataset.from_list(raw_data)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
# Construct Prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_single_answer_type_text(answer_type, is_chinese):
|
||||||
|
if '-' in answer_type: # No need now
|
||||||
|
answer_type = answer_type[:answer_type.find('-')]
|
||||||
|
chinese_answer_type_dict = {
|
||||||
|
'Numerical': '数值',
|
||||||
|
'Expression': '表达式',
|
||||||
|
'Equation': '方程',
|
||||||
|
'Interval': '区间',
|
||||||
|
}
|
||||||
|
english_answer_type_dict = {
|
||||||
|
'Numerical': 'a numerical value',
|
||||||
|
'Expression': 'an expression',
|
||||||
|
'Equation': 'an equation',
|
||||||
|
'Interval': 'an interval',
|
||||||
|
}
|
||||||
|
|
||||||
|
for t in ['Numerical', 'Expression', 'Equation', 'Interval']:
|
||||||
|
if t in answer_type:
|
||||||
|
if is_chinese:
|
||||||
|
return chinese_answer_type_dict[t]
|
||||||
|
else:
|
||||||
|
return english_answer_type_dict[t]
|
||||||
|
raise ValueError(f'Error parsing answer type {answer_type}!')
|
||||||
|
|
||||||
|
|
||||||
|
def get_answer_type_text(answer_type, is_chinese, multiple_answer):
|
||||||
|
if ('Need_human_evaluate' in answer_type) or ('Tuple' in answer_type):
|
||||||
|
return ''
|
||||||
|
if not multiple_answer:
|
||||||
|
answer_text = get_single_answer_type_text(answer_type, is_chinese)
|
||||||
|
if is_chinese:
|
||||||
|
return f',答案类型为{answer_text}'
|
||||||
|
else:
|
||||||
|
return (f'The answer of The problem should be '
|
||||||
|
f'{answer_text}. ')
|
||||||
|
# Multiple answers case
|
||||||
|
if ',' not in answer_type: # Same answer type for all answers
|
||||||
|
answer_text = get_single_answer_type_text(answer_type, is_chinese)
|
||||||
|
if is_chinese:
|
||||||
|
return f',题目有多个答案,答案类型均为{answer_text}'
|
||||||
|
else:
|
||||||
|
return (f'The problem has multiple answers, each of them '
|
||||||
|
f'should be {answer_text}. ')
|
||||||
|
# Different answer types
|
||||||
|
answer_types = answer_type.split(',')
|
||||||
|
answer_types = [
|
||||||
|
get_single_answer_type_text(t, is_chinese) for t in answer_types
|
||||||
|
]
|
||||||
|
if len(set(answer_types)) == 1:
|
||||||
|
answer_text = answer_types[0]
|
||||||
|
if is_chinese:
|
||||||
|
return f',题目有多个答案,答案类型均为{answer_text}'
|
||||||
|
else:
|
||||||
|
return (f'The problem has multiple answers, each of them '
|
||||||
|
f'should be {answer_text}. ')
|
||||||
|
else:
|
||||||
|
if is_chinese:
|
||||||
|
answer_text = '、'.join(answer_types)
|
||||||
|
return f',题目有多个答案,答案类型分别为{answer_text}'
|
||||||
|
else:
|
||||||
|
answer_text = ', '.join(answer_types)
|
||||||
|
return (f'The problem has multiple answers, '
|
||||||
|
f'with the answers in order being {answer_text}. ')
|
||||||
|
|
||||||
|
|
||||||
|
class OlympiadBenchPrompter:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def make_prompt(
|
||||||
|
self,
|
||||||
|
language,
|
||||||
|
subject,
|
||||||
|
question_type,
|
||||||
|
answer_type,
|
||||||
|
is_multiple_answer,
|
||||||
|
unit,
|
||||||
|
):
|
||||||
|
self.is_chinese = language == 'Chinese'
|
||||||
|
self.is_math = subject == 'Math'
|
||||||
|
self.is_theorem_proving = question_type == 'Theorem proof'
|
||||||
|
"""Generate prompt based on question properties."""
|
||||||
|
if self.is_chinese:
|
||||||
|
subject_content = '数学' if self.is_math else '物理'
|
||||||
|
if self.is_theorem_proving:
|
||||||
|
prompt = (f'以下是中国{subject_content}竞赛中的证明题。请根据题目的要求,'
|
||||||
|
f'运用逻辑推理及常用定理证明题目中的命题。证明过程中使用的变量和公式请使用LaTeX格式表示。')
|
||||||
|
else:
|
||||||
|
answer_type_text = get_answer_type_text(
|
||||||
|
answer_type,
|
||||||
|
is_chinese=True,
|
||||||
|
multiple_answer=is_multiple_answer,
|
||||||
|
)
|
||||||
|
if is_multiple_answer:
|
||||||
|
multiple_answer_text = '\\boxed{用英文逗号连接的多个答案}'
|
||||||
|
else:
|
||||||
|
multiple_answer_text = '\\boxed{答案}'
|
||||||
|
unit_text = ''
|
||||||
|
if unit:
|
||||||
|
multiple_answer_text += '(单位)'
|
||||||
|
unit_text = ',注意答案的单位不要放在\\boxed{}中'
|
||||||
|
prompt = (f'以下是中国{subject_content}竞赛中的解答题{answer_type_text}。'
|
||||||
|
f'请根据题目的要求和所提供的信息计算得出答案。解答过程和结果中使用的'
|
||||||
|
f'变量和公式请使用LaTeX格式表示。请在最后以"所以最终答案是'
|
||||||
|
f'{multiple_answer_text}。"显式给出结果{unit_text}。')
|
||||||
|
else:
|
||||||
|
subject_content = 'Math' if self.is_math else 'Physics'
|
||||||
|
if self.is_theorem_proving:
|
||||||
|
prompt = (
|
||||||
|
f'The following is a theorem proving problem from an '
|
||||||
|
f'International {subject_content} competition. Please use '
|
||||||
|
f'logical reasoning and common theorems to prove the '
|
||||||
|
f'proposition in the problem according to the given '
|
||||||
|
f'requirements. Please use LaTeX format to represent the '
|
||||||
|
f'variables and formulas used in the proof.')
|
||||||
|
else:
|
||||||
|
if is_multiple_answer:
|
||||||
|
multiple_answer_text = (
|
||||||
|
'\\boxed{multiple answers connected with commas}')
|
||||||
|
else:
|
||||||
|
multiple_answer_text = '\\boxed{answer}'
|
||||||
|
unit_text = ''
|
||||||
|
if unit:
|
||||||
|
multiple_answer_text += '(unit)'
|
||||||
|
unit_text = (', note that the unit of the answer should '
|
||||||
|
'not be included in \\boxed{}')
|
||||||
|
answer_type_text = get_answer_type_text(
|
||||||
|
answer_type,
|
||||||
|
is_chinese=False,
|
||||||
|
multiple_answer=is_multiple_answer,
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
f'The following is an open-ended problem from an '
|
||||||
|
f'International {subject_content} competition. '
|
||||||
|
f'{answer_type_text}Please calculate the answer according '
|
||||||
|
f'to the given requirements and the information provided. '
|
||||||
|
f'Please use LaTeX format to represent the variables and '
|
||||||
|
f'formulas used in the solution process and results. '
|
||||||
|
f'Please end your solution with "So the final answer is '
|
||||||
|
f'{multiple_answer_text}." and give the result explicitly'
|
||||||
|
f'{unit_text}.')
|
||||||
|
# Add problem statement to the prompt
|
||||||
|
prompt = prompt + '\n' + '{problem}' + '\n'
|
||||||
|
# Add step-by-step reasoning instruction
|
||||||
|
if self.is_chinese:
|
||||||
|
prompt += '\n请通过逐步推理来解答问题,并把最终答案放置于\\boxed{}中。'
|
||||||
|
else:
|
||||||
|
prompt += ('\nPlease reason step by step, and put your final '
|
||||||
|
'answer within \\boxed{}.')
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
|
||||||
|
|
||||||
|
class MathJudger:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.special_signal_map = {
|
||||||
|
'\\left': '',
|
||||||
|
'\\right': '',
|
||||||
|
'∶': ':',
|
||||||
|
',': ',',
|
||||||
|
'$': '',
|
||||||
|
'\\approx': '=',
|
||||||
|
'\\simeq': '=',
|
||||||
|
'\\sim': '=',
|
||||||
|
'^\\prime': "'",
|
||||||
|
'^{\\prime}': "'",
|
||||||
|
'^\\circ': '',
|
||||||
|
'%': '',
|
||||||
|
}
|
||||||
|
self.pi = parse_latex('\\pi')
|
||||||
|
self.precision = 1e-8
|
||||||
|
|
||||||
|
def split_by_comma(self, expr: str):
|
||||||
|
in_bracket_num = 0
|
||||||
|
splitted_expr = []
|
||||||
|
start_idx = 0
|
||||||
|
for i, char in enumerate(expr):
|
||||||
|
if char == '(' or char == '[':
|
||||||
|
in_bracket_num += 1
|
||||||
|
elif char == ')' or char == ']':
|
||||||
|
in_bracket_num -= 1
|
||||||
|
elif char == ',' and in_bracket_num == 0:
|
||||||
|
splitted_expr.append(expr[start_idx:i].strip())
|
||||||
|
start_idx = i + 1
|
||||||
|
|
||||||
|
if start_idx < len(expr):
|
||||||
|
splitted_expr.append(expr[start_idx:].strip())
|
||||||
|
|
||||||
|
return splitted_expr
|
||||||
|
|
||||||
|
def trans_plus_minus_sign(self, expr_list: list):
|
||||||
|
new_expr_list = []
|
||||||
|
for expr in expr_list:
|
||||||
|
if '\\pm' in expr:
|
||||||
|
new_expr_list.append(expr.replace('\\pm', '+'))
|
||||||
|
new_expr_list.append(expr.replace('\\pm', '-'))
|
||||||
|
else:
|
||||||
|
new_expr_list.append(expr)
|
||||||
|
|
||||||
|
return new_expr_list
|
||||||
|
|
||||||
|
def judge(self, expression1, expression2, precision=1e-8):
|
||||||
|
# (默认 expression1 为 Ground_Truth)
|
||||||
|
precision = precision if type(precision) == list else [precision]
|
||||||
|
|
||||||
|
try:
|
||||||
|
expression1, expression2 = self.preprocess(expression1,
|
||||||
|
expression2)
|
||||||
|
except Exception: # 处理具体异常
|
||||||
|
return False
|
||||||
|
if expression1 == expression2:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 去除字符串中的中文字符
|
||||||
|
expression1 = re.sub(r'[\u4e00-\u9fff]+', '', expression1)
|
||||||
|
expression2 = re.sub(r'[\u4e00-\u9fff]+', '', expression2)
|
||||||
|
|
||||||
|
expression1 = self.split_by_comma(expression1)
|
||||||
|
expression2 = self.split_by_comma(expression2)
|
||||||
|
|
||||||
|
temp_list1 = self.trans_plus_minus_sign(expression1)
|
||||||
|
temp_list2 = self.trans_plus_minus_sign(expression2)
|
||||||
|
|
||||||
|
# 设计误差值列表
|
||||||
|
if len(precision) <= 1:
|
||||||
|
precision = precision * len(temp_list1)
|
||||||
|
|
||||||
|
if len(temp_list1) != len(temp_list2):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 判断两个列表中的元素是否可以两两配对,并且两两相等
|
||||||
|
idx = -1
|
||||||
|
while len(temp_list1) != 0:
|
||||||
|
idx = (idx + 1) % len(temp_list1)
|
||||||
|
|
||||||
|
item1 = temp_list1[idx]
|
||||||
|
self.precision = precision[idx]
|
||||||
|
|
||||||
|
for item2 in temp_list2:
|
||||||
|
if self.is_equal(item1, item2):
|
||||||
|
temp_list1.remove(item1)
|
||||||
|
temp_list2.remove(item2)
|
||||||
|
precision.remove(self.precision)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 如果所有元素都匹配并移除,列表可以配对
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_interval(self, epr):
|
||||||
|
return epr.startswith(('(', '[')) and epr.endswith((')', ']'))
|
||||||
|
|
||||||
|
def sympy_sub_pi(self, expression_sympy):
|
||||||
|
return expression_sympy.subs(self.pi, math.pi)
|
||||||
|
|
||||||
|
def is_equal(self, expression1, expression2):
|
||||||
|
if (expression1 == expression2 and expression1 != ''
|
||||||
|
and expression2 != ''):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 先判断是否是两个区间
|
||||||
|
if self.is_interval(expression1) and self.is_interval(expression2):
|
||||||
|
try:
|
||||||
|
if self.interval_equal(expression1, expression2):
|
||||||
|
return True
|
||||||
|
except Exception: # 处理具体异常
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 再判断是否在数值上相等
|
||||||
|
try:
|
||||||
|
if self.numerical_equal(expression1, expression2):
|
||||||
|
return True
|
||||||
|
except Exception: # 处理具体异常
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 再判断是否是表达式相等
|
||||||
|
try:
|
||||||
|
if self.expression_equal(
|
||||||
|
expression1, expression2) and not ('=' in expression1
|
||||||
|
and '=' in expression2):
|
||||||
|
return True
|
||||||
|
except Exception: # 处理具体异常
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 再判断是否是等式相等
|
||||||
|
try:
|
||||||
|
if self.equation_equal(expression1, expression2):
|
||||||
|
return True
|
||||||
|
except Exception: # 处理具体异常
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def numerical_equal(
|
||||||
|
self,
|
||||||
|
expression1: str,
|
||||||
|
expression2: str,
|
||||||
|
include_percentage: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
(默认 expression1 为 Ground_Truth)
|
||||||
|
函数: 判读两个数值是否在误差允许范围内相等
|
||||||
|
步骤1: 将可能出现的百分号的情况包含进来
|
||||||
|
步骤2: 使用 math.isclose 函数判断是否相等
|
||||||
|
"""
|
||||||
|
reference = float(expression1)
|
||||||
|
prediction = float(expression2)
|
||||||
|
|
||||||
|
if include_percentage:
|
||||||
|
gt_result = [reference / 100, reference, reference * 100]
|
||||||
|
else:
|
||||||
|
gt_result = [reference]
|
||||||
|
|
||||||
|
for item in gt_result:
|
||||||
|
if abs(item - prediction) <= self.precision * 1.01:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def expression_equal(self, exp1, exp2):
|
||||||
|
"""
|
||||||
|
(默认 expression1 为 Ground_Truth)
|
||||||
|
函数: 判断两个表达式是否在数学意义上等价
|
||||||
|
步骤1: 提取表达式, 防止有的模型会给出"x=1"而不是"1"
|
||||||
|
步骤2: 使用 sympy 库进行等价判断
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 只提取等号右边的表达式
|
||||||
|
def extract_expression(expression):
|
||||||
|
if '=' in expression:
|
||||||
|
expression = expression.split('=')[1]
|
||||||
|
return expression.strip()
|
||||||
|
|
||||||
|
exp1 = extract_expression(exp1)
|
||||||
|
exp2 = extract_expression(exp2)
|
||||||
|
|
||||||
|
# 将表达式转换为 sympy 中能够进行处理的格式
|
||||||
|
expr1_sym = sympify(parse_latex(exp1))
|
||||||
|
expr2_sym = sympify(parse_latex(exp2))
|
||||||
|
|
||||||
|
if expr1_sym == expr2_sym:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
expr1_sym = self.sympy_sub_pi(expr1_sym)
|
||||||
|
expr2_sym = self.sympy_sub_pi(expr2_sym)
|
||||||
|
|
||||||
|
if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or (
|
||||||
|
not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
|
||||||
|
return False
|
||||||
|
elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
|
||||||
|
try:
|
||||||
|
if not (self.can_compute_power(expr1_sym)
|
||||||
|
and self.can_compute_power(expr2_sym)):
|
||||||
|
print(f'These two number can not be calculated by '
|
||||||
|
f'current computer for: '
|
||||||
|
f'"{str(expr1_sym)}" and "{str(expr2_sym)}"')
|
||||||
|
return False
|
||||||
|
|
||||||
|
if (abs(expr1_sym.evalf() - expr2_sym.evalf()) <=
|
||||||
|
self.precision * 1.01):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except Exception: # 处理具体异常
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
simplified_expr = simplify(expr1_sym - expr2_sym)
|
||||||
|
|
||||||
|
num_value = simplified_expr.evalf()
|
||||||
|
|
||||||
|
return abs(num_value) < 1e-3
|
||||||
|
except Exception: # 处理具体异常
|
||||||
|
return False
|
||||||
|
|
||||||
|
def equation_equal(self, expression1, expression2):
|
||||||
|
"""
|
||||||
|
(expression1 is assumed to be Ground_Truth)
|
||||||
|
Function: Check if two equations are mathematically equivalent
|
||||||
|
Step 1: Simplify equations to standard form with right side equal to 0
|
||||||
|
Step 2: Use sympy library to calculate quotient of left sides,
|
||||||
|
if quotient or its reciprocal is integer, equations are equivalent
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Convert equations to sympy format with right side moved to left side
|
||||||
|
def simplify_equation(latex_eq):
|
||||||
|
# Split left and right sides of equation
|
||||||
|
lhs, rhs = latex_eq.split('=')
|
||||||
|
|
||||||
|
# Parse LaTeX expressions using parse_latex
|
||||||
|
lhs_expr = parse_latex(lhs)
|
||||||
|
rhs_expr = parse_latex(rhs)
|
||||||
|
|
||||||
|
# Create equation object
|
||||||
|
equation = Eq(lhs_expr, rhs_expr)
|
||||||
|
|
||||||
|
# Simplify equation by moving right side to left
|
||||||
|
simplified_eq = simplify(equation.lhs - equation.rhs)
|
||||||
|
|
||||||
|
return simplified_eq
|
||||||
|
|
||||||
|
expr1_sym = simplify_equation(expression1)
|
||||||
|
expr2_sym = simplify_equation(expression2)
|
||||||
|
|
||||||
|
division_result_1 = simplify(expr1_sym / expr2_sym)
|
||||||
|
division_result_2 = simplify(expr2_sym / expr1_sym)
|
||||||
|
|
||||||
|
# If division result or its reciprocal is
|
||||||
|
# non-zero integer, equations are equivalent
|
||||||
|
if (division_result_1.is_Integer
|
||||||
|
and division_result_1 != 0) or (division_result_2.is_Integer
|
||||||
|
and division_result_2 != 0):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def interval_equal(self, expression1, expression2):
|
||||||
|
"""
|
||||||
|
Function: Check if two intervals are mathematically equivalent
|
||||||
|
Step 1: Simplify interval expressions,
|
||||||
|
remove irrelevant symbols
|
||||||
|
like "\\left", "\\right", and "x \\in"
|
||||||
|
Step 2: Compare brackets and mathematical expressions in between
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compare_two_interval(inter1, inter2):
|
||||||
|
# First compare brackets on both sides
|
||||||
|
if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
inter1 = inter1.strip('[]()')
|
||||||
|
inter2 = inter2.strip('[]()')
|
||||||
|
|
||||||
|
# Split interval into left and right parts
|
||||||
|
items_1 = inter1.split(',')
|
||||||
|
items_2 = inter2.split(',')
|
||||||
|
|
||||||
|
for item_1, item_2 in zip(items_1, items_2):
|
||||||
|
if not self.expression_equal(item_1, item_2):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
interval1 = expression1
|
||||||
|
interval2 = expression2
|
||||||
|
|
||||||
|
if interval1 == interval2:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
inter_list1 = interval1.split('\\cup')
|
||||||
|
inter_list2 = interval2.split('\\cup')
|
||||||
|
|
||||||
|
if len(inter_list1) != len(inter_list2):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
for inter1, inter2 in zip(inter_list1, inter_list2):
|
||||||
|
if not compare_two_interval(inter1, inter2):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def preprocess(self, expression1, expression2):
|
||||||
|
"""Extract and preprocess expressions from model output."""
|
||||||
|
|
||||||
|
def extract_boxed_content(latex_str):
|
||||||
|
# Find all \boxed{...} structures
|
||||||
|
boxed_matches = re.finditer(r'\\boxed{', latex_str)
|
||||||
|
results = ''
|
||||||
|
|
||||||
|
for match in boxed_matches:
|
||||||
|
start_index = match.end()
|
||||||
|
end_index = start_index
|
||||||
|
stack = 1
|
||||||
|
|
||||||
|
# Search from after \boxed{ until
|
||||||
|
# finding matching closing brace
|
||||||
|
while stack > 0 and end_index < len(latex_str):
|
||||||
|
if latex_str[end_index] == '{':
|
||||||
|
stack += 1
|
||||||
|
elif latex_str[end_index] == '}':
|
||||||
|
stack -= 1
|
||||||
|
end_index += 1
|
||||||
|
|
||||||
|
if stack == 0:
|
||||||
|
# Extract content inside \boxed{}
|
||||||
|
content = latex_str[start_index:end_index - 1]
|
||||||
|
results += content + ','
|
||||||
|
else:
|
||||||
|
raise ValueError('Mismatched braces in LaTeX string.')
|
||||||
|
|
||||||
|
# If no \boxed{} found, extract formulas from last line
|
||||||
|
if results == '':
|
||||||
|
last_line_ans = latex_str.strip().split('\n')[-1]
|
||||||
|
dollar_pattern = r'\$(.*?)\$'
|
||||||
|
answers = re.findall(dollar_pattern, last_line_ans)
|
||||||
|
|
||||||
|
if answers:
|
||||||
|
for ans in answers:
|
||||||
|
results += ans + ','
|
||||||
|
else:
|
||||||
|
results = latex_str
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def special_symbol_replace(expression):
|
||||||
|
if '\\in ' in expression:
|
||||||
|
expression = expression.split('\\in ')[1]
|
||||||
|
|
||||||
|
# Replace special characters that
|
||||||
|
# don't affect LaTeX parsing (decorative)
|
||||||
|
for signal in self.special_signal_map:
|
||||||
|
expression = expression.replace(
|
||||||
|
signal, self.special_signal_map[signal])
|
||||||
|
|
||||||
|
expression = expression.strip('\n$,.:;^_=+`!@#$%^&*~,。')
|
||||||
|
|
||||||
|
pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
|
||||||
|
expression = re.sub(pattern, r'\1', expression)
|
||||||
|
|
||||||
|
return expression
|
||||||
|
|
||||||
|
exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(
|
||||||
|
expression2)
|
||||||
|
exp1, exp2 = special_symbol_replace(exp1), special_symbol_replace(exp2)
|
||||||
|
|
||||||
|
return exp1, exp2
|
||||||
|
|
||||||
|
def can_compute_power(self, expr):
|
||||||
|
"""Check if the power expression can be computed.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
expr (sympy expression): The expression to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the expression can be computed, False otherwise.
|
||||||
|
"""
|
||||||
|
# Check if the expression is a power expression
|
||||||
|
if isinstance(expr, Pow):
|
||||||
|
# Extract the base and the exponent
|
||||||
|
base, exp = expr.as_base_exp()
|
||||||
|
|
||||||
|
# Check if the base and the exponent are numbers
|
||||||
|
if base.is_number and exp.is_number:
|
||||||
|
# Set a threshold for the maximum size of the exponent
|
||||||
|
# can be adjusted based on the computing environment
|
||||||
|
MAX_EXP = 1000
|
||||||
|
|
||||||
|
# Check if the exponent is greater than the threshold
|
||||||
|
if abs(exp.evalf()) > MAX_EXP:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# If the base or the exponent is not a number,
|
||||||
|
# we cannot compute the power
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# If the expression is not a power expression,
|
||||||
|
# return True as it is not the case we are checking for
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@TEXT_POSTPROCESSORS.register_module('olympiadbench_postprocess_v2')
|
||||||
|
def olympiadbench_postprocess_v2(text: str,
|
||||||
|
is_chinese: bool = False,
|
||||||
|
is_deepseek: bool = False) -> str:
|
||||||
|
"""Extract answer from model output."""
|
||||||
|
# deepseekmath has special answering format
|
||||||
|
if is_deepseek:
|
||||||
|
if is_chinese:
|
||||||
|
matches = re.findall('## 解题答案(.*)', text)
|
||||||
|
else:
|
||||||
|
matches = re.findall('The answer is: (.*)', text)
|
||||||
|
else:
|
||||||
|
if is_chinese:
|
||||||
|
matches = re.findall('所以最终答案是(.*)', text)
|
||||||
|
else:
|
||||||
|
matches = re.findall('So the final answer is (.*)', text)
|
||||||
|
|
||||||
|
# If found matches, take the last one, otherwise return the whole text
|
||||||
|
if matches:
|
||||||
|
return matches[-1].strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class OlympiadBenchEvaluator(BaseEvaluator):
|
||||||
|
"""Evaluator for OlympiadBench dataset."""
|
||||||
|
|
||||||
|
def __init__(self, version='v1'):
|
||||||
|
assert version in ['v1', 'v2']
|
||||||
|
self.version = version
|
||||||
|
self.judger = MathJudger()
|
||||||
|
|
||||||
|
def score(self, predictions, references): # Remove questions parameter
|
||||||
|
"""Calculate accuracy score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictions (list): List of model predictions
|
||||||
|
references (list): List of ground truth answers
|
||||||
|
"""
|
||||||
|
if len(predictions) != len(references):
|
||||||
|
return {
|
||||||
|
'error': 'predictions and references have different length'
|
||||||
|
}
|
||||||
|
|
||||||
|
correct = 0
|
||||||
|
count = 0
|
||||||
|
details = []
|
||||||
|
|
||||||
|
for pred, ref in zip(predictions, references):
|
||||||
|
detail = {'pred': pred, 'answer': ref, 'correct': False}
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# Get precision/error threshold from reference if available
|
||||||
|
precision = 1e-8
|
||||||
|
if isinstance(ref, dict) and 'error' in ref:
|
||||||
|
if ',' in ref['error']:
|
||||||
|
# Multiple precisions for multiple answers
|
||||||
|
precisions = ref['error'].split(',')
|
||||||
|
precisions = [float(p) if p else 1e-8 for p in precisions]
|
||||||
|
precision = precisions
|
||||||
|
else:
|
||||||
|
precision = float(ref['error'])
|
||||||
|
|
||||||
|
# Check if answer is correct
|
||||||
|
try:
|
||||||
|
if (isinstance(ref, dict) and 'answer_type' in ref
|
||||||
|
and 'Tuple' in ref['answer_type']):
|
||||||
|
# Special handling for tuple type answers
|
||||||
|
is_correct = self.judger.judge(pred,
|
||||||
|
ref['final_answer'][0],
|
||||||
|
precision)
|
||||||
|
else:
|
||||||
|
is_correct = self.judger.judge(pred, ref, precision)
|
||||||
|
|
||||||
|
if is_correct:
|
||||||
|
correct += 1
|
||||||
|
detail['correct'] = True
|
||||||
|
except Exception as e: # 处理具体异常
|
||||||
|
detail['error'] = str(e)
|
||||||
|
|
||||||
|
details.append(detail)
|
||||||
|
|
||||||
|
result = {'accuracy': 100 * correct / count, 'details': details}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ICL_PROMPT_TEMPLATES.register_module()
|
||||||
|
class OlympiadBenchTemplate(PromptTemplate):
|
||||||
|
"""Template for OlympiadBench dataset."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Define basic template structure
|
||||||
|
template = dict(round=[dict(role='HUMAN', prompt='{prompt}')])
|
||||||
|
super().__init__(template=template)
|
||||||
|
self.prompter = OlympiadBenchPrompter()
|
||||||
|
|
||||||
|
def generate_item(self, entry: Dict, *args, **kwargs) -> str:
|
||||||
|
"""Generate prompt for a single item."""
|
||||||
|
problem = entry.get('problem', '')
|
||||||
|
language = entry.get('language', 'English')
|
||||||
|
subject = entry.get('subject', 'Math')
|
||||||
|
question_type = entry.get('question_type', '')
|
||||||
|
answer_type = entry.get('answer_type', '')
|
||||||
|
is_multiple_answer = entry.get('is_multiple_answer', False)
|
||||||
|
unit = entry.get('unit', '')
|
||||||
|
|
||||||
|
prompt = self.prompter.make_prompt(
|
||||||
|
language=language,
|
||||||
|
subject=subject,
|
||||||
|
question_type=question_type,
|
||||||
|
answer_type=answer_type,
|
||||||
|
is_multiple_answer=is_multiple_answer,
|
||||||
|
unit=unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_entry = {'prompt': prompt, 'problem': problem}
|
||||||
|
|
||||||
|
return super().generate_item(new_entry, *args, **kwargs)
|
@ -103,6 +103,7 @@ from .natural_question import * # noqa: F401, F403
|
|||||||
from .natural_question_cn import * # noqa: F401, F403
|
from .natural_question_cn import * # noqa: F401, F403
|
||||||
from .NPHardEval import * # noqa: F401, F403
|
from .NPHardEval import * # noqa: F401, F403
|
||||||
from .obqa import * # noqa: F401, F403
|
from .obqa import * # noqa: F401, F403
|
||||||
|
from .OlympiadBench import * # noqa: F401, F403
|
||||||
from .OpenFinData import * # noqa: F401, F403
|
from .OpenFinData import * # noqa: F401, F403
|
||||||
from .piqa import * # noqa: F401, F403
|
from .piqa import * # noqa: F401, F403
|
||||||
from .py150 import * # noqa: F401, F403
|
from .py150 import * # noqa: F401, F403
|
||||||
|
@ -398,9 +398,18 @@ DATASETS_MAPPING = {
|
|||||||
"hf_id": "THUDM/LongBench-v2",
|
"hf_id": "THUDM/LongBench-v2",
|
||||||
"local": "./data/longbenchv2/data.json",
|
"local": "./data/longbenchv2/data.json",
|
||||||
},
|
},
|
||||||
|
"opencompass/OlympiadBench": {
|
||||||
|
"ms_id": "",
|
||||||
|
"hf_id": "",
|
||||||
|
"local": "./data/OlympiadBench",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
DATASETS_URL = {
|
DATASETS_URL = {
|
||||||
|
"/OlympiadBench": {
|
||||||
|
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/OlympiadBench.zip",
|
||||||
|
"md5": "97e8b1ae7f6170d94817288a8930ef00",
|
||||||
|
},
|
||||||
"/longbenchv2":{
|
"/longbenchv2":{
|
||||||
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/longbenchv2.zip",
|
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/longbenchv2.zip",
|
||||||
"md5": "09b7e06e6f98c5cca8ad597b3d7b42f0",
|
"md5": "09b7e06e6f98c5cca8ad597b3d7b42f0",
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Alpaca-eval
|
# Alpaca-eval
|
||||||
alpaca-eval==0.6
|
alpaca-eval==0.6
|
||||||
|
# OlympiadBench
|
||||||
|
antlr4-python3-runtime==4.11
|
||||||
cn2an
|
cn2an
|
||||||
# Dingo
|
# Dingo
|
||||||
dingo-python==1.1.2
|
dingo-python==1.1.2
|
||||||
|
Loading…
Reference in New Issue
Block a user