mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
377 lines
13 KiB
Python
377 lines
13 KiB
Python
import json
|
|
|
|
from datasets import Dataset, DatasetDict
|
|
|
|
from opencompass.openicl.icl_evaluator import BaseEvaluator
|
|
from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
|
|
TEXT_POSTPROCESSORS)
|
|
|
|
from .base import BaseDataset
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
class MATHDataset(BaseDataset):
|
|
|
|
@staticmethod
|
|
def load(path: str):
|
|
|
|
def remove_boxed(s):
|
|
left = '\\boxed{'
|
|
try:
|
|
assert s[:len(left)] == left
|
|
assert s[-1] == '}'
|
|
return s[len(left):-1]
|
|
except Exception:
|
|
return None
|
|
|
|
def last_boxed_only_string(string):
|
|
idx = string.rfind('\\boxed')
|
|
if idx < 0:
|
|
idx = string.rfind('\\fbox')
|
|
if idx < 0:
|
|
return None
|
|
|
|
i = idx
|
|
right_brace_idx = None
|
|
num_left_braces_open = 0
|
|
while i < len(string):
|
|
if string[i] == '{':
|
|
num_left_braces_open += 1
|
|
if string[i] == '}':
|
|
num_left_braces_open -= 1
|
|
if num_left_braces_open == 0:
|
|
right_brace_idx = i
|
|
break
|
|
i += 1
|
|
|
|
if right_brace_idx is None:
|
|
retval = None
|
|
else:
|
|
retval = string[idx:right_brace_idx + 1]
|
|
|
|
return retval
|
|
|
|
dataset = DatasetDict()
|
|
data = json.load(open(path))
|
|
raw_data = []
|
|
for i in data.keys():
|
|
raw_data.append({
|
|
'problem':
|
|
data[i]['problem'],
|
|
'solution':
|
|
remove_boxed(last_boxed_only_string(data[i]['solution']))
|
|
})
|
|
dataset['test'] = Dataset.from_list(raw_data)
|
|
dataset['train'] = Dataset.from_list(raw_data)
|
|
return dataset
|
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('math_postprocess')
|
|
def math_postprocess(text: str) -> str:
|
|
SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''),
|
|
(r'\ ', ''), (' ', ''), ('mbox', 'text'),
|
|
(',\\text{and}', ','), ('\\text{and}', ','),
|
|
('\\text{m}', '\\text{}'), ('\\le', '<')]
|
|
REMOVED_EXPRESSIONS = [
|
|
'square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'ft',
|
|
'hours', 'km', 'units', '\\ldots', 'sue', 'points', 'feet', 'minutes',
|
|
'digits', 'cents', 'degrees', 'cm', 'gm', 'pounds', 'meters', 'meals',
|
|
'edges', 'students', 'childrentickets', 'multiples', '\\text{s}',
|
|
'\\text{.}', '\\text{\ns}', '\\text{}^2', '\\text{}^3', '\\text{\n}',
|
|
'\\text{}', r'\mathrm{th}', r'^\circ', r'^{\circ}', r'\;', r',\!',
|
|
'{,}', '"', '\\dots', '\n', '\r', '\f'
|
|
]
|
|
import re
|
|
|
|
def normalize_final_answer(final_answer: str) -> str:
|
|
"""Normalize a final answer to a quantitative reasoning question."""
|
|
# final_answer = final_answer.split('=')[-1]
|
|
for before, after in SUBSTITUTIONS:
|
|
final_answer = final_answer.replace(before, after)
|
|
for expr in REMOVED_EXPRESSIONS:
|
|
final_answer = final_answer.replace(expr, '')
|
|
|
|
# Extract answer that is in LaTeX math, is bold,
|
|
# is surrounded by a box, etc.
|
|
final_answer = re.sub(r'(\\text\{)(.*?)(\})', '\\2', final_answer)
|
|
final_answer = re.sub(r'(\\textbf\{)(.*?)(\})', '\\2', final_answer)
|
|
final_answer = re.sub(r'(\\overline\{)(.*?)(\})', '\\2', final_answer)
|
|
final_answer = re.sub(r'(\\boxed\{)(.*)(\})', '\\2', final_answer)
|
|
assert '\n' not in final_answer
|
|
assert '\r' not in final_answer
|
|
assert '\f' not in final_answer
|
|
if len(re.findall(r'finalansweris(.*)', final_answer)) > 0:
|
|
final_answer = re.findall(r'finalansweris(.*)', final_answer)[-1]
|
|
|
|
if len(re.findall(r'oxed\{(.*?)\}', final_answer)) > 0:
|
|
final_answer = re.findall(r'oxed\{(.*?)\}', final_answer)[-1]
|
|
|
|
if len(re.findall(r'\$(.*?)\$', final_answer)) > 0:
|
|
final_answer = re.findall(r'\$(.*?)\$', final_answer)[-1]
|
|
final_answer = final_answer.strip()
|
|
if 'rac' in final_answer and '\\frac' not in final_answer:
|
|
final_answer = final_answer.replace('rac', '\\frac')
|
|
|
|
# Normalize shorthand TeX:
|
|
# \fracab -> \frac{a}{b}
|
|
# \frac{abc}{bef} -> \frac{abc}{bef}
|
|
# \fracabc -> \frac{a}{b}c
|
|
# \sqrta -> \sqrt{a}
|
|
# \sqrtab -> sqrt{a}b
|
|
final_answer = re.sub(r'(frac)([^{])(.)', 'frac{\\2}{\\3}',
|
|
final_answer)
|
|
final_answer = re.sub(r'(sqrt)([^{])', 'sqrt{\\2}', final_answer)
|
|
final_answer = final_answer.replace('$', '')
|
|
|
|
# Normalize 100,000 -> 100000
|
|
if final_answer.replace(',', '').isdigit():
|
|
final_answer = final_answer.replace(',', '')
|
|
|
|
return final_answer
|
|
|
|
for maybe_ans in text.split('.'):
|
|
if 'final answer' in maybe_ans.lower():
|
|
return normalize_final_answer(maybe_ans)
|
|
return normalize_final_answer(text.split('.')[0])
|
|
# return normalize_final_answer(
|
|
# text.split('Final Answer: ', 1)[-1].split('\n\n')[0])
|
|
|
|
|
|
@ICL_EVALUATORS.register_module()
|
|
class MATHEvaluator(BaseEvaluator):
|
|
|
|
def score(self, predictions, references):
|
|
if len(predictions) != len(references):
|
|
return {
|
|
'error': 'predictions and references have different '
|
|
'length'
|
|
}
|
|
correct = 0
|
|
count = 0
|
|
details = []
|
|
for i, j in zip(predictions, references):
|
|
detail = {'pred': i, 'answer': j, 'correct': False}
|
|
count += 1
|
|
if self.is_equiv(i, j):
|
|
correct += 1
|
|
detail['correct'] = True
|
|
details.append(detail)
|
|
result = {'accuracy': 100 * correct / count, 'details': details}
|
|
return result
|
|
|
|
def _fix_fracs(self, string):
|
|
substrs = string.split('\\frac')
|
|
new_str = substrs[0]
|
|
if len(substrs) > 1:
|
|
substrs = substrs[1:]
|
|
for substr in substrs:
|
|
new_str += '\\frac'
|
|
if substr[0] == '{':
|
|
new_str += substr
|
|
else:
|
|
try:
|
|
assert len(substr) >= 2
|
|
except AssertionError:
|
|
return string
|
|
a = substr[0]
|
|
b = substr[1]
|
|
if b != '{':
|
|
if len(substr) > 2:
|
|
post_substr = substr[2:]
|
|
new_str += '{' + a + '}{' + b + '}' + post_substr
|
|
else:
|
|
new_str += '{' + a + '}{' + b + '}'
|
|
else:
|
|
if len(substr) > 2:
|
|
post_substr = substr[2:]
|
|
new_str += '{' + a + '}' + b + post_substr
|
|
else:
|
|
new_str += '{' + a + '}' + b
|
|
string = new_str
|
|
return string
|
|
|
|
def _fix_a_slash_b(self, string):
|
|
if len(string.split('/')) != 2:
|
|
return string
|
|
a = string.split('/')[0]
|
|
b = string.split('/')[1]
|
|
try:
|
|
a = int(a)
|
|
b = int(b)
|
|
assert string == '{}/{}'.format(a, b)
|
|
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
|
|
return new_string
|
|
except AssertionError:
|
|
return string
|
|
|
|
def _remove_right_units(self, string):
|
|
# "\\text{ " only ever occurs (at least in the val set) when describing
|
|
# units
|
|
if '\\text{ ' in string:
|
|
splits = string.split('\\text{ ')
|
|
assert len(splits) == 2
|
|
return splits[0]
|
|
else:
|
|
return string
|
|
|
|
def _fix_sqrt(self, string):
|
|
if '\\sqrt' not in string:
|
|
return string
|
|
splits = string.split('\\sqrt')
|
|
new_string = splits[0]
|
|
for split in splits[1:]:
|
|
if split[0] != '{':
|
|
a = split[0]
|
|
new_substr = '\\sqrt{' + a + '}' + split[1:]
|
|
else:
|
|
new_substr = '\\sqrt' + split
|
|
new_string += new_substr
|
|
return new_string
|
|
|
|
def _strip_string(self, string):
|
|
# linebreaks
|
|
string = string.replace('\n', '')
|
|
|
|
# remove inverse spaces
|
|
string = string.replace('\\!', '')
|
|
|
|
# replace \\ with \
|
|
string = string.replace('\\\\', '\\')
|
|
|
|
# replace tfrac and dfrac with frac
|
|
string = string.replace('tfrac', 'frac')
|
|
string = string.replace('dfrac', 'frac')
|
|
|
|
# remove \left and \right
|
|
string = string.replace('\\left', '')
|
|
string = string.replace('\\right', '')
|
|
|
|
# Remove circ (degrees)
|
|
string = string.replace('^{\\circ}', '')
|
|
string = string.replace('^\\circ', '')
|
|
|
|
# remove dollar signs
|
|
string = string.replace('\\$', '')
|
|
|
|
# remove units (on the right)
|
|
string = self._remove_right_units(string)
|
|
|
|
# remove percentage
|
|
string = string.replace('\\%', '')
|
|
string = string.replace('\%', '') # noqa: W605
|
|
|
|
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively,
|
|
# add "0" if "." is the start of the string
|
|
string = string.replace(' .', ' 0.')
|
|
string = string.replace('{.', '{0.')
|
|
# if empty, return empty string
|
|
if len(string) == 0:
|
|
return string
|
|
if string[0] == '.':
|
|
string = '0' + string
|
|
|
|
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
|
if len(string.split('=')) == 2:
|
|
if len(string.split('=')[0]) <= 2:
|
|
string = string.split('=')[1]
|
|
|
|
# fix sqrt3 --> sqrt{3}
|
|
string = self._fix_sqrt(string)
|
|
|
|
# remove spaces
|
|
string = string.replace(' ', '')
|
|
|
|
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works
|
|
# with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
|
string = self._fix_fracs(string)
|
|
|
|
# manually change 0.5 --> \frac{1}{2}
|
|
if string == '0.5':
|
|
string = '\\frac{1}{2}'
|
|
|
|
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix
|
|
# in case the model output is X/Y
|
|
string = self._fix_a_slash_b(string)
|
|
|
|
return string
|
|
|
|
def is_equiv(self, str1, str2, verbose=False):
|
|
if str1 is None and str2 is None:
|
|
print('WARNING: Both None')
|
|
return True
|
|
if str1 is None or str2 is None:
|
|
return False
|
|
|
|
try:
|
|
ss1 = self._strip_string(str1)
|
|
ss2 = self._strip_string(str2)
|
|
if verbose:
|
|
print(ss1, ss2)
|
|
return ss1 == ss2
|
|
except: # noqa
|
|
return str1 == str2
|
|
|
|
|
|
class MATHAgentEvaluator(MATHEvaluator):
|
|
"""math agent evaluator for soft condition.
|
|
|
|
Args:
|
|
action (str): Action for catching internal prediction.
|
|
Defaults to `PythonInterpreter`.
|
|
"""
|
|
|
|
def __init__(self, action: str = 'PythonInterpreter'):
|
|
self.action = action
|
|
|
|
def soft_equal(self, pred, refer, step):
|
|
try:
|
|
soft_pred = step['result']['text']
|
|
if self.is_equiv(soft_pred, refer):
|
|
return True
|
|
except Exception:
|
|
# result might not exists
|
|
print(pred, soft_pred, refer)
|
|
return False
|
|
|
|
def get_action(self, step):
|
|
for s in step[::-1]:
|
|
if s['type'] == self.action:
|
|
return s
|
|
|
|
def score(self, predictions, references, steps):
|
|
"""Calculate accuracy."""
|
|
if len(predictions) != len(references):
|
|
return {'error': 'preds and refrs have different length'}
|
|
|
|
row_reasoning_scope = 0
|
|
action_scope = 0
|
|
code_scope = 0
|
|
reasoning_scope = 0
|
|
final_scope = 0
|
|
total = len(references)
|
|
for pred, refer, step in zip(predictions, references, steps):
|
|
# if final answer right
|
|
if self.is_equiv(pred, refer):
|
|
if self.get_action(step):
|
|
final_scope += 1
|
|
else:
|
|
row_reasoning_scope += 1
|
|
else:
|
|
s = self.get_action(step)
|
|
if s:
|
|
action_scope += 1
|
|
if not s['errmsg']:
|
|
code_scope += 1
|
|
# whether action result is correct
|
|
reasoning_scope += self.soft_equal(pred, refer, s)
|
|
|
|
result = dict(
|
|
follow_acc=100 * (row_reasoning_scope + final_scope) / total,
|
|
reasoning_acc=100 *
|
|
(reasoning_scope + final_scope + row_reasoning_scope) / total,
|
|
code_acc=100 * (code_scope + final_scope) /
|
|
(action_scope + final_scope),
|
|
action_pct=100 * (action_scope + final_scope) / total,
|
|
)
|
|
return result
|