import copy import json import os.path as osp import re from datasets import Dataset from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS from .base import BaseDataset def get_number(options): result_string = '' for i, option in enumerate(options, start=ord('A')): result_string += f'{chr(i)}. {option}\n' return result_string def get_circular_example(entry, id): """For given example, generate four circular examples.""" # Only 4 options is supported for current circular eval. circular_patterns = ['ABCD', 'BCDA', 'CDAB', 'DABC'] data = [] for c in circular_patterns: line = copy.deepcopy(entry) options = [] for i in range(4): options.append(line['options'][ord(c[i]) - ord('A')]) line['options'] = options line['answer'] = { c[0]: 'A', c[1]: 'B', c[2]: 'C', c[3]: 'D' }[line['answer']] line['answer'] = str(id) + '--' + line['answer'] + '--' + c line['question'] = line['question'].strip() + '\n' + get_number( line['options']) data.append(line) return data @LOAD_DATASET.register_module() class MathBenchDataset(BaseDataset): @staticmethod def load(path: str, name: str, with_circular: bool = True): """MathBenth Dataset. Args: path (str): Path of the mathbench dataset. name (str): Name of the target subset. with_circular (bool): Whether to create circular dataset for single choice question. Defaults to True. """ data = [] filename = osp.join(path, f'{name}.jsonl') with open(filename, 'r', encoding='utf-8') as infile: for id, line in enumerate(infile): entry = json.loads(line) if 'cloze' in name: data.append({ 'question': entry['question'].strip(), 'answer': entry['answer'].strip() }) else: if with_circular: data.extend(get_circular_example(entry, id)) else: question = entry['question'].strip( ) + '\n' + get_number(entry['options']) info = { 'question': question, 'answer': entry['answer'].strip() } # For PPL evaluation for i in range(4): info[chr(ord('A') + i)] = entry['options'][i].strip() data.append(info) dataset = Dataset.from_list(data) return dataset @TEXT_POSTPROCESSORS.register_module() def mathbench_postprocess(text: str, name: str) -> str: split = False ans = text if '_cn' in name: ans_line = ans.split('答案是') else: ans_line = ans.split('The answer is') if len(ans_line) != 1: ans = ans_line[1].strip() split = True output = re.sub(r'(\d),(\d)', r'\1\2', ans) numbers = re.findall(r'-?\d*\.?/?\d+|\d+', output) if numbers: return numbers[0] if split else numbers[-1] return ans