import json from datasets import Dataset, DatasetDict from opencompass.openicl.icl_evaluator import BaseEvaluator from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS from .base import BaseDataset @LOAD_DATASET.register_module() class MATHDataset(BaseDataset): @staticmethod def load(path: str): def remove_boxed(s): left = '\\boxed{' try: assert s[:len(left)] == left assert s[-1] == '}' return s[len(left):-1] except Exception: return None def last_boxed_only_string(string): idx = string.rfind('\\boxed') if idx < 0: idx = string.rfind('\\fbox') if idx < 0: return None i = idx right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == '{': num_left_braces_open += 1 if string[i] == '}': num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 if right_brace_idx is None: retval = None else: retval = string[idx:right_brace_idx + 1] return retval dataset = DatasetDict() data = json.load(open(path)) raw_data = [] for i in data.keys(): raw_data.append({ 'problem': data[i]['problem'], 'solution': remove_boxed(last_boxed_only_string(data[i]['solution'])) }) dataset['test'] = Dataset.from_list(raw_data) dataset['train'] = Dataset.from_list(raw_data) return dataset @TEXT_POSTPROCESSORS.register_module('math') def math_postprocess(text: str) -> str: SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''), (r'\ ', ''), (' ', ''), ('mbox', 'text'), (',\\text{and}', ','), ('\\text{and}', ','), ('\\text{m}', '\\text{}'), ('\le', '<')] REMOVED_EXPRESSIONS = [ 'square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'ft', 'hours', 'km', 'units', '\\ldots', 'sue', 'points', 'feet', 'minutes', 'digits', 'cents', 'degrees', 'cm', 'gm', 'pounds', 'meters', 'meals', 'edges', 'students', 'childrentickets', 'multiples', '\\text{s}', '\\text{.}', '\\text{\ns}', '\\text{}^2', '\\text{}^3', '\\text{\n}', '\\text{}', r'\mathrm{th}', r'^\circ', r'^{\circ}', r'\;', r',\!', '{,}', '"', '\\dots', '\n', '\r', '\f' ] import re def normalize_final_answer(final_answer: str) -> str: """Normalize a final answer to a quantitative reasoning question.""" # final_answer = final_answer.split('=')[-1] for before, after in SUBSTITUTIONS: final_answer = final_answer.replace(before, after) for expr in REMOVED_EXPRESSIONS: final_answer = final_answer.replace(expr, '') # Extract answer that is in LaTeX math, is bold, # is surrounded by a box, etc. final_answer = re.sub(r'(\\text\{)(.*?)(\})', '\\2', final_answer) final_answer = re.sub(r'(\\textbf\{)(.*?)(\})', '\\2', final_answer) final_answer = re.sub(r'(\\overline\{)(.*?)(\})', '\\2', final_answer) final_answer = re.sub(r'(\\boxed\{)(.*)(\})', '\\2', final_answer) assert '\n' not in final_answer and '\r' not in final_answer and '\f' not in final_answer if len(re.findall(r'finalansweris(.*)', final_answer)) > 0: final_answer = re.findall(r'finalansweris(.*)', final_answer)[-1] if len(re.findall(r'oxed\{(.*?)\}', final_answer)) > 0: final_answer = re.findall(r'oxed\{(.*?)\}', final_answer)[-1] if len(re.findall(r'\$(.*?)\$', final_answer)) > 0: final_answer = re.findall(r'\$(.*?)\$', final_answer)[-1] final_answer = final_answer.strip() if 'rac' in final_answer and '\\frac' not in final_answer: final_answer = final_answer.replace('rac', '\\frac') # Normalize shorthand TeX: # \fracab -> \frac{a}{b} # \frac{abc}{bef} -> \frac{abc}{bef} # \fracabc -> \frac{a}{b}c # \sqrta -> \sqrt{a} # \sqrtab -> sqrt{a}b final_answer = re.sub(r'(frac)([^{])(.)', 'frac{\\2}{\\3}', final_answer) final_answer = re.sub(r'(sqrt)([^{])', 'sqrt{\\2}', final_answer) final_answer = final_answer.replace('$', '') # Normalize 100,000 -> 100000 if final_answer.replace(',', '').isdigit(): final_answer = final_answer.replace(',', '') return final_answer for maybe_ans in text.split('.'): if 'final answer' in maybe_ans.lower(): return normalize_final_answer(maybe_ans) return normalize_final_answer(text.split('.')[0]) # return normalize_final_answer( # text.split('Final Answer: ', 1)[-1].split('\n\n')[0]) @ICL_EVALUATORS.register_module() class MATHEvaluator(BaseEvaluator): def score(self, predictions, references): if len(predictions) != len(references): return { 'error': 'predictions and references have different ' 'length' } correct = 0 count = 0 for i, j in zip(predictions, references): count += 1 if self.is_equiv(i, j): correct += 1 result = {'accuracy': 100 * correct / count} return result def _fix_fracs(self, string): substrs = string.split('\\frac') new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] for substr in substrs: new_str += '\\frac' if substr[0] == '{': new_str += substr else: try: assert len(substr) >= 2 except AssertionError: return string a = substr[0] b = substr[1] if b != '{': if len(substr) > 2: post_substr = substr[2:] new_str += '{' + a + '}{' + b + '}' + post_substr else: new_str += '{' + a + '}{' + b + '}' else: if len(substr) > 2: post_substr = substr[2:] new_str += '{' + a + '}' + b + post_substr else: new_str += '{' + a + '}' + b string = new_str return string def _fix_a_slash_b(self, string): if len(string.split('/')) != 2: return string a = string.split('/')[0] b = string.split('/')[1] try: a = int(a) b = int(b) assert string == '{}/{}'.format(a, b) new_string = '\\frac{' + str(a) + '}{' + str(b) + '}' return new_string except AssertionError: return string def _remove_right_units(self, string): # "\\text{ " only ever occurs (at least in the val set) when describing # units if '\\text{ ' in string: splits = string.split('\\text{ ') assert len(splits) == 2 return splits[0] else: return string def _fix_sqrt(self, string): if '\\sqrt' not in string: return string splits = string.split('\\sqrt') new_string = splits[0] for split in splits[1:]: if split[0] != '{': a = split[0] new_substr = '\\sqrt{' + a + '}' + split[1:] else: new_substr = '\\sqrt' + split new_string += new_substr return new_string def _strip_string(self, string): # linebreaks string = string.replace('\n', '') # remove inverse spaces string = string.replace('\\!', '') # replace \\ with \ string = string.replace('\\\\', '\\') # replace tfrac and dfrac with frac string = string.replace('tfrac', 'frac') string = string.replace('dfrac', 'frac') # remove \left and \right string = string.replace('\\left', '') string = string.replace('\\right', '') # Remove circ (degrees) string = string.replace('^{\\circ}', '') string = string.replace('^\\circ', '') # remove dollar signs string = string.replace('\\$', '') # remove units (on the right) string = self._remove_right_units(string) # remove percentage string = string.replace('\\%', '') string = string.replace('\%', '') # noqa: W605 # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, # add "0" if "." is the start of the string string = string.replace(' .', ' 0.') string = string.replace('{.', '{0.') # if empty, return empty string if len(string) == 0: return string if string[0] == '.': string = '0' + string # to consider: get rid of e.g. "k = " or "q = " at beginning if len(string.split('=')) == 2: if len(string.split('=')[0]) <= 2: string = string.split('=')[1] # fix sqrt3 --> sqrt{3} string = self._fix_sqrt(string) # remove spaces string = string.replace(' ', '') # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works # with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} string = self._fix_fracs(string) # manually change 0.5 --> \frac{1}{2} if string == '0.5': string = '\\frac{1}{2}' # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix # in case the model output is X/Y string = self._fix_a_slash_b(string) return string def is_equiv(self, str1, str2, verbose=False): if str1 is None and str2 is None: print('WARNING: Both None') return True if str1 is None or str2 is None: return False try: ss1 = self._strip_string(str1) ss2 = self._strip_string(str2) if verbose: print(ss1, ss2) return ss1 == ss2 except: # noqa return str1 == str2