mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
153 lines
5.2 KiB
Python
153 lines
5.2 KiB
Python
![]() |
from collections import defaultdict
|
||
|
from mmengine import load
|
||
|
|
||
|
from ..utils.template import parse_string
|
||
|
from ..utils.format_load import format_load
|
||
|
from ..schema import ResponseDataSample
|
||
|
import ast
|
||
|
import numpy as np
|
||
|
|
||
|
class InstructEvaluator:
|
||
|
"""Instruct Following Evaluation
|
||
|
|
||
|
Args:
|
||
|
dataset_path(str): File path of evaluation dataset.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
dataset_path: str,
|
||
|
**kwargs,
|
||
|
) -> None:
|
||
|
self.dataset_path = dataset_path
|
||
|
|
||
|
def _load_dataset(self):
|
||
|
self.dataset = []
|
||
|
dataset = load(self.dataset_path)
|
||
|
|
||
|
for key in dataset.keys():
|
||
|
datum = dataset[key]
|
||
|
data_sample = self._process_response(datum)
|
||
|
|
||
|
self.dataset.append(
|
||
|
dict(
|
||
|
origin_prompt=datum["origin_prompt"],
|
||
|
response_data_sample=data_sample))
|
||
|
self.num_samples = len(self.dataset)
|
||
|
|
||
|
def _process_response(
|
||
|
self,
|
||
|
datum: dict,
|
||
|
) -> ResponseDataSample:
|
||
|
"""Process the response to needed format.
|
||
|
|
||
|
Args:
|
||
|
datum(dict): inputs.
|
||
|
|
||
|
Returns:
|
||
|
dict: Processed response data sample.
|
||
|
"""
|
||
|
|
||
|
# Dict with keyword-only arguments.
|
||
|
template = datum['template']
|
||
|
# Generated response.
|
||
|
pred_data = datum['prediction']
|
||
|
# Response of ground truth.
|
||
|
gt_data = datum['ground_truth']
|
||
|
meta_data = datum['meta_data']
|
||
|
|
||
|
return ResponseDataSample(
|
||
|
template=template, pred=pred_data, gt=gt_data, meta_data=meta_data)
|
||
|
|
||
|
def _evaluate(self, data_sample: dict) -> dict:
|
||
|
metrics_result = dict()
|
||
|
response_format = data_sample.meta_data['response_format']
|
||
|
if response_format == 'json':
|
||
|
pred_data = self.json_format_parse(data_sample)
|
||
|
else:
|
||
|
pred_data = self.string_format_parse(data_sample)
|
||
|
|
||
|
if pred_data is None:
|
||
|
# directly set to 0 for all metrics
|
||
|
metrics_result[f'{response_format}_format_metric'] = 0
|
||
|
metrics_result[f'{response_format}_args_em_metric'] = 0
|
||
|
return metrics_result
|
||
|
|
||
|
# Exact matching
|
||
|
metrics_result[f'{response_format}_format_metric'] = 1
|
||
|
metrics_result[f'{response_format}_args_em_metric'] = self.compute_args_em_metric(
|
||
|
gt_action=data_sample.gt['action'], pred_action=pred_data['action'],
|
||
|
gt_args=data_sample.gt['args'], pred_args=pred_data['args']
|
||
|
)
|
||
|
return metrics_result
|
||
|
|
||
|
def compute_args_em_metric(self, gt_action, pred_action, gt_args, pred_args):
|
||
|
cnt = 0.
|
||
|
if gt_action == pred_action:
|
||
|
cnt += 1.
|
||
|
num_args = len(gt_args) + 1 # 1 means action name match
|
||
|
for gt_key in gt_args:
|
||
|
pred_val = pred_args.get(gt_key, "")
|
||
|
if pred_val == gt_args[gt_key]:
|
||
|
cnt += 1.
|
||
|
return cnt / num_args
|
||
|
|
||
|
def string_format_parse(self, data_sample):
|
||
|
pred_data = data_sample.pred
|
||
|
template = data_sample.template
|
||
|
thought_start = template['thought_start']
|
||
|
thought_end = template['thought_end']
|
||
|
action_start = template['action_start']
|
||
|
action_end = template['action_end']
|
||
|
args_start = template['args_start']
|
||
|
args_end = template['args_end']
|
||
|
|
||
|
parse_template = thought_start + "{thought}" + thought_end \
|
||
|
+ action_start + "{action}" + action_end \
|
||
|
+ args_start + "{args}" + args_end
|
||
|
res = parse_string(parse_template, pred_data, allow_newline=True)
|
||
|
try:
|
||
|
if res is not None:
|
||
|
args = ast.literal_eval(res['args'].strip())
|
||
|
res['args'] = args if isinstance(args, dict) else {}
|
||
|
res['action'] = res['action'].strip()
|
||
|
return res
|
||
|
except:
|
||
|
return dict(thought=res['thought'], action=res['action'].strip(), args=dict())
|
||
|
|
||
|
def json_format_parse(self, data_sample):
|
||
|
try:
|
||
|
pred_data = format_load(data_sample.pred)
|
||
|
template = data_sample.template
|
||
|
new_data = dict()
|
||
|
new_data['thought'] = pred_data[template['thought']]
|
||
|
new_data['action'] = pred_data[template['action']]
|
||
|
args = pred_data[template['args']]
|
||
|
new_data['args'] = args if isinstance(args, dict) else {}
|
||
|
except Exception as e:
|
||
|
return None
|
||
|
|
||
|
return new_data
|
||
|
|
||
|
def evaluate(self):
|
||
|
self._load_dataset()
|
||
|
results_list = []
|
||
|
for data_sample in self.dataset:
|
||
|
metrics_result = self._evaluate(data_sample['response_data_sample'])
|
||
|
results_list.append(metrics_result)
|
||
|
return self._post_process(results_list)
|
||
|
|
||
|
def _post_process(self, results_list):
|
||
|
# list of dict to dict of list
|
||
|
results_dict = defaultdict(list)
|
||
|
{
|
||
|
results_dict[key].append(sub[key])
|
||
|
for sub in results_list for key in sub
|
||
|
}
|
||
|
metric_list = ['json_format_metric', 'json_args_em_metric',
|
||
|
'string_format_metric', 'string_args_em_metric']
|
||
|
for metric in metric_list:
|
||
|
results_dict[metric] = np.round(np.mean(results_dict[metric]), decimals=4)
|
||
|
return results_dict
|