mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
fix bug of gsm8k_postprocess (#863)
* fix bug of gsm8k_postprocess * update postprocess --------- Co-authored-by: Lei Fei <SENSETIME\leifei1@cn3114002087l.domain.sensetime.com> Co-authored-by: Leymore <zfz-960727@163.com>
This commit is contained in:
parent
444d8d9507
commit
dd444685bb
@ -1,3 +1,7 @@
|
|||||||
|
# GONNA BE DEPRECATED, DON'T USE IT
|
||||||
|
# The postprocessor has the assumption that the prompt is in the format of "Question:blabla"
|
||||||
|
# This config does not follow the above assumption, thus deprecated
|
||||||
|
|
||||||
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
from opencompass.openicl.icl_prompt_template import PromptTemplate
|
||||||
from opencompass.openicl.icl_retriever import ZeroRetriever
|
from opencompass.openicl.icl_retriever import ZeroRetriever
|
||||||
from opencompass.openicl.icl_inferencer import GenInferencer
|
from opencompass.openicl.icl_inferencer import GenInferencer
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
from datasets import Dataset, DatasetDict
|
from datasets import Dataset, DatasetDict
|
||||||
|
|
||||||
@ -34,24 +35,10 @@ def gsm8k_dataset_postprocess(text: str) -> str:
|
|||||||
@TEXT_POSTPROCESSORS.register_module('gsm8k')
|
@TEXT_POSTPROCESSORS.register_module('gsm8k')
|
||||||
def gsm8k_postprocess(text: str) -> str:
|
def gsm8k_postprocess(text: str) -> str:
|
||||||
text = text.split('Question:')[0]
|
text = text.split('Question:')[0]
|
||||||
text = text.split(' ')[::-1]
|
numbers = re.findall(r'\-?\d+\.\d+|\-?\d+', text)
|
||||||
flag = False
|
if not numbers:
|
||||||
ret = ''
|
return 'NULL'
|
||||||
for i in range(len(text)):
|
return numbers[-1]
|
||||||
s = text[i]
|
|
||||||
for i in range(len(s)):
|
|
||||||
if s[i].isdigit():
|
|
||||||
flag = True
|
|
||||||
ret = s
|
|
||||||
break
|
|
||||||
if flag:
|
|
||||||
break
|
|
||||||
ret1 = ''
|
|
||||||
for i in range(len(ret)):
|
|
||||||
# deal with potential float number
|
|
||||||
if ret[i].isdigit() or ret[i] == '.':
|
|
||||||
ret1 += ret[i]
|
|
||||||
return ret1.strip('.')
|
|
||||||
|
|
||||||
|
|
||||||
class Gsm8kEvaluator(BaseEvaluator):
|
class Gsm8kEvaluator(BaseEvaluator):
|
||||||
|
Loading…
Reference in New Issue
Block a user