[Feature] Support LiveMathBench (#1727)

This commit is contained in:
Junnan Liu 2024-11-30 00:07:19 +08:00 committed by GitHub
parent b063779034
commit fe6d76fb13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 523 additions and 0 deletions

View File

@ -0,0 +1,74 @@
# LiveMathBench
## Details of Datsets
| dataset | language | #single-choice | #multiple-choice | #fill-in-the-blank | #problem-solving |
| -- | -- | -- | -- | -- | -- |
| AIMC | cn | 46 | 0 | 0 | 0 |
| AIMC | en | 46 | 0 | 0 | 0 |
| CEE | cn | 28 | 9 | 13 | 3 |
| CEE | en | 28 | 9 | 13 | 3 |
| CMO | cn | 0 | 0 | 0 | 18 |
| CMO | en | 0 | 0 | 0 | 18 |
## How to use
```python
from mmengine.config import read_base
with read_base():
from opencompass.datasets.livemathbench import livemathbench_datasets
livemathbench_datasets[0].update(
{
'path': '/path/to/data/dir',
'k': 'k@pass', # the max value of k in k@pass
'n': 'number of runs', # number of runs
}
)
livemathbench_datasets[0]['eval_cfg']['evaluator'].update(
{
'model_name': 'Qwen/Qwen2.5-72B-Instruct',
'url': [
'http://0.0.0.0:23333/v1',
'...'
] # set url of evaluation models
}
)
```
> ❗️ At present, `extract_from_boxed` is used to extract answers from model responses, and one can also leverage LLM for extracting through the following parameters, but this part of the code has not been tested.
```python
livemathbench_datasets[0]['eval_cfg']['evaluator'].update(
{
'model_name': 'Qwen/Qwen2.5-72B-Instruct',
'url': [
'http://0.0.0.0:23333/v1',
'...'
], # set url of evaluation models
# for LLM-based extraction
'use_extract_model': True,
'post_model_name': 'oc-extractor',
'post_url': [
'http://0.0.0.0:21006/v1,
'...'
]
}
)
```
## Output Samples
| dataset | version | metric | mode | Qwen2.5-72B-Instruct |
|----- | ----- | ----- | ----- | -----|
| LiveMathBench | caed8f | 1@pass | gen | 26.07 |
| LiveMathBench | caed8f | 1@pass/std | gen | xx.xx |
| LiveMathBench | caed8f | 2@pass | gen | xx.xx |
| LiveMathBench | caed8f | 2@pass/std | gen | xx.xx |
| LiveMathBench | caed8f | pass-rate | gen | xx.xx |

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .livemathbench_gen_caed8f import livemathbench_datasets # noqa: F401, F403

View File

@ -0,0 +1,49 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets.livemathbench import LiveMathBenchDataset, LiveMathBenchEvaluator
livemathbench_reader_cfg = dict(
input_columns=['prompt'],
output_column='answer'
)
livemathbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='{prompt}'),
]
)
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(
type=GenInferencer,
max_out_len=2048,
temperature=1.0
)
)
livemathbench_eval_cfg = dict(
evaluator=dict(
type=LiveMathBenchEvaluator,
model_name='Qwen/Qwen2.5-72B-Instruct',
url=[]
)
)
livemathbench_datasets = [
dict(
type=LiveMathBenchDataset,
abbr='LiveMathBench',
path='',
k=32,
n=5,
reader_cfg=livemathbench_reader_cfg,
infer_cfg=livemathbench_infer_cfg,
eval_cfg=livemathbench_eval_cfg
)
]

View File

@ -0,0 +1,2 @@
from .livemathbench import LiveMathBenchDataset # noqa: F401, F403
from .livemathbench import LiveMathBenchEvaluator # noqa: F401, F403

View File

