mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
fix bug of gsm8k_postprocess (#863)
* fix bug of gsm8k_postprocess * update postprocess --------- Co-authored-by: Lei Fei <SENSETIME\leifei1@cn3114002087l.domain.sensetime.com> Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
444d8d9507
commit
dd444685bb
@ -1,3 +1,7 @@
|
||||
# GONNA BE DEPRECATED, DON'T USE IT
|
||||
# The postprocessor has the assumption that the prompt is in the format of "Question:blabla"
|
||||
# This config does not follow the above assumption, thus deprecated
|
||||
|
||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
|
||||
@ -34,24 +35,10 @@ def gsm8k_dataset_postprocess(text: str) -> str:
|
||||
@TEXT_POSTPROCESSORS.register_module('gsm8k')
|
||||
def gsm8k_postprocess(text: str) -> str:
|
||||
text = text.split('Question:')[0]
|
||||
text = text.split(' ')[::-1]
|
||||
flag = False
|
||||
ret = ''
|
||||
for i in range(len(text)):
|
||||
s = text[i]
|
||||
for i in range(len(s)):
|
||||
if s[i].isdigit():
|
||||
flag = True
|
||||
ret = s
|
||||
break
|
||||
if flag:
|
||||
break
|
||||
ret1 = ''
|
||||
for i in range(len(ret)):
|
||||
# deal with potential float number
|
||||
if ret[i].isdigit() or ret[i] == '.':
|
||||
ret1 += ret[i]
|
||||
return ret1.strip('.')
|
||||
numbers = re.findall(r'\-?\d+\.\d+|\-?\d+', text)
|
||||
if not numbers:
|
||||
return 'NULL'
|
||||
return numbers[-1]
|
||||
|
||||
|
||||
class Gsm8kEvaluator(BaseEvaluator):
|
||||
|
Loading…
Reference in New Issue
Block a user