[Feature] Support OlympiadBench Benchmark (#1841)

* Support OlympiadBench Benchmark

* Support OlympiadBench Benchmark

* Support OlympiadBench Benchmark

* update dataset path

* Update olmpiadBench

* Update olmpiadBench

* Update olmpiadBench

---------

Co-authored-by: liushz <qq1791167085@163.com>
This commit is contained in:
Shudong Liu 2025-01-24 10:00:01 +08:00 committed by GitHub
parent 70f2c963d3
commit 412199f802
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 932 additions and 0 deletions

View File

@ -0,0 +1,36 @@
from mmengine.config import read_base
with read_base():
from opencompass.configs.datasets.OlympiadBench.OlympiadBench_0shot_gen_be8b13 import olympiadbench_datasets
from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import models as lmdeploy_qwen2_5_7b_instruct_model
from opencompass.configs.summarizers.OlympiadBench import summarizer
datasets = sum([v for k, v in locals().items() if k.endswith('_datasets') or k == 'datasets'], [])
models = sum([v for k, v in locals().items() if k.endswith('_model')], [])
from opencompass.runners import LocalRunner
from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner
from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask
infer = dict(
partitioner=dict(type=NumWorkerPartitioner, num_worker=8),
runner=dict(
type=LocalRunner,
max_num_workers=8,
task=dict(type=OpenICLInferTask)
),
)
eval = dict(
partitioner=dict(type=NaivePartitioner, n=10),
runner=dict(
type=LocalRunner,
max_num_workers=256,
task=dict(type=OpenICLEvalTask)
),
)
work_dir = 'outputs/debug/OlympiadBench'

View File

@ -0,0 +1,50 @@
from mmengine.config import read_base
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import OlympiadBenchDataset, OlympiadBenchEvaluator, olympiadbench_postprocess_v2
with read_base():
from .OlympiadBench_categories import categories
# Create prompter instance for problems
olympiadbench_prompter_cfg = dict(
type='OlympiadBenchPrompter'
)
olympiadbench_reader_cfg = dict(
input_columns=[
'problem', 'language', 'subject', 'question_type',
'answer_type', 'is_multiple_answer', 'unit', 'questions'
],
output_column='solution'
)
olympiadbench_datasets = []
for _name in categories:
olympiadbench_infer_cfg = dict(
prompt_template=dict(
type='OlympiadBenchTemplate'
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
olympiadbench_eval_cfg = dict(
evaluator=dict(type=OlympiadBenchEvaluator, version='v2'),
pred_postprocessor=dict(type=olympiadbench_postprocess_v2),
)
olympiadbench_datasets.append(
dict(
type=OlympiadBenchDataset,
abbr=f'OlympiadBench_{_name}',
path='opencompass/OlympiadBench',
name=_name,
reader_cfg=olympiadbench_reader_cfg,
infer_cfg=olympiadbench_infer_cfg,
eval_cfg=olympiadbench_eval_cfg,
)
)
del _name

View File

@ -0,0 +1,7 @@
categories = [
'OE_TO_maths_en_COMP', # OpenEnded - TextOnly - maths - COMP
'OE_TO_maths_zh_COMP', # OpenEnded - TextOnly - maths - COMP
'OE_TO_maths_zh_CEE', # OpenEnded - TextOnly - maths - CEE
'OE_TO_physics_en_COMP', # OpenEnded - TextOnly - physics - COMP
'OE_TO_physics_zh_CEE' # OpenEnded - TextOnly - physics - CEE
]

View File

@ -0,0 +1,15 @@
from mmengine.config import read_base
with read_base():
from .groups.OlympiadBench import OlympiadBench_summary_groups
summarizer = dict(
dataset_abbrs=[
'OlympiadBench_OE_TO_maths_en_COMP',
'OlympiadBench_OE_TO_maths_zh_COMP',
'OlympiadBench_OE_TO_maths_zh_CEE',
'OlympiadBench_OE_TO_physics_en_COMP',
'OlympiadBench_OE_TO_physics_zh_CEE'
],
summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),
)

View File

@ -0,0 +1,11 @@
categories = [
'OE_TO_maths_en_COMP', # OpenEnded - TextOnly - maths - COMP
'OE_TO_maths_zh_COMP', # OpenEnded - TextOnly - maths - COMP
'OE_TO_maths_zh_CEE', # OpenEnded - TextOnly - maths - CEE
'OE_TO_physics_en_COMP', # OpenEnded - TextOnly - physics - COMP
'OE_TO_physics_zh_CEE' # OpenEnded - TextOnly - physics - CEE
]
OlympiadBench_summary_groups = [
{'name': 'OlympiadBench', 'subsets': ['OlympiadBench_' + c.replace(' ', '_') for c in categories]},
]

View File

@ -0,0 +1,801 @@
import json
import math
import os
import re
from os import environ
from typing import Dict
import sympy as sp
from datasets import Dataset, DatasetDict
from sympy import Eq, Pow, simplify, sympify
from sympy.parsing.latex import parse_latex
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.registry import (ICL_PROMPT_TEMPLATES, LOAD_DATASET,
TEXT_POSTPROCESSORS)
from opencompass.utils import get_data_path
from .base import BaseDataset
# Load Dataset
@LOAD_DATASET.register_module()
class OlympiadBenchDataset(BaseDataset):
"""Dataset for OlympiadBench.
Args:
path (str): Path to dataset directory
name (str): Name of specific json file to load
e.g. 'OE_TO_maths_en_COMP'
"""
@staticmethod
def load(path: str, name: str = None, **kwargs):
"""Load dataset.
Args:
path (str): Path to dataset directory
name (str): Name of specific json file to load
Returns:
DatasetDict: Dataset with test and train splits
"""
path = get_data_path(path)
dataset = DatasetDict()
raw_data = []
if environ.get('DATASET_SOURCE') == 'ModelScope':
from modelscope import MsDataset
ms_dataset = MsDataset.load(path, split='train')
for item in ms_dataset:
raw_data.append({
'problem':
item['question'],
'solution':
item['final_answer'][0],
'language':
item['language'],
'subject':
item['subject'],
'question_type':
item['question_type'],
'answer_type':
item['answer_type'],
'is_multiple_answer':
item['is_multiple_answer'],
'unit':
item['unit'],
'error':
item['error'],
'questions':
item, # may not be used
})
else:
# Construct file path using name parameter
if name is None:
raise ValueError(
"Must specify 'name' parameter to load specific json file")
# file_path = os.path.join(path, name, f'{name}.json')
file_path = os.path.join(path, f'{name}.json')
if not os.path.exists(file_path):
raise FileNotFoundError(f'File not found: {file_path}')
# Load the specified json file
data = json.load(open(file_path, encoding='utf-8'))
for item in data:
raw_data.append({
'problem':
item['question'],
'solution':
item['final_answer'][0],
'language':
item['language'],
'subject':
item['subject'],
'question_type':
item['question_type'],
'answer_type':
item['answer_type'],
'is_multiple_answer':
item['is_multiple_answer'],
'unit':
item['unit'],
'error':
item['error'],
'questions':
item, # may not be used
})
dataset['test'] = Dataset.from_list(raw_data)
dataset['train'] = Dataset.from_list(raw_data)
return dataset
# Construct Prompt
def get_single_answer_type_text(answer_type, is_chinese):
if '-' in answer_type: # No need now
answer_type = answer_type[:answer_type.find('-')]
chinese_answer_type_dict = {
'Numerical': '数值',
'Expression': '表达式',
'Equation': '方程',
'Interval': '区间',
}
english_answer_type_dict = {
'Numerical': 'a numerical value',
'Expression': 'an expression',
'Equation': 'an equation',
'Interval': 'an interval',
}
for t in ['Numerical', 'Expression', 'Equation', 'Interval']:
if t in answer_type:
if is_chinese:
return chinese_answer_type_dict[t]
else:
return english_answer_type_dict[t]
raise ValueError(f'Error parsing answer type {answer_type}!')
def get_answer_type_text(answer_type, is_chinese, multiple_answer):
if ('Need_human_evaluate' in answer_type) or ('Tuple' in answer_type):
return ''
if not multiple_answer:
answer_text = get_single_answer_type_text(answer_type, is_chinese)
if is_chinese:
return f',答案类型为{answer_text}'
else:
return (f'The answer of The problem should be '
f'{answer_text}. ')
# Multiple answers case
if ',' not in answer_type: # Same answer type for all answers
answer_text = get_single_answer_type_text(answer_type, is_chinese)
if is_chinese:
return f',题目有多个答案,答案类型均为{answer_text}'
else:
return (f'The problem has multiple answers, each of them '
f'should be {answer_text}. ')
# Different answer types
answer_types = answer_type.split(',')
answer_types = [
get_single_answer_type_text(t, is_chinese) for t in answer_types
]
if len(set(answer_types)) == 1:
answer_text = answer_types[0]
if is_chinese:
return f',题目有多个答案,答案类型均为{answer_text}'
else:
return (f'The problem has multiple answers, each of them '
f'should be {answer_text}. ')
else:
if is_chinese:
answer_text = ''.join(answer_types)
return f',题目有多个答案,答案类型分别为{answer_text}'
else:
answer_text = ', '.join(answer_types)
return (f'The problem has multiple answers, '
f'with the answers in order being {answer_text}. ')
class OlympiadBenchPrompter:
def __init__(self):
pass
def make_prompt(
self,
language,
subject,
question_type,
answer_type,
is_multiple_answer,
unit,
):
self.is_chinese = language == 'Chinese'
self.is_math = subject == 'Math'
self.is_theorem_proving = question_type == 'Theorem proof'
"""Generate prompt based on question properties."""
if self.is_chinese:
subject_content = '数学' if self.is_math else '物理'
if self.is_theorem_proving:
prompt = (f'以下是中国{subject_content}竞赛中的证明题。请根据题目的要求,'
f'运用逻辑推理及常用定理证明题目中的命题。证明过程中使用的变量和公式请使用LaTeX格式表示。')
else:
answer_type_text = get_answer_type_text(
answer_type,
is_chinese=True,
multiple_answer=is_multiple_answer,
)
if is_multiple_answer:
multiple_answer_text = '\\boxed{用英文逗号连接的多个答案}'
else:
multiple_answer_text = '\\boxed{答案}'
unit_text = ''
if unit:
multiple_answer_text += '(单位)'
unit_text = ',注意答案的单位不要放在\\boxed{}'
prompt = (f'以下是中国{subject_content}竞赛中的解答题{answer_type_text}'
f'请根据题目的要求和所提供的信息计算得出答案。解答过程和结果中使用的'
f'变量和公式请使用LaTeX格式表示。请在最后以"所以最终答案是'
f'{multiple_answer_text}"显式给出结果{unit_text}')
else:
subject_content = 'Math' if self.is_math else 'Physics'
if self.is_theorem_proving:
prompt = (
f'The following is a theorem proving problem from an '
f'International {subject_content} competition. Please use '
f'logical reasoning and common theorems to prove the '
f'proposition in the problem according to the given '
f'requirements. Please use LaTeX format to represent the '
f'variables and formulas used in the proof.')
else:
if is_multiple_answer:
multiple_answer_text = (
'\\boxed{multiple answers connected with commas}')
else:
multiple_answer_text = '\\boxed{answer}'
unit_text = ''
if unit:
multiple_answer_text += '(unit)'
unit_text = (', note that the unit of the answer should '
'not be included in \\boxed{}')
answer_type_text = get_answer_type_text(
answer_type,
is_chinese=False,
multiple_answer=is_multiple_answer,
)
prompt = (
f'The following is an open-ended problem from an '
f'International {subject_content} competition. '
f'{answer_type_text}Please calculate the answer according '
f'to the given requirements and the information provided. '
f'Please use LaTeX format to represent the variables and '
f'formulas used in the solution process and results. '
f'Please end your solution with "So the final answer is '
f'{multiple_answer_text}." and give the result explicitly'
f'{unit_text}.')
# Add problem statement to the prompt
prompt = prompt + '\n' + '{problem}' + '\n'
# Add step-by-step reasoning instruction
if self.is_chinese:
prompt += '\n请通过逐步推理来解答问题,并把最终答案放置于\\boxed{}中。'
else:
prompt += ('\nPlease reason step by step, and put your final '
'answer within \\boxed{}.')
return prompt
# Evaluate
class MathJudger:
def __init__(self):
self.special_signal_map = {
'\\left': '',
'\\right': '',
'': ':',
'': ',',
'$': '',
'\\approx': '=',
'\\simeq': '=',
'\\sim': '=',
'^\\prime': "'",
'^{\\prime}': "'",
'^\\circ': '',
'%': '',
}
self.pi = parse_latex('\\pi')
self.precision = 1e-8
def split_by_comma(self, expr: str):
in_bracket_num = 0
splitted_expr = []
start_idx = 0
for i, char in enumerate(expr):
if char == '(' or char == '[':
in_bracket_num += 1
elif char == ')' or char == ']':
in_bracket_num -= 1
elif char == ',' and in_bracket_num == 0:
splitted_expr.append(expr[start_idx:i].strip())
start_idx = i + 1
if start_idx < len(expr):
splitted_expr.append(expr[start_idx:].strip())
return splitted_expr
def trans_plus_minus_sign(self, expr_list: list):
new_expr_list = []
for expr in expr_list:
if '\\pm' in expr:
new_expr_list.append(expr.replace('\\pm', '+'))
new_expr_list.append(expr.replace('\\pm', '-'))
else:
new_expr_list.append(expr)
return new_expr_list
def judge(self, expression1, expression2, precision=1e-8):
# (默认 expression1 为 Ground_Truth)
precision = precision if type(precision) == list else [precision]
try:
expression1, expression2 = self.preprocess(expression1,
expression2)
except Exception: # 处理具体异常
return False
if expression1 == expression2:
return True
# 去除字符串中的中文字符
expression1 = re.sub(r'[\u4e00-\u9fff]+', '', expression1)
expression2 = re.sub(r'[\u4e00-\u9fff]+', '', expression2)
expression1 = self.split_by_comma(expression1)
expression2 = self.split_by_comma(expression2)
temp_list1 = self.trans_plus_minus_sign(expression1)
temp_list2 = self.trans_plus_minus_sign(expression2)
# 设计误差值列表
if len(precision) <= 1:
precision = precision * len(temp_list1)
if len(temp_list1) != len(temp_list2):
return False
# 判断两个列表中的元素是否可以两两配对,并且两两相等
idx = -1
while len(temp_list1) != 0:
idx = (idx + 1) % len(temp_list1)
item1 = temp_list1[idx]
self.precision = precision[idx]
for item2 in temp_list2:
if self.is_equal(item1, item2):
temp_list1.remove(item1)
temp_list2.remove(item2)
precision.remove(self.precision)
break
else:
return False
# 如果所有元素都匹配并移除,列表可以配对
return True
def is_interval(self, epr):
return epr.startswith(('(', '[')) and epr.endswith((')', ']'))
def sympy_sub_pi(self, expression_sympy):
return expression_sympy.subs(self.pi, math.pi)
def is_equal(self, expression1, expression2):
if (expression1 == expression2 and expression1 != ''
and expression2 != ''):
return True
# 先判断是否是两个区间
if self.is_interval(expression1) and self.is_interval(expression2):
try:
if self.interval_equal(expression1, expression2):
return True
except Exception: # 处理具体异常
return False
# 再判断是否在数值上相等
try:
if self.numerical_equal(expression1, expression2):
return True
except Exception: # 处理具体异常
pass
# 再判断是否是表达式相等
try:
if self.expression_equal(
expression1, expression2) and not ('=' in expression1
and '=' in expression2):
return True
except Exception: # 处理具体异常
pass
# 再判断是否是等式相等
try:
if self.equation_equal(expression1, expression2):
return True
except Exception: # 处理具体异常
pass
return False
def numerical_equal(
self,
expression1: str,
expression2: str,
include_percentage: bool = True,
):
"""
(默认 expression1 Ground_Truth)
函数: 判读两个数值是否在误差允许范围内相等
步骤1: 将可能出现的百分号的情况包含进来
步骤2: 使用 math.isclose 函数判断是否相等
"""
reference = float(expression1)
prediction = float(expression2)
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
if abs(item - prediction) <= self.precision * 1.01:
return True
return False
def expression_equal(self, exp1, exp2):
"""
(默认 expression1 Ground_Truth)
函数: 判断两个表达式是否在数学意义上等价
步骤1: 提取表达式, 防止有的模型会给出"x=1"而不是"1"
步骤2: 使用 sympy 库进行等价判断
"""
# 只提取等号右边的表达式
def extract_expression(expression):
if '=' in expression:
expression = expression.split('=')[1]
return expression.strip()
exp1 = extract_expression(exp1)
exp2 = extract_expression(exp2)
# 将表达式转换为 sympy 中能够进行处理的格式
expr1_sym = sympify(parse_latex(exp1))
expr2_sym = sympify(parse_latex(exp2))
if expr1_sym == expr2_sym:
return True
else:
expr1_sym = self.sympy_sub_pi(expr1_sym)
expr2_sym = self.sympy_sub_pi(expr2_sym)
if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or (
not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)):
return False
elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol):
try:
if not (self.can_compute_power(expr1_sym)
and self.can_compute_power(expr2_sym)):
print(f'These two number can not be calculated by '
f'current computer for: '
f'"{str(expr1_sym)}" and "{str(expr2_sym)}"')
return False
if (abs(expr1_sym.evalf() - expr2_sym.evalf()) <=
self.precision * 1.01):
return True
else:
return False
except Exception: # 处理具体异常
return False
else:
try:
simplified_expr = simplify(expr1_sym - expr2_sym)
num_value = simplified_expr.evalf()
return abs(num_value) < 1e-3
except Exception: # 处理具体异常
return False
def equation_equal(self, expression1, expression2):
"""
(expression1 is assumed to be Ground_Truth)
Function: Check if two equations are mathematically equivalent
Step 1: Simplify equations to standard form with right side equal to 0
Step 2: Use sympy library to calculate quotient of left sides,
if quotient or its reciprocal is integer, equations are equivalent
"""
# Convert equations to sympy format with right side moved to left side
def simplify_equation(latex_eq):
# Split left and right sides of equation
lhs, rhs = latex_eq.split('=')
# Parse LaTeX expressions using parse_latex
lhs_expr = parse_latex(lhs)
rhs_expr = parse_latex(rhs)
# Create equation object
equation = Eq(lhs_expr, rhs_expr)
# Simplify equation by moving right side to left
simplified_eq = simplify(equation.lhs - equation.rhs)
return simplified_eq
expr1_sym = simplify_equation(expression1)
expr2_sym = simplify_equation(expression2)
division_result_1 = simplify(expr1_sym / expr2_sym)
division_result_2 = simplify(expr2_sym / expr1_sym)
# If division result or its reciprocal is
# non-zero integer, equations are equivalent
if (division_result_1.is_Integer
and division_result_1 != 0) or (division_result_2.is_Integer
and division_result_2 != 0):
return True
else:
return False
def interval_equal(self, expression1, expression2):
"""
Function: Check if two intervals are mathematically equivalent
Step 1: Simplify interval expressions,
remove irrelevant symbols
like "\\left", "\\right", and "x \\in"
Step 2: Compare brackets and mathematical expressions in between
"""
def compare_two_interval(inter1, inter2):
# First compare brackets on both sides
if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]:
return False
inter1 = inter1.strip('[]()')
inter2 = inter2.strip('[]()')
# Split interval into left and right parts
items_1 = inter1.split(',')
items_2 = inter2.split(',')
for item_1, item_2 in zip(items_1, items_2):
if not self.expression_equal(item_1, item_2):
return False
return True
interval1 = expression1
interval2 = expression2
if interval1 == interval2:
return True
else:
inter_list1 = interval1.split('\\cup')
inter_list2 = interval2.split('\\cup')
if len(inter_list1) != len(inter_list2):
return False
else:
for inter1, inter2 in zip(inter_list1, inter_list2):
if not compare_two_interval(inter1, inter2):
return False
return True
def preprocess(self, expression1, expression2):
"""Extract and preprocess expressions from model output."""
def extract_boxed_content(latex_str):
# Find all \boxed{...} structures
boxed_matches = re.finditer(r'\\boxed{', latex_str)
results = ''
for match in boxed_matches:
start_index = match.end()
end_index = start_index
stack = 1
# Search from after \boxed{ until
# finding matching closing brace
while stack > 0 and end_index < len(latex_str):
if latex_str[end_index] == '{':
stack += 1
elif latex_str[end_index] == '}':
stack -= 1
end_index += 1
if stack == 0:
# Extract content inside \boxed{}
content = latex_str[start_index:end_index - 1]
results += content + ','
else:
raise ValueError('Mismatched braces in LaTeX string.')
# If no \boxed{} found, extract formulas from last line
if results == '':
last_line_ans = latex_str.strip().split('\n')[-1]
dollar_pattern = r'\$(.*?)\$'
answers = re.findall(dollar_pattern, last_line_ans)
if answers:
for ans in answers:
results += ans + ','
else:
results = latex_str
return results
def special_symbol_replace(expression):
if '\\in ' in expression:
expression = expression.split('\\in ')[1]
# Replace special characters that
# don't affect LaTeX parsing (decorative)
for signal in self.special_signal_map:
expression = expression.replace(
signal, self.special_signal_map[signal])
expression = expression.strip('\n$,.:;^_=+`!@#$%^&*~,。')
pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}'
expression = re.sub(pattern, r'\1', expression)
return expression
exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(
expression2)
exp1, exp2 = special_symbol_replace(exp1), special_symbol_replace(exp2)
return exp1, exp2
def can_compute_power(self, expr):
"""Check if the power expression can be computed.
Parameters:
expr (sympy expression): The expression to check.
Returns:
bool: True if the expression can be computed, False otherwise.
"""
# Check if the expression is a power expression
if isinstance(expr, Pow):
# Extract the base and the exponent
base, exp = expr.as_base_exp()
# Check if the base and the exponent are numbers
if base.is_number and exp.is_number:
# Set a threshold for the maximum size of the exponent
# can be adjusted based on the computing environment
MAX_EXP = 1000
# Check if the exponent is greater than the threshold
if abs(exp.evalf()) > MAX_EXP:
return False
else:
return True
else:
# If the base or the exponent is not a number,
# we cannot compute the power
return False
else:
# If the expression is not a power expression,
# return True as it is not the case we are checking for
return True
@TEXT_POSTPROCESSORS.register_module('olympiadbench_postprocess_v2')
def olympiadbench_postprocess_v2(text: str,
is_chinese: bool = False,
is_deepseek: bool = False) -> str:
"""Extract answer from model output."""
# deepseekmath has special answering format
if is_deepseek:
if is_chinese:
matches = re.findall('## 解题答案(.*)', text)
else:
matches = re.findall('The answer is: (.*)', text)
else:
if is_chinese:
matches = re.findall('所以最终答案是(.*)', text)
else:
matches = re.findall('So the final answer is (.*)', text)
# If found matches, take the last one, otherwise return the whole text
if matches:
return matches[-1].strip()
return text
class OlympiadBenchEvaluator(BaseEvaluator):
"""Evaluator for OlympiadBench dataset."""
def __init__(self, version='v1'):
assert version in ['v1', 'v2']
self.version = version
self.judger = MathJudger()
def score(self, predictions, references): # Remove questions parameter
"""Calculate accuracy score.
Args:
predictions (list): List of model predictions
references (list): List of ground truth answers
"""
if len(predictions) != len(references):
return {
'error': 'predictions and references have different length'
}
correct = 0
count = 0
details = []
for pred, ref in zip(predictions, references):
detail = {'pred': pred, 'answer': ref, 'correct': False}
count += 1
# Get precision/error threshold from reference if available
precision = 1e-8
if isinstance(ref, dict) and 'error' in ref:
if ',' in ref['error']:
# Multiple precisions for multiple answers
precisions = ref['error'].split(',')
precisions = [float(p) if p else 1e-8 for p in precisions]
precision = precisions
else:
precision = float(ref['error'])
# Check if answer is correct
try:
if (isinstance(ref, dict) and 'answer_type' in ref
and 'Tuple' in ref['answer_type']):
# Special handling for tuple type answers
is_correct = self.judger.judge(pred,
ref['final_answer'][0],
precision)
else:
is_correct = self.judger.judge(pred, ref, precision)
if is_correct:
correct += 1
detail['correct'] = True
except Exception as e: # 处理具体异常
detail['error'] = str(e)
details.append(detail)
result = {'accuracy': 100 * correct / count, 'details': details}
return result
@ICL_PROMPT_TEMPLATES.register_module()
class OlympiadBenchTemplate(PromptTemplate):
"""Template for OlympiadBench dataset."""
def __init__(self):
# Define basic template structure
template = dict(round=[dict(role='HUMAN', prompt='{prompt}')])
super().__init__(template=template)
self.prompter = OlympiadBenchPrompter()
def generate_item(self, entry: Dict, *args, **kwargs) -> str:
"""Generate prompt for a single item."""
problem = entry.get('problem', '')
language = entry.get('language', 'English')
subject = entry.get('subject', 'Math')
question_type = entry.get('question_type', '')
answer_type = entry.get('answer_type', '')
is_multiple_answer = entry.get('is_multiple_answer', False)
unit = entry.get('unit', '')
prompt = self.prompter.make_prompt(
language=language,
subject=subject,
question_type=question_type,
answer_type=answer_type,
is_multiple_answer=is_multiple_answer,
unit=unit,
)
new_entry = {'prompt': prompt, 'problem': problem}
return super().generate_item(new_entry, *args, **kwargs)