@ -0,0 +1,324 @@
import concurrent.futures
import os
import re
from copy import deepcopy
from itertools import product
from typing import Any, Dict, List
import jsonlines
import numpy as np
from datasets import Dataset
from opencompass.models import OpenAISDK
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, MODELS
from ..base import BaseDataset
from .prompts import (EXTRACT_PROMPT_CN, EXTRACT_PROMPT_EN, JUDGE_PROMPT_CN,
JUDGE_PROMPT_EN, PROMPT_CN, PROMPT_EN)
@LOAD_DATASET.register_module()
class LiveMathBenchDataset(BaseDataset):
dataset_splits = ['AIMC', 'CEE', 'CMO']
dataset_languages = ['cn', 'en']
@staticmethod
def load(
path: str,
k: int,
n: int,
) -> List[Dict[str, Any]]:
dataset = []
dataset_info = {}
for split, language in product(LiveMathBenchDataset.dataset_splits,
LiveMathBenchDataset.dataset_languages):
file_path = os.path.join(path, f'{split}_{language}.jsonl')
dataset_info[f'{split}_{language}'] = {
'single-choice': 0,
'multiple-choice': 0,
'fill-in-the-blank': 0,
'problem-solving': 0
}
question_type_mapping = {
'单选': 'single-choice',
'多选': 'multiple-choice',
'填空': 'fill-in-the-blank',
'问答': 'problem-solving'
}
with jsonlines.open(file_path, 'r') as file:
for example_idx, example in enumerate(file):
dataset_info[f'{split}_{language}'][
example['question_type'] if language == 'en' else
question_type_mapping[example['question_type']]] += 1
prompt = PROMPT_EN if language == 'en' else PROMPT_CN
example.update({
'dataset_key':
f'{split}_{language}_{example_idx}',
'prompt':
prompt.format(question_type=example['question_type'],
question=example['question'] +
('' if 'options' not in example else
' '.join(example['options']))),
'k':
k,
'n':
n
})
for idx in range(k * n):
duplicated_example = deepcopy(example)
duplicated_example.update({'duplicated_idx': idx})
dataset.append(duplicated_example)
return Dataset.from_list(dataset)
@ICL_EVALUATORS.register_module()
class LiveMathBenchEvaluator(BaseEvaluator):
api_meta_template = dict(round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
])
def __init__(self,
model_name,
url,
with_postprocess=True,
use_extract_model=False,
post_url=[],
post_model_name='',
**kwargs):
if isinstance(url, str):
url = [url]
self.model = [
MODELS.build(
dict(
type=OpenAISDK,
path=model_name,
openai_api_base=url,
key='EMPTY',
query_per_second=2,
meta_template=self.api_meta_template,
temperature=kwargs.get('temperature', 0.01),
max_seq_len=kwargs.get('max_tokens', 2048),
)) for url in url
]
self.with_postprocess = with_postprocess
self.use_extract_model = use_extract_model
self.post_url = post_url
self.post_model_name = post_model_name
def batch_response(self, models: List[OpenAISDK],
inputs: List[str]) -> List[str]:
batch_num = len(models)
batch_size = (len(inputs) + batch_num - 1) // batch_num
result_responses = []
with concurrent.futures.ThreadPoolExecutor(
max_workers=batch_num) as executor:
futures = [
executor.submit(models[i].generate,
inputs[i * batch_size:(i + 1) * batch_size])
for i in range(batch_num)
]
for response in executor.map(lambda f: f.result(), futures):
result_responses.extend(response)
return result_responses
def postprocess(self, questions: List[str], predictions: List[str],
question_types: List[str],
languages: List[str]) -> List[str]:
if self.use_extract_model:
assert len(self.post_url) > 0 and self.post_model_name != ''
post_model = [
MODELS.build(
dict(
type=OpenAISDK,
path=self.post_model_name,
openai_api_base=url,
key='EMPTY',
query_per_second=2,
meta_template=self.api_meta_template,
temperature=0.01,
max_seq_len=1024,
)) for url in self.post_url
]
input_prompts = []
for question, prediction, question_type, language in zip(
questions, predictions, question_types, languages):
prompt = (EXTRACT_PROMPT_EN
if language == 'en' else EXTRACT_PROMPT_CN)
input_prompts.append(
prompt.format(question=question,
response=prediction,
question_type=question_type))
result_responses = self.batch_response(post_model, input_prompts)
return result_responses
def last_boxed_only_string(string):
idx = string.rfind('\\boxed')
if idx < 0:
idx = string.rfind('\\fbox')
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == '{':
num_left_braces_open += 1
if string[i] == '}':
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
def remove_boxed(s):
left = '\\boxed{'
try:
assert s[:len(left)] == left
assert s[-1] == '}'
return s[len(left):-1]
except Exception:
return None
def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
boxed_str = last_boxed_only_string(pred_str)
if boxed_str is None:
return None
answer = remove_boxed(boxed_str)
if answer is None:
return None
if strip_double_curly_brace:
match = re.match('^\{(.*)\}$', answer) # noqa: W605
if match:
answer = match.group(1)
return answer
predictions = [
extract_boxed_answer(prediction) for prediction in predictions
]
return predictions
def extract_boxed_answer(self, text):
match = re.findall(r'\\boxed{(.+?)}', text)
if match:
return match[-1]
return None
def score(self, predictions, references, origin_prompt, test_set):
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'}
questions = test_set['question']
question_types = test_set['question_type']
languages = [key.split('_')[1] for key in test_set['dataset_key']]
if self.with_postprocess:
predictions = self.postprocess(questions, predictions,
question_types, languages)
inputs = []
for prediction, reference, question, language in zip(
predictions, references, questions, languages):
prompt = JUDGE_PROMPT_EN if language == 'en' else JUDGE_PROMPT_CN
inputs.append(
prompt.format(answer=prediction,
gold_answer=reference,
question=question))
result_responses = self.batch_response(self.model, inputs)
results = [
self.extract_boxed_answer(result) == 'yes'
for result in result_responses
]
K = test_set['k'][0]
N = test_set['n'][0]
key2example = {}
for example, result_response, result, prediction in zip(
test_set, result_responses, results, predictions):
if example['dataset_key'] not in key2example:
key2example[example['dataset_key']] = []
example.update({
'eval_response': result_response,
'prediction': prediction,
'correct': result
})
key2example[example['dataset_key']].append(example)
for key in key2example:
key2example[key] = [
key2example[key][i * K:(i + 1) * K] for i in range(N)
]
count = []
total_pass_num = []
details = []
for key, examples in key2example.items():
detail = {
'question': examples[0][0]['question'],
'answer': examples[0][0]['answer'],
'responses': []
}
if_pass_list = []
for single_run_examples in examples:
detail['responses'].append([])
if_pass_list.append([])
for example in single_run_examples:
detail['responses'][-1].append({
'prediction':
example['prediction'],
'eval_response':
example['eval_response']
})
if_pass_list[-1].append(1.0 if example['correct'] else 0.0)
if_pass_list = [
sorted(if_pass, reverse=True) for if_pass in if_pass_list
]
if_pass_list = np.array(if_pass_list)
i = 1
while i <= K:
detail.update({
f'{i}@pass':
if_pass_list[:, :i].mean(axis=1).mean(axis=0).item(),
f'{i}@pass/std':
if_pass_list[:, :i].mean(axis=1).std(axis=0).item()
})
i = i * 2
count.append(np.ones_like(if_pass_list).sum(axis=1))
total_pass_num.append(if_pass_list.sum(axis=1))
details.append(detail)
detailed_result = {'details': details}
i = 1
while i <= K:
detailed_result.update({
f'{i}@pass':
100. * np.mean([detail[f'{i}@pass'] for detail in details]),
f'{i}@pass/std':
100. * np.mean([detail[f'{i}@pass/std'] for detail in details])
})
i = i * 2
detailed_result.update(
{'pass-rate': 100. * np.mean(sum(total_pass_num) / sum(count))})
return detailed_result

