mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Sync] Updata dataset cfg for internMath (#837)
Co-authored-by: liuhongwei <liuhongwei@pjlab.org.cn>
This commit is contained in:
parent
f7d7837ac0
commit
0991dd33a0
33
configs/datasets/gsm8k/gsm8k_gen_701491.py
Normal file
33
configs/datasets/gsm8k/gsm8k_gen_701491.py
Normal file
@ -0,0 +1,33 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.openicl.icl_evaluator import AccEvaluator
|
||||
from opencompass.datasets import GSM8KDataset, gsm8k_postprocess, gsm8k_dataset_postprocess, Gsm8kEvaluator
|
||||
|
||||
gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer')
|
||||
|
||||
gsm8k_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(
|
||||
round=[
|
||||
dict(role='HUMAN', prompt="Question: {question}\nLet's think step by step\nAnswer:")
|
||||
],
|
||||
)),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
gsm8k_eval_cfg = dict(evaluator=dict(type=Gsm8kEvaluator),
|
||||
pred_role="BOT",
|
||||
pred_postprocessor=dict(type=gsm8k_postprocess),
|
||||
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess))
|
||||
|
||||
gsm8k_datasets = [
|
||||
dict(
|
||||
abbr='gsm8k',
|
||||
type=GSM8KDataset,
|
||||
path='./data/gsm8k',
|
||||
reader_cfg=gsm8k_reader_cfg,
|
||||
infer_cfg=gsm8k_infer_cfg,
|
||||
eval_cfg=gsm8k_eval_cfg)
|
||||
]
|
68
configs/datasets/math/math_gen_0957ff.py
Normal file
68
configs/datasets/math/math_gen_0957ff.py
Normal file
@ -0,0 +1,68 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import MATHDataset, MATHEvaluator, math_postprocess
|
||||
|
||||
math_reader_cfg = dict(input_columns=['problem'], output_column='solution')
|
||||
|
||||
math_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"Problem:\nFind the domain of the expression $\\frac{\sqrt{x-2}}{\sqrt{5-x}}$.}\nSolution:"
|
||||
),
|
||||
dict(
|
||||
role="BOT",
|
||||
prompt=
|
||||
"The expressions inside each square root must be non-negative. Therefore, $x-2 \ge 0$, so $x\ge2$, and $5 - x \ge 0$, so $x \le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{[2,5)}$.\nFinal Answer: The final answer is $[2,5)$. I hope it is correct.\n"
|
||||
),
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"Problem:\nIf $\det \mathbf{A} = 2$ and $\det \mathbf{B} = 12,$ then find $\det (\mathbf{A} \mathbf{B}).$\nSolution:"
|
||||
),
|
||||
dict(
|
||||
role="BOT",
|
||||
prompt=
|
||||
"We have that $\det (\mathbf{A} \mathbf{B}) = (\det \mathbf{A})(\det \mathbf{B}) = (2)(12) = \\boxed{24}.$\nFinal Answer: The final answer is $24$. I hope it is correct.\n"
|
||||
),
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"Problem:\nTerrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?\nSolution:"
|
||||
),
|
||||
dict(
|
||||
role="BOT",
|
||||
prompt=
|
||||
"If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\cdot 12\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\cdot15\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \Rightarrow\qquad n&=480/30=\\boxed{16} \end{align*}\nFinal Answer: The final answer is $16$. I hope it is correct.\n"
|
||||
),
|
||||
dict(
|
||||
role="HUMAN",
|
||||
prompt=
|
||||
"Problem:\nIf the system of equations: \\begin{align*} 6x-4y&=a,\\\\ 6y-9x &=b. \end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is nonzero.\nSolution:"
|
||||
),
|
||||
dict(
|
||||
role="BOT",
|
||||
prompt=
|
||||
"If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$\nFinal Answer: The final answer is $-\\frac{2}{3}$. I hope it is correct.\n"
|
||||
),
|
||||
dict(role="HUMAN", prompt="Problem:\n{problem}\nSolution:\n"),
|
||||
])),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
math_eval_cfg = dict(
|
||||
evaluator=dict(type=MATHEvaluator), pred_postprocessor=dict(type=math_postprocess))
|
||||
|
||||
math_datasets = [
|
||||
dict(
|
||||
type=MATHDataset,
|
||||
abbr='math',
|
||||
path='./data/math/math.json',
|
||||
reader_cfg=math_reader_cfg,
|
||||
infer_cfg=math_infer_cfg,
|
||||
eval_cfg=math_eval_cfg)
|
||||
]
|
28
configs/datasets/math/math_gen_736506.py
Normal file
28
configs/datasets/math/math_gen_736506.py
Normal file
@ -0,0 +1,28 @@
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
from opencompass.datasets import MATHInternDataset, MATHInternEvaluator, math_intern_postprocess
|
||||
|
||||
math_reader_cfg = dict(input_columns=['problem'], output_column='solution')
|
||||
|
||||
math_infer_cfg = dict(
|
||||
prompt_template=dict(
|
||||
type=PromptTemplate,
|
||||
template=dict(round=[
|
||||
dict(role='HUMAN', prompt="Question: {problem}\nLet's think step by step\nAnswer:")
|
||||
])),
|
||||
retriever=dict(type=ZeroRetriever),
|
||||
inferencer=dict(type=GenInferencer, max_out_len=512))
|
||||
|
||||
math_eval_cfg = dict(
|
||||
evaluator=dict(type=MATHInternEvaluator), pred_postprocessor=dict(type=math_intern_postprocess))
|
||||
|
||||
math_datasets = [
|
||||
dict(
|
||||
type=MATHInternDataset,
|
||||
abbr='math',
|
||||
path='./data/math/math.json',
|
||||
reader_cfg=math_reader_cfg,
|
||||
infer_cfg=math_infer_cfg,
|
||||
eval_cfg=math_eval_cfg)
|
||||
]
|
@ -61,6 +61,7 @@ from .longbench import * # noqa: F401, F403
|
||||
from .mastermath2024v1 import * # noqa: F401, F403
|
||||
from .math import * # noqa: F401, F403
|
||||
from .math401 import * # noqa: F401, F403
|
||||
from .math_intern import * # noqa: F401, F403
|
||||
from .mathbench import * # noqa: F401, F403
|
||||
from .mbpp import * # noqa: F401, F403
|
||||
from .medbench import * # noqa: F401, F403
|
||||
|
342
opencompass/datasets/math_intern.py
Normal file
342
opencompass/datasets/math_intern.py
Normal file
@ -0,0 +1,342 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
||||
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
|
||||
TEXT_POSTPROCESSORS)
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
def last_boxed_only_string(string):
|
||||
idx = string.rfind('\\boxed')
|
||||
if idx < 0:
|
||||
idx = string.rfind('\\fbox')
|
||||
if idx < 0:
|
||||
return None
|
||||
|
||||
i = idx
|
||||
right_brace_idx = None
|
||||
num_left_braces_open = 0
|
||||
while i < len(string):
|
||||
if string[i] == '{':
|
||||
num_left_braces_open += 1
|
||||
if string[i] == '}':
|
||||
num_left_braces_open -= 1
|
||||
if num_left_braces_open == 0:
|
||||
right_brace_idx = i
|
||||
break
|
||||
i += 1
|
||||
|
||||
if right_brace_idx is None:
|
||||
retval = None
|
||||
else:
|
||||
retval = string[idx:right_brace_idx + 1]
|
||||
|
||||
return retval
|
||||
|
||||
|
||||
def remove_boxed(s):
|
||||
left = '\\boxed{'
|
||||
try:
|
||||
assert s[:len(left)] == left
|
||||
assert s[-1] == '}'
|
||||
return s[len(left):-1]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
|
||||
boxed_str = last_boxed_only_string(pred_str)
|
||||
if boxed_str is None:
|
||||
return None
|
||||
answer = remove_boxed(boxed_str)
|
||||
if answer is None:
|
||||
return None
|
||||
if strip_double_curly_brace:
|
||||
match = re.match('^\{(.*)\}$', answer) # noqa: W605
|
||||
if match:
|
||||
answer = match.group(1)
|
||||
return answer
|
||||
|
||||
|
||||
@LOAD_DATASET.register_module()
|
||||
class MATHInternDataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def load(path: str):
|
||||
dataset = DatasetDict()
|
||||
data = json.load(open(path))
|
||||
raw_data = []
|
||||
for i in data.keys():
|
||||
raw_data.append({
|
||||
'problem':
|
||||
data[i]['problem'],
|
||||
'solution':
|
||||
extract_boxed_answer(data[i]['solution'])
|
||||
})
|
||||
dataset['test'] = Dataset.from_list(raw_data)
|
||||
dataset['train'] = Dataset.from_list(raw_data)
|
||||
return dataset
|
||||
|
||||
|
||||
@ICL_EVALUATORS.register_module()
|
||||
class MATHInternEvaluator(BaseEvaluator):
|
||||
|
||||
def score(self, predictions, references):
|
||||
if len(predictions) != len(references):
|
||||
return {
|
||||
'error': 'predictions and references have different '
|
||||
'length'
|
||||
}
|
||||
correct = 0
|
||||
count = 0
|
||||
details = []
|
||||
for i, j in zip(predictions, references):
|
||||
detail = {'pred': i, 'answer': j, 'correct': False}
|
||||
count += 1
|
||||
if is_equiv(i, j):
|
||||
correct += 1
|
||||
detail['correct'] = True
|
||||
details.append(detail)
|
||||
result = {'accuracy': 100 * correct / count, 'details': details}
|
||||
return result
|
||||
|
||||
|
||||
@TEXT_POSTPROCESSORS.register_module('math_intern_postprocess')
|
||||
def math_intern_postprocess(text: str) -> str:
|
||||
extractor = Extractor()
|
||||
return extractor.extract_answer(text)
|
||||
|
||||
|
||||
class Extractor:
|
||||
|
||||
def extract_matching_bracket(cls, target_str: str):
|
||||
if not target_str:
|
||||
return target_str
|
||||
current_nest_level = 1
|
||||
for i, ch in enumerate(target_str):
|
||||
if ch == '{':
|
||||
current_nest_level += 1
|
||||
elif ch == '}':
|
||||
current_nest_level -= 1
|
||||
if current_nest_level == 0:
|
||||
break
|
||||
return target_str[:i]
|
||||
|
||||
def clean(cls, target_str: str):
|
||||
opt = target_str.strip().replace('{{', '{').replace('}}', '}')
|
||||
if not opt:
|
||||
return opt
|
||||
if opt[-1] == '.' or opt[-1] == '。':
|
||||
return opt[:-1]
|
||||
return opt
|
||||
|
||||
def extract_answer(cls, pred: str, extract_last_num=False):
|
||||
if pred.find('The final answer is ') >= 0:
|
||||
x = pred[pred.find('The final answer is ') +
|
||||
len('The final answer is '):]
|
||||
x = x[1:x.find('$.')]
|
||||
# print(x)
|
||||
return cls.clean(x)
|
||||
if pred.find('\n\nQuestion:') >= 0:
|
||||
pred = pred.split('\n\nQuestion:')[0]
|
||||
if pred.find('The answer is'):
|
||||
pred = pred[pred.find('The answer is') + len('The answer is'):]
|
||||
return cls.clean(pred)
|
||||
if pred.find('# Answer') >= 0:
|
||||
return cls.clean(pred[pred.find('# Answer') + len('# Answer'):])
|
||||
if pred.find('The answer is:') >= 0:
|
||||
return cls.clean(pred[pred.find('The answer is:') +
|
||||
len('The answer is:'):])
|
||||
if pred.find('####') >= 0:
|
||||
return cls.clean(pred[pred.find('####') + 4:])
|
||||
left = '\\boxed{'
|
||||
if pred.find(left) >= 0:
|
||||
pred = pred[pred.find(left) + len(left):]
|
||||
return cls.clean(cls.extract_matching_bracket(pred))
|
||||
|
||||
if extract_last_num:
|
||||
nums = []
|
||||
opt = ''
|
||||
|
||||
def contain_digit(opt):
|
||||
for ch in opt:
|
||||
if ch.isdigit():
|
||||
return True
|
||||
return False
|
||||
|
||||
for ch in pred:
|
||||
if ch.isdigit() or ch in ' ,.':
|
||||
opt = opt + ch
|
||||
else:
|
||||
if contain_digit(opt):
|
||||
nums.append(opt)
|
||||
opt = ''
|
||||
if contain_digit(opt):
|
||||
return cls.clean(opt)
|
||||
if nums:
|
||||
return cls.clean(nums[-1])
|
||||
return None
|
||||
|
||||
|
||||
def fix_fracs(string):
|
||||
substrs = string.split('\\frac')
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += '\\frac'
|
||||
if substr[0] == '{':
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except AssertionError:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != '{':
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += '{' + a + '}{' + b + '}' + post_substr
|
||||
else:
|
||||
new_str += '{' + a + '}{' + b + '}'
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += '{' + a + '}' + b + post_substr
|
||||
else:
|
||||
new_str += '{' + a + '}' + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def fix_a_slash_b(string):
|
||||
if len(string.split('/')) != 2:
|
||||
return string
|
||||
a = string.split('/')[0]
|
||||
b = string.split('/')[1]
|
||||
try:
|
||||
a = int(a)
|
||||
b = int(b)
|
||||
assert string == '{}/{}'.format(a, b)
|
||||
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
|
||||
return new_string
|
||||
except AssertionError:
|
||||
return string
|
||||
|
||||
|
||||
def remove_right_units(string):
|
||||
# "\\text{ " only ever occurs (at least in the val set)
|
||||
if '\\text{ ' in string:
|
||||
splits = string.split('\\text{ ')
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def fix_sqrt(string):
|
||||
if '\\sqrt' not in string:
|
||||
return string
|
||||
splits = string.split('\\sqrt')
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if split[0] != '{':
|
||||
a = split[0]
|
||||
new_substr = '\\sqrt{' + a + '}' + split[1:]
|
||||
else:
|
||||
new_substr = '\\sqrt' + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
|
||||
def strip_string(string):
|
||||
# linebreaks
|
||||
string = string.replace('\n', '')
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace('\\!', '')
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace('\\\\', '\\')
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace('tfrac', 'frac')
|
||||
string = string.replace('dfrac', 'frac')
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace('\\left', '')
|
||||
string = string.replace('\\right', '')
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace('^{\\circ}', '')
|
||||
string = string.replace('^\\circ', '')
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace('\\$', '')
|
||||
|
||||
# remove units (on the right)
|
||||
string = remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace('\\%', '')
|
||||
string = string.replace('\%', '') # noqa: W605
|
||||
|
||||
string = string.replace(' .', ' 0.')
|
||||
string = string.replace('{.', '{0.')
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == '.':
|
||||
string = '0' + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
if len(string.split('=')) == 2:
|
||||
if len(string.split('=')[0]) <= 2:
|
||||
string = string.split('=')[1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(' ', '')
|
||||
|
||||
string = fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == '0.5':
|
||||
string = '\\frac{1}{2}'
|
||||
|
||||
string = fix_a_slash_b(string)
|
||||
string = string.replace('x \\in', '').strip() # noqa: W605
|
||||
|
||||
# a_b == a, a_{b} == a_b for bit conversion
|
||||
if string.find('_') >= 0:
|
||||
p = string.split('_')
|
||||
p[1] = p[1].replace('{', '').replace('}', '')
|
||||
string = '_'.join(p)
|
||||
|
||||
# 10800 == 10,800; we only deal with single number
|
||||
if string.strip().find(' ') == -1 and string.find('(') == -1:
|
||||
string = string.replace(',', '')
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def is_equiv(str1, str2, verbose=False):
|
||||
if str1 is None and str2 is None:
|
||||
# print("WARNING: Both None")
|
||||
return False
|
||||
if str1 is None or str2 is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
ss1 = strip_string(str1)
|
||||
ss2 = strip_string(str2)
|
||||
return ss1 == ss2
|
||||
except Exception:
|
||||
return str1 == str2
|
@ -71,6 +71,7 @@ def first_option_postprocess(text: str, options: str, cushion=True) -> str:
|
||||
f'答案为\s?([{options}])',
|
||||
f'答案选\s?([{options}])',
|
||||
f'选择?\s?([{options}])',
|
||||
f'故选?\s?([{options}])'
|
||||
f'只有选?项?\s?([{options}])\s?是?对',
|
||||
f'只有选?项?\s?([{options}])\s?是?错',
|
||||
f'只有选?项?\s?([{options}])\s?不?正确',
|
||||
|
Loading…
Reference in New Issue
Block a user