2023-11-13 13:00:37 +08:00
|
|
|
import json
|
|
|
|
import os
|
|
|
|
|
|
|
|
from datasets import Dataset, DatasetDict
|
|
|
|
|
2023-10-27 20:31:22 +08:00
|
|
|
from opencompass.openicl import BaseEvaluator
|
2023-11-13 13:00:37 +08:00
|
|
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
|
|
|
|
|
|
|
|
from .base import BaseDataset
|
|
|
|
|
|
|
|
|
|
|
|
@LOAD_DATASET.register_module()
|
|
|
|
class GSM8KDataset(BaseDataset):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load(path):
|
|
|
|
datasets = {}
|
|
|
|
for split in ['train', 'test']:
|
|
|
|
split_path = os.path.join(path, split + '.jsonl')
|
|
|
|
dataset = []
|
|
|
|
with open(split_path, 'r', encoding='utf-8') as f:
|
|
|
|
for line in f:
|
|
|
|
line = json.loads(line.strip())
|
|
|
|
line['answer']
|
|
|
|
dataset.append(line)
|
|
|
|
datasets[split] = Dataset.from_list(dataset)
|
|
|
|
return DatasetDict(datasets)
|
2023-07-05 09:01:25 +08:00
|
|
|
|
|
|
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('gsm8k_dataset')
|
|
|
|
def gsm8k_dataset_postprocess(text: str) -> str:
|
|
|
|
return text.split('#### ')[1].replace(',', '')
|
|
|
|
|
|
|
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('gsm8k')
|
|
|
|
def gsm8k_postprocess(text: str) -> str:
|
|
|
|
text = text.split('\n\n')[0]
|
|
|
|
text = text.split(' ')[::-1]
|
|
|
|
flag = False
|
|
|
|
ret = ''
|
|
|
|
for i in range(len(text)):
|
|
|
|
s = text[i]
|
|
|
|
for i in range(len(s)):
|
|
|
|
if s[i].isdigit():
|
|
|
|
flag = True
|
|
|
|
ret = s
|
|
|
|
break
|
|
|
|
if flag:
|
|
|
|
break
|
|
|
|
ret1 = ''
|
|
|
|
for i in range(len(ret)):
|
2023-12-01 15:08:38 +08:00
|
|
|
# deal with potential float number
|
|
|
|
if ret[i].isdigit() or ret[i] == '.':
|
2023-07-05 09:01:25 +08:00
|
|
|
ret1 += ret[i]
|
2023-12-01 15:08:38 +08:00
|
|
|
return ret1.strip('.')
|
2023-10-27 20:31:22 +08:00
|
|
|
|
|
|
|
|
|
|
|
class Gsm8kEvaluator(BaseEvaluator):
|
|
|
|
|
|
|
|
def score(self, predictions, references):
|
|
|
|
if len(predictions) != len(references):
|
|
|
|
return {
|
|
|
|
'error': 'predictions and references have different '
|
|
|
|
'length'
|
|
|
|
}
|
|
|
|
correct = 0
|
|
|
|
count = 0
|
|
|
|
details = []
|
|
|
|
for i, j in zip(predictions, references):
|
2023-11-13 15:15:34 +08:00
|
|
|
detail = {'pred': i, 'answer': j, 'correct': False}
|
2023-10-27 20:31:22 +08:00
|
|
|
count += 1
|
|
|
|
if i == j:
|
|
|
|
correct += 1
|
|
|
|
detail['correct'] = True
|
|
|
|
details.append(detail)
|
|
|
|
result = {'accuracy': 100 * correct / count, 'details': details}
|
|
|
|
return result
|
2023-11-07 19:11:44 +08:00
|
|
|
|
|
|
|
|
|
|
|
class Gsm8kAgentEvaluator(BaseEvaluator):
|
|
|
|
"""Gsm8k agent evaluator for soft condition.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
action (str): Action for catching internal prediction.
|
|
|
|
Defaults to `PythonInterpreter`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, action: str = 'PythonInterpreter'):
|
|
|
|
self.action = action
|
|
|
|
|
2023-12-01 15:08:38 +08:00
|
|
|
def is_equal(self, pred, refer):
|
|
|
|
try:
|
|
|
|
if pred == refer or abs(float(pred) - int(refer)) < 1e-6:
|
|
|
|
return True
|
|
|
|
except Exception:
|
|
|
|
pass
|
|
|
|
return False
|
|
|
|
|
2023-11-07 19:11:44 +08:00
|
|
|
def soft_equal(self, pred, refer, step):
|
|
|
|
try:
|
|
|
|
soft_pred = step['result']['text']
|
2023-12-01 15:08:38 +08:00
|
|
|
if abs(float(soft_pred) - int(refer)) < 1e-6:
|
2023-11-07 19:11:44 +08:00
|
|
|
return True
|
|
|
|
except Exception:
|
|
|
|
# result might not exists
|
|
|
|
# text cannot convert to float
|
2023-12-01 15:08:38 +08:00
|
|
|
pass
|
2023-11-07 19:11:44 +08:00
|
|
|
return False
|
|
|
|
|
|
|
|
def get_action(self, step):
|
|
|
|
for s in step[::-1]:
|
|
|
|
if s['type'] == self.action:
|
|
|
|
return s
|
|
|
|
|
|
|
|
def score(self, predictions, references, steps):
|
|
|
|
"""Calculate accuracy."""
|
|
|
|
|
|
|
|
row_reasoning_scope = 0
|
|
|
|
action_scope = 0
|
|
|
|
code_scope = 0
|
|
|
|
reasoning_scope = 0
|
|
|
|
final_scope = 0
|
|
|
|
total = len(references)
|
|
|
|
for pred, refer, step in zip(predictions, references, steps):
|
|
|
|
# if final answer right
|
2023-12-01 15:08:38 +08:00
|
|
|
if self.is_equal(pred, refer):
|
2023-11-07 19:11:44 +08:00
|
|
|
if self.get_action(step):
|
|
|
|
final_scope += 1
|
|
|
|
else:
|
|
|
|
row_reasoning_scope += 1
|
|
|
|
else:
|
|
|
|
s = self.get_action(step)
|
|
|
|
if s:
|
|
|
|
action_scope += 1
|
|
|
|
if not s['errmsg']:
|
|
|
|
code_scope += 1
|
|
|
|
# whether action result is correct
|
|
|
|
reasoning_scope += self.soft_equal(pred, refer, s)
|
|
|
|
|
|
|
|
result = dict(
|
|
|
|
follow_acc=100 * (row_reasoning_scope + final_scope) / total,
|
|
|
|
reasoning_acc=100 *
|
|
|
|
(reasoning_scope + final_scope + row_reasoning_scope) / total,
|
|
|
|
code_acc=100 * (code_scope + final_scope) / total,
|
2023-12-11 17:42:53 +08:00
|
|
|
action_pct=100 * (action_scope + final_scope) / total,
|
2023-11-07 19:11:44 +08:00
|
|
|
)
|
|
|
|
return result
|