[Feature] Add mathbench dataset and circular evaluator (#408)

* add_mathbench

* update mathbench

* support non circular eval dataset

---------

Co-authored-by: liuhongwei <liuhongwei@pjlab.org.cn>
Co-authored-by: yingfhu <yingfhu@gmail.com>
This commit is contained in:
liushz 2023-10-18 17:08:31 +08:00 committed by GitHub
parent fccfcb6f5b
commit 2737249f31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 322 additions and 0 deletions

View File

@ -0,0 +1,4 @@
from mmengine.config import read_base
with read_base():
from .mathbench_gen_86de1c import mathbench_datasets # noqa: F401, F403

View File

@ -0,0 +1,114 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import CircularEvaluator, AccEvaluator
from opencompass.datasets import MathBenchDataset, mathbench_postprocess
from opencompass.utils.text_postprocessors import first_capital_postprocess
single_choice_prompts = {
"single_choice_cn_with_reasoning": "以下是一道关于数学的单项选择题,请你一步一步推理并得到最终的答案选项。回答格式为如下:\n答案选项A、B、C、D中你认为正确的一个选项\n计算过程:根据题目得到选项答案的一步步过程\n请严格按照上面的格式回答问题,下面是你要回答的题目:\n{question}\n答案选项:",
"single_choice_cn": "以下是一道关于数学的单项选择题,请你直接给出正确的答案选项。回答格式为如下:\n答案选项A、B、C、D中你认为正确的选项\n下面是你要回答的题目:\n{question}\n答案选项:",
"single_choice_en_with_reasoning": "Here is a multiple-choice question about mathematics. Please provide the final answer option by step-by-step reasoning. Please answer in the following format:\nAnswer option: A, B, C, or D (the option you believe is correct)\nCalculation process: Step-by-step process to derive the answer option based on the question\nPlease strictly follow the above format to answer the question. Here is the question you need to answer:\n{question}\nAnswer option:",
"single_choice_en": "Here is a multiple-choice question about mathematics. Please provide the correct answer option directly. Please answer in the following format:\nAnswer option: A, B, C, or D (the option you believe is correct)\nHere is the question you need to answer:\n{question}\nAnswer option:",
}
cloze_prompts ={
"cloze_cn": [
dict(role='HUMAN', prompt='Q: 林中有15棵树。林务工人员今天将在林中种植树木。完成后将有21棵树。林务工人员今天种植了多少棵树'),
dict(role='BOT', prompt='A: 我们从15棵树开始。后来有21棵树。差值必定是他们种植的树木数量。所以他们必须种植了21 - 15 = 6棵树。答案是 6\n'),
dict(role='HUMAN', prompt='Q: 如果停车场有3辆车又有2辆车进来停车场里有多少辆车'),
dict(role='BOT', prompt='A: 停车场已经有3辆车。又进来了2辆车。现在有3 + 2 = 5辆车。答案是 5\n'),
dict(role='HUMAN', prompt='Q: Leah有32块巧克力她的妹妹有42块。如果他们吃了35块他们总共剩下多少块'),
dict(role='BOT', prompt='A: Leah有32块巧克力Leah的妹妹有42块。这意味着原本有32 + 42 = 74块巧克力。被吃掉了35块。所以他们总共还剩下74 - 35 = 39块巧克力。答案是 39\n'),
dict(role='HUMAN', prompt='Q: Jason有20个棒棒糖。他给Denny一些棒棒糖。现在Jason只剩下12个棒棒糖。Jason给Denny多少个棒棒糖'),
dict(role='BOT', prompt='A: Jason有20个棒棒糖。因为他现在只剩下12个所以他必须把剩下的都给了Denny。他给Denny的棒棒糖数量必定是20 - 12 = 8个。答案是 8\n'),
dict(role='HUMAN', prompt='Q: Shawn有五个玩具。在圣诞节他从他的爸爸和妈妈那里各得到了两个玩具。现在他有多少个玩具'),
dict(role='BOT', prompt='A: 他有5个玩具。他从妈妈那里得到了2个所以之后他有5 + 2 = 7个玩具。然后他从爸爸那里得到了2个所以总共他有7 + 2 = 9个玩具。答案是 9\n'),
dict(role='HUMAN', prompt='Q: 服务器房间里有九台电脑。从周一到周四每天增加五台电脑。现在服务器房里有多少台电脑?'),
dict(role='BOT', prompt='A: 从周一到周四有4天。每天增加5台电脑。这意味着总共增加了4 * 5 = 20台电脑。一开始有9台电脑所以现在有9 + 20 = 29台电脑。答案是 29\n'),
dict(role='HUMAN', prompt='Q: Michael有58个高尔夫球。星期二他丢失了23个高尔夫球。星期三他又丢失了2个。星期三结束时他还剩下多少个高尔夫球'),
dict(role='BOT', prompt='A: Michael一开始有58个球。星期二他丢失了23个所以之后他还剩下58 - 23 = 35个球。星期三他又丢失了2个所以现在他还剩下35 - 2 = 33个球。答案是 33\n'),
dict(role='HUMAN', prompt='Q: Olivia有23美元。她用每个3美元的价格买了五个百吉饼。她还剩下多少钱'),
dict(role='BOT', prompt='A: 她以每个3美元的价格买了5个百吉饼。这意味着她在百吉饼上花费了5 * 3 = 15美元。她一开始有23美元所以现在她还剩下23 - 15 = 8美元。答案是 8\n'),
dict(role='HUMAN', prompt='Q: {question}'),
dict(role='BOT', prompt='A: {answer}'),
],
"cloze_en": [
dict(role='HUMAN', prompt='Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?'),
dict(role='BOT', prompt='A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6.\n'),
dict(role='HUMAN', prompt='Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?'),
dict(role='BOT', prompt='A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.\n'),
dict(role='HUMAN', prompt='Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?'),
dict(role='BOT', prompt="A: Leah had 32 chocolates and Leah's sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39.\n"),
dict(role='HUMAN', prompt='Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?'),
dict(role='BOT', prompt='A: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8.\n'),
dict(role='HUMAN', prompt='Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?'),
dict(role='BOT', prompt='A: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9.\n'),
dict(role='HUMAN', prompt='Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?'),
dict(role='BOT', prompt='A: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29.\n'),
dict(role='HUMAN', prompt='Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?'),
dict(role='BOT', prompt='A: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33.\n'),
dict(role='HUMAN', prompt='Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?'),
dict(role='BOT', prompt='A: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8.\n'),
dict(role='HUMAN', prompt='Q: {question}'),
dict(role='BOT', prompt='A: {answer}\n'),
],
}
mathbench_sets = {
'college': ['single_choice_cn', 'cloze_en'],
'high': ['single_choice_cn', 'single_choice_en'],
'middle': ['single_choice_cn'],
'primary': ['single_choice_cn', 'cloze_cn'],
}
# Generate reasoning path if set True or just generate the final answer
with_reasoning = True
# Use circular evaluation or not
with_circular_eval = True
mathbench_datasets = []
for _split in list(mathbench_sets.keys()):
for _name in mathbench_sets[_split]:
mathbench_infer_cfg = dict(
ice_template=dict(
type=PromptTemplate,
template=dict(
begin="</E>",
round=[
dict(
role="HUMAN",
prompt=single_choice_prompts[_name + "_with_reasoning"] if with_reasoning else single_choice_prompts[_name],
),
dict(role="BOT", prompt="{answer}")] if 'choice' in _name else cloze_prompts[_name],
),
ice_token="</E>",
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
mathbench_eval_cfg = dict(
evaluator=dict(type=CircularEvaluator if 'choice' in _name else AccEvaluator),
pred_postprocessor=dict(type=first_capital_postprocess) if 'single_choice' in _name else dict(type=mathbench_postprocess, name=_name))
mathbench_datasets.append(
dict(
type=MathBenchDataset,
path=f"./data/mathbench/{_split}",
name=_name,
abbr="mathbench-" + _split + '-' + _name,
reader_cfg=dict(
input_columns=["question"],
output_column="answer"
),
infer_cfg=mathbench_infer_cfg,
eval_cfg=mathbench_eval_cfg,
))
del _split, _name

View File

@ -45,6 +45,7 @@ from .lcsts import * # noqa: F401, F403
from .leval import * # noqa: F401, F403 from .leval import * # noqa: F401, F403
from .longbench import * # noqa: F401, F403 from .longbench import * # noqa: F401, F403
from .math import * # noqa: F401, F403 from .math import * # noqa: F401, F403
from .mathbench import * # noqa: F401, F403
from .mbpp import * # noqa: F401, F403 from .mbpp import * # noqa: F401, F403
from .mmlu import * # noqa: F401, F403 from .mmlu import * # noqa: F401, F403
from .multirc import * # noqa: F401, F403 from .multirc import * # noqa: F401, F403

View File

@ -0,0 +1,98 @@
import copy
import json
import os.path as osp
import re
from datasets import Dataset
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from .base import BaseDataset
def get_number(options):
result_string = ''
for i, option in enumerate(options, start=ord('A')):
result_string += f'{chr(i)}. {option}\n'
return result_string
def get_circular_example(entry, id):
"""For given example, generate four circular examples."""
# Only 4 options is supported for current circular eval.
circular_patterns = ['ABCD', 'BCDA', 'CDAB', 'DABC']
data = []
for c in circular_patterns:
line = copy.deepcopy(entry)
options = []
for i in range(4):
options.append(line['options'][ord(c[i]) - ord('A')])
line['options'] = options
line['answer'] = {
c[0]: 'A',
c[1]: 'B',
c[2]: 'C',
c[3]: 'D'
}[line['answer']]
line['answer'] = str(id) + '--' + line['answer'] + '--' + c
line['question'] = line['question'].strip() + '\n' + get_number(
line['options'])
data.append(line)
return data
@LOAD_DATASET.register_module()
class MathBenchDataset(BaseDataset):
@staticmethod
def load(path: str, name: str, with_circular: bool = True):
"""MathBenth Dataset.
Args:
path (str): Path of the mathbench dataset.
name (str): Name of the target subset.
with_circular (bool): Whether to create circular dataset for
single choice question. Defaults to True.
"""
data = []
filename = osp.join(path, f'{name}.jsonl')
with open(filename, 'r') as infile:
for id, line in enumerate(infile):
entry = json.loads(line)
if 'cloze' in name:
data.append({
'question': entry['question'].strip(),
'answer': entry['answer'].strip()
})
else:
if with_circular:
data.extend(get_circular_example(entry, id))
else:
question = entry['question'].strip(
) + '\n' + get_number(entry['options'])
data.append({
'question': question,
'answer': entry['answer'].strip()
})
dataset = Dataset.from_list(data)
return dataset
@TEXT_POSTPROCESSORS.register_module()
def mathbench_postprocess(text: str, name: str) -> str:
ans = text
if '_cn' in name:
ans_line = ans.split('答案是')
else:
ans_line = ans.split('The answer is')
if len(ans_line) != 1:
ans = ans_line[1].strip()
output = re.sub(r'(\d),(\d)', r'\1\2', ans)
numbers = re.findall(r'-?\d*\.?\d+|\d+', output)
if numbers:
return numbers[-1]
return ans

View File

@ -1,6 +1,7 @@
from .icl_agent_evaluator import * # noqa from .icl_agent_evaluator import * # noqa
from .icl_aucroc_evaluator import AUCROCEvaluator # noqa from .icl_aucroc_evaluator import AUCROCEvaluator # noqa
from .icl_base_evaluator import BaseEvaluator # noqa from .icl_base_evaluator import BaseEvaluator # noqa
from .icl_circular_evaluator import CircularEvaluator # noqa
from .icl_em_evaluator import EMEvaluator # noqa from .icl_em_evaluator import EMEvaluator # noqa
from .icl_hf_evaluator import * # noqa from .icl_hf_evaluator import * # noqa
from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa from .icl_jieba_rouge_evaluator import JiebaRougeEvaluator # noqa

View File

@ -0,0 +1,104 @@
import collections
from opencompass.registry import ICL_EVALUATORS
from .icl_base_evaluator import BaseEvaluator
@ICL_EVALUATORS.register_module()
class CircularEvaluator(BaseEvaluator):
"""Robust circular evaluator for multi-choice questions."""
def __init__(self) -> None:
super().__init__()
self.cp4 = ['ABCD', 'BCDA', 'CDAB', 'DABC']
self.cp1 = ['ABCD']
def score(self, predictions, references):
"""Calculate the accuracy of predictions.
Args:
predictions (list): List of predictions.
references (list): List of references.
Returns:
dict: A dict of evaluation results.
"""
self._metrics = {}
self._metrics.update({'acc_4': 0, 'acc_1': 0})
# Accuracy for patterns with no circular shift / 4 circular shifts
for pred, reference in zip(predictions, references):
index, ref, circular_pattern = reference.split('--')
if circular_pattern in self.cp4:
self._metrics['acc_4'] += 1 if pred == ref else 0
if circular_pattern in self.cp1:
self._metrics['acc_1'] += 1 if pred == ref else 0
for k in ['acc_4', 'acc_1']:
self._metrics[k] = self._metrics[k] / len(predictions) * 4 / int(
k.split('_')[-1]) * 100
# Accuracy for patterns with no circular shift / 4 circular shifts
details = {4: {}, 1: {}}
for pred, reference in zip(predictions, references):
index, ref, circular_pattern = reference.split('--')
if index not in details[4]:
details[4][index] = []
details[1][index] = []
if circular_pattern in self.cp4:
details[4][index].append(True if pred == ref else False)
if circular_pattern in self.cp1:
details[1][index].append(True if pred == ref else False)
# Calculate accuracy for having at least j correct out of i total
for i in [1, 4]:
for j in range(0, i + 1):
count, total = 0, 0
for index in details[i]:
if sum(details[i][index]) >= j:
count += 1
total += 1
self._metrics[f'more_{i}_{j}'] = count / total * 100
# Consider fully correct as correct
for i in [1, 4]:
self._metrics[f'perf_{i}'] = self._metrics[f'more_{i}_{i}']
# Calculate voting accuracy
voting = {'vote_4': {}, 'vote_1': {}}
refs = {}
for pred, reference in zip(predictions, references):
index, ref, circular_pattern = reference.split('--')
c = circular_pattern
back_map = {'A': c[0], 'B': c[1], 'C': c[2], 'D': c[3]}
ref = back_map[ref]
if pred not in ['A', 'B', 'C', 'D']:
pred = '-'
else:
pred = back_map[pred]
if index not in voting['vote_4']:
voting['vote_4'][index] = collections.Counter()
voting['vote_1'][index] = collections.Counter()
refs[index] = ref
if c in self.cp4:
voting['vote_4'][index][pred] += 1
if c in self.cp1:
voting['vote_1'][index][pred] += 1
for k in ['vote_4', 'vote_1']:
voting_count = 0
for index in voting[k]:
if refs[index] == voting[k][index].most_common(1)[0][0]:
voting_count += 1
self._metrics[k] = voting_count / len(voting[k]) * 100
# Calculate the frequency of ABCD in model predictions
prior_counts = {'A': 0, 'B': 0, 'C': 0, 'D': 0, '-': 0}
for pred, reference in zip(predictions, references):
if pred in ['A', 'B', 'C', 'D']:
prior_counts[pred] += 1
else:
prior_counts['-'] += 1
for k in ['A', 'B', 'C', 'D', '-']:
self._metrics[f'prior_{k}'] = prior_counts[k] / len(
predictions) * 100
return self._metrics