OpenCompass/tools/prediction_merger.py

102 lines
3.1 KiB
Python
Raw Permalink Normal View History

import argparse
import copy
import json
import os.path as osp
import mmengine
from mmengine.config import Config, ConfigDict
from opencompass.utils import build_dataset_from_cfg, get_infer_output_path
def parse_args():
parser = argparse.ArgumentParser(
description='Merge patitioned predictions')
parser.add_argument('config', help='Train config file path')
parser.add_argument('-w',
'--work-dir',
help='Work path, all the outputs will be '
'saved in this path, including the slurm logs, '
'the evaluation results, the summary results, etc.'
'If not specified, the work_dir will be set to '
'./outputs/default.',
default=None,
type=str)
args = parser.parse_args()
return args
class PredictionMerger:
""""""
def __init__(self, cfg: ConfigDict) -> None:
self.cfg = cfg
self.model_cfg = copy.deepcopy(self.cfg['model'])
self.dataset_cfg = copy.deepcopy(self.cfg['dataset'])
self.work_dir = self.cfg.get('work_dir')
def run(self):
filename = get_infer_output_path(
self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'predictions'))
root, ext = osp.splitext(filename)
partial_filename = root + '_0' + ext
if osp.exists(osp.realpath(filename)):
return
if not osp.exists(osp.realpath(partial_filename)):
print(f'{filename} not found')
return
# Load predictions
partial_filenames = []
if osp.exists(osp.realpath(filename)):
preds = mmengine.load(filename)
else:
preds, offset = {}, 0
i = 1
while osp.exists(osp.realpath(partial_filename)):
partial_filenames.append(osp.realpath(partial_filename))
_preds = mmengine.load(partial_filename)
partial_filename = root + f'_{i}' + ext
i += 1
for _o in range(len(_preds)):
preds[str(offset)] = _preds[str(_o)]
offset += 1
dataset = build_dataset_from_cfg(self.dataset_cfg)
if len(preds) != len(dataset.test):
print('length mismatch')
return
print(f'Merge {partial_filenames} to {filename}')
with open(filename, 'w', encoding='utf-8') as f:
json.dump(preds, f, indent=4, ensure_ascii=False)
def dispatch_tasks(cfg):
for model in cfg['models']:
for dataset in cfg['datasets']:
PredictionMerger({
'model': model,
'dataset': dataset,
'work_dir': cfg['work_dir']
}).run()
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set work_dir
if args.work_dir is not None:
cfg['work_dir'] = args.work_dir
else:
cfg.setdefault('work_dir', './outputs/default')
dispatch_tasks(cfg)
if __name__ == '__main__':
main()