View File

@ -103,6 +103,7 @@ from .natural_question import * # noqa: F401, F403
from .natural_question_cn import * # noqa: F401, F403
from .NPHardEval import * # noqa: F401, F403
from .obqa import * # noqa: F401, F403
from .OlympiadBench import * # noqa: F401, F403
from .OpenFinData import * # noqa: F401, F403
from .piqa import * # noqa: F401, F403
from .py150 import * # noqa: F401, F403

View File

@ -398,9 +398,18 @@ DATASETS_MAPPING = {
"hf_id": "THUDM/LongBench-v2",
"local": "./data/longbenchv2/data.json",
},
"opencompass/OlympiadBench": {
"ms_id": "",
"hf_id": "",
"local": "./data/OlympiadBench",
},
}
DATASETS_URL = {
"/OlympiadBench": {
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/OlympiadBench.zip",
"md5": "97e8b1ae7f6170d94817288a8930ef00",
},
"/longbenchv2":{
"url": "http://opencompass.oss-cn-shanghai.aliyuncs.com/datasets/data/longbenchv2.zip",
"md5": "09b7e06e6f98c5cca8ad597b3d7b42f0",

View File

@ -1,5 +1,7 @@
# Alpaca-eval
alpaca-eval==0.6
# OlympiadBench
antlr4-python3-runtime==4.11
cn2an
# Dingo
dingo-python==1.1.2