fix bug of gsm8k_postprocess (#863)

* fix bug of gsm8k_postprocess

* update postprocess

---------

Co-authored-by: Lei Fei <SENSETIME\leifei1@cn3114002087l.domain.sensetime.com>
Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
hailsham 2024-02-06 23:52:47 +08:00 committed by GitHub
parent 444d8d9507
commit dd444685bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 18 deletions

View File

@ -1,3 +1,7 @@
# GONNA BE DEPRECATED, DON'T USE IT
# The postprocessor has the assumption that the prompt is in the format of "Question:blabla"
# This config does not follow the above assumption, thus deprecated
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer

View File

@ -1,5 +1,6 @@
import json
import os
import re
from datasets import Dataset, DatasetDict
@ -34,24 +35,10 @@ def gsm8k_dataset_postprocess(text: str) -> str:
@TEXT_POSTPROCESSORS.register_module('gsm8k')
def gsm8k_postprocess(text: str) -> str:
text = text.split('Question:')[0]
text = text.split(' ')[::-1]
flag = False
ret = ''
for i in range(len(text)):
s = text[i]
for i in range(len(s)):
if s[i].isdigit():
flag = True
ret = s
break
if flag:
break
ret1 = ''
for i in range(len(ret)):
# deal with potential float number
if ret[i].isdigit() or ret[i] == '.':
ret1 += ret[i]
return ret1.strip('.')
numbers = re.findall(r'\-?\d+\.\d+|\-?\d+', text)
if not numbers:
return 'NULL'
return numbers[-1]
class Gsm8kEvaluator(BaseEvaluator):