OpenCompass/opencompass/tasks/openicl_eval.py

231 lines
8.6 KiB
Python
Raw Normal View History

2023-07-05 10:33:12 +08:00
import argparse
import fnmatch
2023-07-05 10:33:12 +08:00
import os.path as osp
import time
from collections import Counter
from inspect import signature
2023-07-05 10:33:12 +08:00
from typing import Optional
import mmengine
from mmengine.config import Config, ConfigDict
from mmengine.utils import mkdir_or_exist
from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS,
TEXT_POSTPROCESSORS)
from opencompass.tasks.base import BaseTask
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
get_infer_output_path, get_logger,
task_abbr_from_cfg)
2023-07-05 10:33:12 +08:00
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
class OpenICLEvalTask(BaseTask):
"""OpenICL Evaluation Task.
This task is used to evaluate the metric between predictions and
references.
"""
name_prefix = 'OpenICLEval'
log_subdir = 'logs/eval'
output_subdir = 'results'
def __init__(self, cfg: ConfigDict):
super().__init__(cfg)
self.num_gpus = 0
self.logger = get_logger()
def get_command(self, cfg_path, template):
script_path = __file__
command = f'python3 {script_path} {cfg_path}'
return template.format(task_cmd=command)
2023-07-05 10:33:12 +08:00
def run(self):
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
for dataset_cfg in dataset_cfgs:
self.model_cfg = model_cfg
self.dataset_cfg = dataset_cfg
# Load Dataset
self.eval_cfg = self.dataset_cfg.get('eval_cfg')
self.output_column = dataset_cfg['reader_cfg']['output_column']
# overwrite postprocessor if the model has specified one
ds_abbr = dataset_abbr_from_cfg(self.dataset_cfg)
model_postprocessors = self.model_cfg.get(
'pred_postprocessor', {})
for pattern in model_postprocessors.keys():
if fnmatch.fnmatch(ds_abbr, pattern):
self.eval_cfg[
'pred_postprocessor'] = model_postprocessors[
pattern] # noqa
break
2023-07-05 10:33:12 +08:00
out_path = get_infer_output_path(
self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'results'))
if osp.exists(out_path):
continue
self._score()
def _score(self):
test_set = build_dataset_from_cfg(self.dataset_cfg).test
# Postprocess dataset if necessary
if 'dataset_postprocessor' in self.eval_cfg:
proc = self.eval_cfg['dataset_postprocessor']['type']
if isinstance(proc, str):
proc = TEXT_POSTPROCESSORS.get(proc)
2023-07-05 10:33:12 +08:00
def postprocess(sample):
s = sample[self.output_column]
sample[self.output_column] = proc(s)
return sample
test_set = test_set.map(postprocess)
# Load predictions
filename = get_infer_output_path(
self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'predictions'))
# in case the prediction is partial
root, ext = osp.splitext(filename)
partial_filename = root + '_0' + ext
# Get sc_size if use Self-Consistency
sc_size = self.eval_cfg.get('sc_size')
2023-07-05 10:33:12 +08:00
if not osp.exists(osp.realpath(filename)) and not osp.exists(
osp.realpath(partial_filename)):
result = {'error': 'No predictions found.'}
else:
if osp.exists(osp.realpath(filename)):
preds = mmengine.load(filename)
preds = [preds[str(i)] for i in range(len(preds))]
2023-07-05 10:33:12 +08:00
else:
filename = partial_filename
preds = []
2023-07-05 10:33:12 +08:00
i = 1
while osp.exists(osp.realpath(filename)):
sub_preds = mmengine.load(filename)
preds.extend(
[sub_preds[str(i)] for i in range(len(sub_preds))])
2023-07-05 10:33:12 +08:00
filename = root + f'_{i}' + ext
i += 1
preds = {k: [pred[k] for pred in preds] for k in preds[0]}
pred_strs = preds.pop('prediction')
2023-07-05 10:33:12 +08:00
if ('pred_role' in self.eval_cfg
and 'meta_template' in self.model_cfg
and not MODELS.get(self.model_cfg['type']).is_api):
# Create a prompt template for role config parsing
from opencompass.models.base import LMTemplateParser
parser = LMTemplateParser(self.model_cfg['meta_template'])
role = parser.roles[self.eval_cfg['pred_role']]
if sc_size is not None:
for pred in pred_strs:
if not isinstance(pred, list):
raise TypeError(
'The prediction for Self-Consistency'
'must be list.')
pred_strs.append([
self._extract_role_pred(sc_pred,
role.get('begin', None),
role.get('end', None))
for sc_pred in pred
])
else:
pred_strs = [
self._extract_role_pred(pred, role.get('begin', None),
role.get('end', None))
for pred in pred_strs
]
2023-07-05 10:33:12 +08:00
# Postprocess predictions if necessary
if 'pred_postprocessor' in self.eval_cfg:
kwargs = self.eval_cfg['pred_postprocessor']
proc = kwargs.pop('type')
if isinstance(proc, str):
proc = TEXT_POSTPROCESSORS.get(proc)
if sc_size is not None:
pred_strs = [[proc(s, **kwargs) for s in preds]
for preds in pred_strs]
else:
pred_strs = [proc(s, **kwargs) for s in pred_strs]
2023-07-05 10:33:12 +08:00
# Get majority voting predictions if use self-consistency
if sc_size is not None:
pred_strs = [
Counter(s).most_common(1)[0][0] for s in pred_strs
]
2023-07-05 10:33:12 +08:00
icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator'])
preds['predictions'] = pred_strs
preds['references'] = test_set[self.output_column]
preds = {
k: preds[k]
for k in signature(icl_evaluator.score).parameters
}
result = icl_evaluator.score(**preds)
2023-07-05 10:33:12 +08:00
if 'error' in result:
self.logger.error(
f'Task {task_abbr_from_cfg(self.cfg)}: {result["error"]}')
return
else:
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}: {result}')
2023-07-05 10:33:12 +08:00
# Save result
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'results'))
mkdir_or_exist(osp.split(out_path)[0])
mmengine.dump(result, out_path)
def _extract_role_pred(self, s: str, begin_str: Optional[str],
end_str: Optional[str]) -> str:
"""Extract the role prediction from the full prediction string. The
role prediction may be the substring between the begin and end string.
Args:
s (str): Full prediction string.
begin_str (str): The beginning string of the role
end_str (str): The ending string of the role.
Returns:
str: The extracted role prediction.
"""
start = 0
end = len(s)
if begin_str:
begin_idx = s.find(begin_str)
if begin_idx != -1:
start = begin_idx + len(begin_str)
if end_str:
# TODO: Support calling tokenizer for the accurate eos token
# and avoid such hardcode
end_idx = s.find(end_str[:1], start)
if end_idx != -1:
end = end_idx
return s[start:end]
def parse_args():
parser = argparse.ArgumentParser(description='Score Calculator')
parser.add_argument('config', help='Config file path')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
cfg = Config.fromfile(args.config)
start_time = time.time()
inferencer = OpenICLEvalTask(cfg)
inferencer.run()
end_time = time.time()
get_logger().info(f'time elapsed: {end_time - start_time:.2f}s')