2023-10-27 20:31:22 +08:00
|
|
|
from opencompass.openicl import BaseEvaluator
|
2023-07-05 09:01:25 +08:00
|
|
|
from opencompass.registry import TEXT_POSTPROCESSORS
|
|
|
|
|
|
|
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('gsm8k_dataset')
|
|
|
|
def gsm8k_dataset_postprocess(text: str) -> str:
|
|
|
|
return text.split('#### ')[1].replace(',', '')
|
|
|
|
|
|
|
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('gsm8k')
|
|
|
|
def gsm8k_postprocess(text: str) -> str:
|
|
|
|
text = text.split('\n\n')[0]
|
|
|
|
text = text.split(' ')[::-1]
|
|
|
|
flag = False
|
|
|
|
ret = ''
|
|
|
|
for i in range(len(text)):
|
|
|
|
s = text[i]
|
|
|
|
for i in range(len(s)):
|
|
|
|
if s[i].isdigit():
|
|
|
|
flag = True
|
|
|
|
ret = s
|
|
|
|
break
|
|
|
|
if flag:
|
|
|
|
break
|
|
|
|
ret1 = ''
|
|
|
|
for i in range(len(ret)):
|
|
|
|
if ret[i].isdigit():
|
|
|
|
ret1 += ret[i]
|
|
|
|
return ret1
|
2023-10-27 20:31:22 +08:00
|
|
|
|
|
|
|
|
|
|
|
class Gsm8kEvaluator(BaseEvaluator):
|
|
|
|
|
|
|
|
def score(self, predictions, references):
|
|
|
|
if len(predictions) != len(references):
|
|
|
|
return {
|
|
|
|
'error': 'predictions and references have different '
|
|
|
|
'length'
|
|
|
|
}
|
|
|
|
correct = 0
|
|
|
|
count = 0
|
|
|
|
details = []
|
|
|
|
for i, j in zip(predictions, references):
|
|
|
|
detail = {'pred': i, 'answers': j, 'correct': False}
|
|
|
|
count += 1
|
|
|
|
if i == j:
|
|
|
|
correct += 1
|
|
|
|
detail['correct'] = True
|
|
|
|
details.append(detail)
|
|
|
|
result = {'accuracy': 100 * correct / count, 'details': details}
|
|
|
|
return result
|