import json import re from datasets import Dataset, DatasetDict from opencompass.openicl.icl_evaluator import BaseEvaluator from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS) from .base import BaseDataset def last_boxed_only_string(string): idx = string.rfind('\\boxed') if idx < 0: idx = string.rfind('\\fbox') if idx < 0: return None i = idx right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == '{': num_left_braces_open += 1 if string[i] == '}': num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 if right_brace_idx is None: retval = None else: retval = string[idx:right_brace_idx + 1] return retval def remove_boxed(s): left = '\\boxed{' try: assert s[:len(left)] == left assert s[-1] == '}' return s[len(left):-1] except Exception: return None def extract_boxed_answer(pred_str, strip_double_curly_brace=False): boxed_str = last_boxed_only_string(pred_str) if boxed_str is None: return None answer = remove_boxed(boxed_str) if answer is None: return None if strip_double_curly_brace: match = re.match('^\{(.*)\}$', answer) # noqa: W605 if match: answer = match.group(1) return answer @LOAD_DATASET.register_module() class MATHInternDataset(BaseDataset): @staticmethod def load(path: str): dataset = DatasetDict() data = json.load(open(path)) raw_data = [] for i in data.keys(): raw_data.append({ 'problem': data[i]['problem'], 'solution': extract_boxed_answer(data[i]['solution']) }) dataset['test'] = Dataset.from_list(raw_data) dataset['train'] = Dataset.from_list(raw_data) return dataset @ICL_EVALUATORS.register_module() class MATHInternEvaluator(BaseEvaluator): def score(self, predictions, references): if len(predictions) != len(references): return { 'error': 'predictions and references have different ' 'length' } correct = 0 count = 0 details = [] for i, j in zip(predictions, references): detail = {'pred': i, 'answer': j, 'correct': False} count += 1 if is_equiv(i, j): correct += 1 detail['correct'] = True details.append(detail) result = {'accuracy': 100 * correct / count, 'details': details} return result @TEXT_POSTPROCESSORS.register_module('math_intern_postprocess') def math_intern_postprocess(text: str) -> str: extractor = Extractor() return extractor.extract_answer(text) class Extractor: def extract_matching_bracket(cls, target_str: str): if not target_str: return target_str current_nest_level = 1 for i, ch in enumerate(target_str): if ch == '{': current_nest_level += 1 elif ch == '}': current_nest_level -= 1 if current_nest_level == 0: break return target_str[:i] def clean(cls, target_str: str): opt = target_str.strip().replace('{{', '{').replace('}}', '}') if not opt: return opt if opt[-1] == '.' or opt[-1] == '。': return opt[:-1] return opt def extract_answer(cls, pred: str, extract_last_num=False): if pred.find('The final answer is ') >= 0: x = pred[pred.find('The final answer is ') + len('The final answer is '):] x = x[1:x.find('$.')] # print(x) return cls.clean(x) if pred.find('\n\nQuestion:') >= 0: pred = pred.split('\n\nQuestion:')[0] if pred.find('The answer is'): pred = pred[pred.find('The answer is') + len('The answer is'):] return cls.clean(pred) if pred.find('# Answer') >= 0: return cls.clean(pred[pred.find('# Answer') + len('# Answer'):]) if pred.find('The answer is:') >= 0: return cls.clean(pred[pred.find('The answer is:') + len('The answer is:'):]) if pred.find('####') >= 0: return cls.clean(pred[pred.find('####') + 4:]) left = '\\boxed{' if pred.find(left) >= 0: pred = pred[pred.find(left) + len(left):] return cls.clean(cls.extract_matching_bracket(pred)) if extract_last_num: nums = [] opt = '' def contain_digit(opt): for ch in opt: if ch.isdigit(): return True return False for ch in pred: if ch.isdigit() or ch in ' ,.': opt = opt + ch else: if contain_digit(opt): nums.append(opt) opt = '' if contain_digit(opt): return cls.clean(opt) if nums: return cls.clean(nums[-1]) return None def fix_fracs(string): substrs = string.split('\\frac') new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] for substr in substrs: new_str += '\\frac' if substr[0] == '{': new_str += substr else: try: assert len(substr) >= 2 except AssertionError: return string a = substr[0] b = substr[1] if b != '{': if len(substr) > 2: post_substr = substr[2:] new_str += '{' + a + '}{' + b + '}' + post_substr else: new_str += '{' + a + '}{' + b + '}' else: if len(substr) > 2: post_substr = substr[2:] new_str += '{' + a + '}' + b + post_substr else: new_str += '{' + a + '}' + b string = new_str return string def fix_a_slash_b(string): if len(string.split('/')) != 2: return string a = string.split('/')[0] b = string.split('/')[1] try: a = int(a) b = int(b) assert string == '{}/{}'.format(a, b) new_string = '\\frac{' + str(a) + '}{' + str(b) + '}' return new_string except AssertionError: return string def remove_right_units(string): # "\\text{ " only ever occurs (at least in the val set) if '\\text{ ' in string: splits = string.split('\\text{ ') assert len(splits) == 2 return splits[0] else: return string def fix_sqrt(string): if '\\sqrt' not in string: return string splits = string.split('\\sqrt') new_string = splits[0] for split in splits[1:]: if split[0] != '{': a = split[0] new_substr = '\\sqrt{' + a + '}' + split[1:] else: new_substr = '\\sqrt' + split new_string += new_substr return new_string def strip_string(string): # linebreaks string = string.replace('\n', '') # remove inverse spaces string = string.replace('\\!', '') # replace \\ with \ string = string.replace('\\\\', '\\') # replace tfrac and dfrac with frac string = string.replace('tfrac', 'frac') string = string.replace('dfrac', 'frac') # remove \left and \right string = string.replace('\\left', '') string = string.replace('\\right', '') # Remove circ (degrees) string = string.replace('^{\\circ}', '') string = string.replace('^\\circ', '') # remove dollar signs string = string.replace('\\$', '') # remove units (on the right) string = remove_right_units(string) # remove percentage string = string.replace('\\%', '') string = string.replace('\%', '') # noqa: W605 string = string.replace(' .', ' 0.') string = string.replace('{.', '{0.') # if empty, return empty string if len(string) == 0: return string if string[0] == '.': string = '0' + string # to consider: get rid of e.g. "k = " or "q = " at beginning if len(string.split('=')) == 2: if len(string.split('=')[0]) <= 2: string = string.split('=')[1] # fix sqrt3 --> sqrt{3} string = fix_sqrt(string) # remove spaces string = string.replace(' ', '') string = fix_fracs(string) # manually change 0.5 --> \frac{1}{2} if string == '0.5': string = '\\frac{1}{2}' string = fix_a_slash_b(string) string = string.replace('x \\in', '').strip() # noqa: W605 # a_b == a, a_{b} == a_b for bit conversion if string.find('_') >= 0: p = string.split('_') p[1] = p[1].replace('{', '').replace('}', '') string = '_'.join(p) # 10800 == 10,800; we only deal with single number if string.strip().find(' ') == -1 and string.find('(') == -1: string = string.replace(',', '') return string def is_equiv(str1, str2, verbose=False): if str1 is None and str2 is None: # print("WARNING: Both None") return False if str1 is None or str2 is None: return False try: ss1 = strip_string(str1) ss2 = strip_string(str2) return ss1 == ss2 except Exception: return str1 == str2