diff --git a/configs/datasets/gsm8k/gsm8k_gen_e9e91e.py b/configs/datasets/gsm8k/gsm8k_gen_e9e91e.py index ba9e07f5..f3c0eb12 100644 --- a/configs/datasets/gsm8k/gsm8k_gen_e9e91e.py +++ b/configs/datasets/gsm8k/gsm8k_gen_e9e91e.py @@ -1,3 +1,7 @@ +# GONNA BE DEPRECATED, DON'T USE IT +# The postprocessor has the assumption that the prompt is in the format of "Question:blabla" +# This config does not follow the above assumption, thus deprecated + from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_inferencer import GenInferencer diff --git a/opencompass/datasets/gsm8k.py b/opencompass/datasets/gsm8k.py index f8146da4..a3baaff8 100644 --- a/opencompass/datasets/gsm8k.py +++ b/opencompass/datasets/gsm8k.py @@ -1,5 +1,6 @@ import json import os +import re from datasets import Dataset, DatasetDict @@ -34,24 +35,10 @@ def gsm8k_dataset_postprocess(text: str) -> str: @TEXT_POSTPROCESSORS.register_module('gsm8k') def gsm8k_postprocess(text: str) -> str: text = text.split('Question:')[0] - text = text.split(' ')[::-1] - flag = False - ret = '' - for i in range(len(text)): - s = text[i] - for i in range(len(s)): - if s[i].isdigit(): - flag = True - ret = s - break - if flag: - break - ret1 = '' - for i in range(len(ret)): - # deal with potential float number - if ret[i].isdigit() or ret[i] == '.': - ret1 += ret[i] - return ret1.strip('.') + numbers = re.findall(r'\-?\d+\.\d+|\-?\d+', text) + if not numbers: + return 'NULL' + return numbers[-1] class Gsm8kEvaluator(BaseEvaluator):