mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
195 lines
6.9 KiB
Python
195 lines
6.9 KiB
Python
import argparse
|
|
import copy
|
|
import json
|
|
import os.path as osp
|
|
|
|
import mmengine
|
|
from mmengine.config import Config, ConfigDict
|
|
from mmengine.utils import mkdir_or_exist
|
|
from tqdm import tqdm
|
|
|
|
from opencompass.registry import TEXT_POSTPROCESSORS
|
|
from opencompass.utils import build_dataset_from_cfg, get_infer_output_path
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Run case analyzer')
|
|
parser.add_argument('config', help='Train config file path')
|
|
parser.add_argument(
|
|
'-f',
|
|
'--force',
|
|
help='Force to run the task even if the results already exist',
|
|
action='store_true',
|
|
default=False)
|
|
parser.add_argument('-w',
|
|
'--work-dir',
|
|
help='Work path, all the outputs will be '
|
|
'saved in this path, including the slurm logs, '
|
|
'the evaluation results, the summary results, etc.'
|
|
'If not specified, the work_dir will be set to '
|
|
'./outputs/default.',
|
|
default=None,
|
|
type=str)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
class BadcaseShower:
|
|
""""""
|
|
|
|
def __init__(self, cfg: ConfigDict) -> None:
|
|
|
|
self.cfg = cfg
|
|
self.model_cfg = copy.deepcopy(self.cfg['model'])
|
|
self.dataset_cfg = copy.deepcopy(self.cfg['dataset'])
|
|
self.work_dir = self.cfg.get('work_dir')
|
|
# Load Dataset
|
|
self.eval_cfg = self.dataset_cfg.get('eval_cfg')
|
|
self.ds_split = self.eval_cfg.get('ds_split', None)
|
|
self.ds_column = self.eval_cfg.get('ds_column')
|
|
|
|
def run(self):
|
|
filename = get_infer_output_path(
|
|
self.model_cfg, self.dataset_cfg,
|
|
osp.join(self.work_dir, 'predictions'))
|
|
root, ext = osp.splitext(filename)
|
|
partial_filename = root + '_0' + ext
|
|
|
|
if not osp.exists(osp.realpath(filename)) and not osp.exists(
|
|
osp.realpath(partial_filename)):
|
|
print(f'{filename} not found')
|
|
return
|
|
|
|
dataset = build_dataset_from_cfg(self.dataset_cfg)
|
|
# Postprocess dataset if necessary
|
|
if 'dataset_postprocessor' in self.eval_cfg:
|
|
|
|
def postprocess(sample):
|
|
s = sample[self.ds_column]
|
|
proc = TEXT_POSTPROCESSORS.get(
|
|
self.eval_cfg['dataset_postprocessor']['type'])
|
|
sample[self.ds_column] = proc(s)
|
|
return sample
|
|
|
|
dataset = dataset.map(postprocess)
|
|
|
|
# Load predictions
|
|
if osp.exists(osp.realpath(filename)):
|
|
preds = mmengine.load(filename)
|
|
else:
|
|
filename = partial_filename
|
|
preds, offset = {}, 0
|
|
i = 1
|
|
while osp.exists(osp.realpath(filename)):
|
|
_preds = mmengine.load(filename)
|
|
filename = root + f'_{i}' + ext
|
|
i += 1
|
|
for _o in range(len(_preds)):
|
|
preds[str(offset)] = _preds[str(_o)]
|
|
offset += 1
|
|
pred_strs = [preds[str(i)]['prediction'] for i in range(len(preds))]
|
|
|
|
# Postprocess predictions if necessary
|
|
if 'pred_postprocessor' in self.eval_cfg:
|
|
proc = TEXT_POSTPROCESSORS.get(
|
|
self.eval_cfg['pred_postprocessor']['type'])
|
|
pred_strs = [proc(s) for s in pred_strs]
|
|
|
|
if self.ds_split:
|
|
references = dataset[self.ds_split][self.ds_column]
|
|
else:
|
|
references = dataset[self.ds_column]
|
|
|
|
if len(pred_strs) != len(references):
|
|
print('length mismatch')
|
|
return
|
|
|
|
# combine cases
|
|
allcase, badcase = [], []
|
|
if 'in-context examples' in preds['0']:
|
|
# ppl eval
|
|
for i, (pred_str,
|
|
reference) in enumerate(zip(tqdm(pred_strs), references)):
|
|
ref_str = str(reference)
|
|
try:
|
|
pred_prompt = preds[str(i)]['label: ' +
|
|
pred_str]['testing input']
|
|
pred_PPL = preds[str(i)]['label: ' + pred_str]['PPL']
|
|
ref_prompt = preds[str(i)]['label: ' +
|
|
ref_str]['testing input']
|
|
ref_PPL = preds[str(i)]['label: ' + ref_str]['PPL']
|
|
except KeyError:
|
|
continue
|
|
item = {
|
|
'prediction_prompt': pred_prompt,
|
|
'prediction': pred_str,
|
|
'prediction_PPL': pred_PPL,
|
|
'reference_prompt': ref_prompt,
|
|
'reference': ref_str,
|
|
'reference_PPL': ref_PPL
|
|
}
|
|
if pred_str != ref_str:
|
|
badcase.append(item)
|
|
allcase.append(item)
|
|
else:
|
|
allcase.append(item)
|
|
|
|
else:
|
|
# gen eval
|
|
for i, (pred_str,
|
|
reference) in enumerate(zip(tqdm(pred_strs), references)):
|
|
ref_str = str(reference)
|
|
origin_prompt = preds[str(i)]['origin_prompt']
|
|
item = {
|
|
'origin_prompt': origin_prompt,
|
|
'prediction': pred_str,
|
|
'reference': ref_str
|
|
}
|
|
# FIXME: we now consider all cases as bad cases
|
|
badcase.append(item)
|
|
allcase.append(item)
|
|
|
|
# Save result
|
|
out_path = get_infer_output_path(
|
|
self.cfg['model'], self.cfg['dataset'],
|
|
osp.join(self.work_dir, 'case_analysis/bad'))
|
|
mkdir_or_exist(osp.split(out_path)[0])
|
|
with open(out_path, 'w', encoding='utf-8') as f:
|
|
json.dump(badcase, f, indent=4, ensure_ascii=False)
|
|
|
|
out_path = get_infer_output_path(
|
|
self.cfg['model'], self.cfg['dataset'],
|
|
osp.join(self.work_dir, 'case_analysis/all'))
|
|
mkdir_or_exist(osp.split(out_path)[0])
|
|
with open(out_path, 'w', encoding='utf-8') as f:
|
|
json.dump(allcase, f, indent=4, ensure_ascii=False)
|
|
|
|
|
|
def dispatch_tasks(cfg, force=False):
|
|
for model in cfg['models']:
|
|
for dataset in cfg['datasets']:
|
|
if force or not osp.exists(
|
|
get_infer_output_path(
|
|
model, dataset,
|
|
osp.join(cfg['work_dir'], 'case_analysis/all'))):
|
|
BadcaseShower({
|
|
'model': model,
|
|
'dataset': dataset,
|
|
'work_dir': cfg['work_dir']
|
|
}).run()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
cfg = Config.fromfile(args.config)
|
|
# set work_dir
|
|
if args.work_dir is not None:
|
|
cfg['work_dir'] = args.work_dir
|
|
else:
|
|
cfg.setdefault('work_dir', './outputs/default')
|
|
dispatch_tasks(cfg, force=args.force)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|