mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
78 lines
2.5 KiB
Python
78 lines
2.5 KiB
Python
![]() |
import re
|
||
|
|
||
|
|
||
|
def gsm8k_postprocess(text: str) -> str:
|
||
|
text = text.split(' ')[::-1]
|
||
|
flag = False
|
||
|
ret = ''
|
||
|
for i in range(len(text)):
|
||
|
s = text[i]
|
||
|
for i in range(len(s)):
|
||
|
if s[i].isdigit():
|
||
|
flag = True
|
||
|
ret = s
|
||
|
break
|
||
|
if flag:
|
||
|
break
|
||
|
ret1 = ''
|
||
|
for i in range(len(ret)):
|
||
|
if ret[i].isdigit():
|
||
|
ret1 += ret[i]
|
||
|
return ret1
|
||
|
|
||
|
|
||
|
def humaneval_postprocess(text: str) -> str:
|
||
|
text = '\n'.join(text.split('\n')[1:]).strip()
|
||
|
if '```' in text:
|
||
|
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
|
||
|
if len(blocks) == 0:
|
||
|
text = text.split('```')[1] # fall back to default strategy
|
||
|
else:
|
||
|
text = blocks[0] # fetch the first code block
|
||
|
if not text.startswith('\n'): # in case starting with ```python
|
||
|
text = text[max(text.find('\n') + 1, 0):]
|
||
|
if text.strip().startswith('from') or text.strip().startswith('import'):
|
||
|
def_idx = text.find('def')
|
||
|
if def_idx != -1:
|
||
|
text = text[max(text.find('\n', def_idx) + 1, 0):]
|
||
|
if text.strip().startswith('def'):
|
||
|
text = '\n'.join(text.split('\n')[1:])
|
||
|
if not text.startswith(' '):
|
||
|
if text.startswith(' '):
|
||
|
text = ' ' + text.lstrip()
|
||
|
else:
|
||
|
text = '\n'.join([' ' + line for line in text.split('\n')])
|
||
|
return text
|
||
|
|
||
|
|
||
|
def lcsts_postprocess(text: str) -> str:
|
||
|
text = text.strip()
|
||
|
text = text.replace('1. ', '') if text.startswith('1. ') else text
|
||
|
text = text.replace('- ', '') if text.startswith('- ') else text
|
||
|
text = text.strip('“,。!”')
|
||
|
return text
|
||
|
|
||
|
|
||
|
def mbpp_postprocess(text: str) -> str:
|
||
|
if text.startswith('Here'):
|
||
|
text = '\n'.join(text.split('\n')[1:]).strip()
|
||
|
if '```' in text:
|
||
|
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
|
||
|
if len(blocks) == 0:
|
||
|
text = text.split('```')[1] # fall back to default strategy
|
||
|
else:
|
||
|
text = blocks[0] # fetch the first code block
|
||
|
if not text.startswith('\n'): # in case starting with ```python
|
||
|
text = text[max(text.find('\n') + 1, 0):]
|
||
|
return text
|
||
|
|
||
|
|
||
|
def strategyqa_pred_postprocess(text: str) -> str:
|
||
|
if text.startswith('Here'):
|
||
|
text = '\n'.join(text.split('\n')[1:]).strip()
|
||
|
text = text.split('answer is ')[-1]
|
||
|
match = re.search(r'(yes|no)', text.lower())
|
||
|
if match:
|
||
|
return match.group(1)
|
||
|
return ''
|