View File

@ -0,0 +1,70 @@
# flake8: noqa
EXTRACT_PROMPT_CN = '''你是一个乐于助人的助手,任务是从给定的回答句子中提取精确的关键答案。你必须只提供提取的关键答案,不包括任何额外的文字。
我将为你提供一个问题回答句子和问题类型回答句子是对所提供问题的回应利用提供的信息你必须准确而精确地确定并从回答句子中提取预期的关键答案请不要对问题发表主观看法
对于单选题答案应该是选项字母例如 "A"
对于多选题答案应该是一个选项字母的列表例如 ["A"] ["A", "B", "C"]
对于填空题答案应该是一个填入空白处的答案列表列表的数量应该与问题中的空白数量相同同一空白的答案可能有多个请在同一个 string 中用逗号隔开表示 ['sqrt(x) 且 x > 10', '1/2, 1/3', '1/4'] 代表问题包含三小问第一小问包含取值范围信息第二小问有两个答案第三小问有一个答案
对于解答题类似填空题答案应该是一个答案列表每小问的答案间用逗号隔开同样需要注意某些小问答案多个的情况
如果回答句子提供了多个不同的答案请仔细判断后面提供的答案是否是对前面答案的修正或修改如果是这样提取这个修正或修改后的答案作为最终答案相反如果回答句子在多个答案之间波动而没有明确的最终答案你应该输出 [No valid answer]
问题类型: {question_type}
原始问题: {question}
回答: {response}
提取的关键答案:
'''
EXTRACT_PROMPT_EN = '''You are a helpful assistant whose task is to extract precise key answers from given response sentences. You must only provide the extracted key answers without any additional text.
I will provide you with a question, a response sentence, and the question type. The response sentence is a reply to the provided question. Using the provided information, you must accurately and precisely identify and extract the expected key answers from the response sentence. Please do not provide subjective opinions about the question.
For single-choice questions, the answer should be the letter of the option, such as "A".
For multiple-choice questions, the answer should be a list of option letters, such as ["A"] or ["A", "B", "C"].
For fill-in-the-blank questions, the answer should be a list of answers to fill in the blanks. The number of items in the list should match the number of blanks in the question. If there are multiple answers for the same blank, separate them with a comma within the same string, like ['sqrt(x) and x > 10', '1/2, 1/3', '1/4'], which represents three sub-questions where the first sub-question includes a range, the second sub-question has two answers, and the third sub-question has one answer.
For problem-solving questions, similar to fill-in-the-blank questions, the answer should be a list of answers. Separate answers for different sub-questions with commas, and note that some sub-questions may have multiple answers.
If the response sentence provides multiple different answers, carefully determine whether a later provided answer is a correction or modification of an earlier answer. If so, extract this corrected or modified answer as the final answer. Conversely, if the response sentence fluctuates between multiple answers without a clear final answer, you should output [No valid answer].
Question type: {question_type}
Question: {question}
Output sentences: {response}
Key extracted answer:
'''
JUDGE_PROMPT_CN = '''请你作为一个数学阅卷专家,判断下面的答案是否与标准答案一致,即考生是否回答正确。下面是一些评判标准:
1. 有些答案可能包含多项内容可能有单选题多选题填空题和问答题只要答案与标准答案一致即可, 对于多选题和多个空的填空题需要考生对应的选项或空都回答正确才算正确
2. 有些答案可能通过不同的方式表达比如有些答案可能是一个数学表达式有些答案可能是一个文字描述只要表达的意思一致即可且有些公式通过不同的方式表达但等价也是正确的
3. 你不需要重新计算问题答案因为标准答案已经给出只需要根据问题形式来判断考生的答案是否与标准答案一致是否正确即可
请你根据上述标准判断下面的答案是否与标准答案一致如果一致请在最后输出\\boxed{{yes}}, 否则输出\\boxed{{no}}, 如果难以判断请输出\\boxed{{no}}.
原问题{question}
标准答案{gold_answer}
考生答案{answer}
分析
'''
JUDGE_PROMPT_EN = '''Please act as an expert in grading mathematics exam papers, and judge whether the following answers match the standard answers, i.e., whether the examinee answered correctly. Here are some evaluation criteria:
1. Some answers may contain multiple parts, such as single-choice questions, multiple-choice questions, fill-in-the-blank questions, and problem-solving questions. As long as the answer matches the standard answer, it is considered correct. For multiple-choice questions and fill-in-the-blank questions with multiple blanks, the examinee must answer all corresponding options or blanks correctly to be considered correct.
2. Some answers may be expressed in different ways; for example, some answers may be mathematical expressions, while others may be textual descriptions. As long as the meaning conveyed is consistent, it is considered correct. Additionally, some formulas may be expressed differently but are equivalent, which is also considered correct.
3. You do not need to recalculate the problem answers, as the standard answers are already provided. You only need to judge whether the examinee's answer matches the standard answer based on the form of the question and whether it is correct.
Please judge whether the following answer matches the standard answer according to the above criteria. If they match, output \\boxed{{yes}}, otherwise output \\boxed{{no}}. If it is difficult to judge, also output \\boxed{{no}}.
Original Question: {question}
Standard Answer: {gold_answer}
Examinee's Answer: {answer}
Analysis:
'''
PROMPT_CN = '''下面是一个{question_type}类型的数学问题,请逐步推理,并把最终答案放置于\\boxed{{}}中。
{question}
'''
PROMPT_EN = '''Here is a {question_type} type math problem, please reasoning step by step, and put your answer in \\boxed{{}}.
{question}
'''