2023-07-20 11:53:24 +08:00
|
|
|
import argparse
|
|
|
|
import copy
|
|
|
|
import json
|
2024-04-09 17:50:23 +08:00
|
|
|
import os
|
2023-07-20 11:53:24 +08:00
|
|
|
|
|
|
|
import mmengine
|
|
|
|
from mmengine.config import Config, ConfigDict
|
|
|
|
|
|
|
|
from opencompass.utils import build_dataset_from_cfg, get_infer_output_path
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
2023-08-25 16:00:26 +08:00
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description='Merge patitioned predictions')
|
2023-07-20 11:53:24 +08:00
|
|
|
parser.add_argument('config', help='Train config file path')
|
2024-04-09 17:50:23 +08:00
|
|
|
parser.add_argument('-w', '--work-dir', default=None, type=str)
|
|
|
|
parser.add_argument('-r', '--reuse', default='latest', type=str)
|
|
|
|
parser.add_argument('-c', '--clean', action='store_true')
|
2023-07-20 11:53:24 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
class PredictionMerger:
|
|
|
|
|
|
|
|
def __init__(self, cfg: ConfigDict) -> None:
|
|
|
|
self.cfg = cfg
|
|
|
|
self.model_cfg = copy.deepcopy(self.cfg['model'])
|
|
|
|
self.dataset_cfg = copy.deepcopy(self.cfg['dataset'])
|
|
|
|
self.work_dir = self.cfg.get('work_dir')
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
filename = get_infer_output_path(
|
|
|
|
self.model_cfg, self.dataset_cfg,
|
2024-04-09 17:50:23 +08:00
|
|
|
os.path.join(self.work_dir, 'predictions'))
|
|
|
|
root, ext = os.path.splitext(filename)
|
2023-07-20 11:53:24 +08:00
|
|
|
partial_filename = root + '_0' + ext
|
|
|
|
|
2024-04-09 17:50:23 +08:00
|
|
|
if os.path.exists(os.path.realpath(filename)):
|
2023-07-20 11:53:24 +08:00
|
|
|
return
|
|
|
|
|
2024-04-09 17:50:23 +08:00
|
|
|
if not os.path.exists(os.path.realpath(partial_filename)):
|
2023-07-20 11:53:24 +08:00
|
|
|
print(f'{filename} not found')
|
|
|
|
return
|
|
|
|
|
|
|
|
# Load predictions
|
|
|
|
partial_filenames = []
|
2024-04-09 17:50:23 +08:00
|
|
|
preds, offset = {}, 0
|
|
|
|
i = 1
|
|
|
|
while os.path.exists(os.path.realpath(partial_filename)):
|
|
|
|
partial_filenames.append(os.path.realpath(partial_filename))
|
|
|
|
_preds = mmengine.load(partial_filename)
|
|
|
|
partial_filename = root + f'_{i}' + ext
|
|
|
|
i += 1
|
|
|
|
for _o in range(len(_preds)):
|
|
|
|
preds[str(offset)] = _preds[str(_o)]
|
|
|
|
offset += 1
|
2023-07-20 11:53:24 +08:00
|
|
|
|
|
|
|
dataset = build_dataset_from_cfg(self.dataset_cfg)
|
|
|
|
if len(preds) != len(dataset.test):
|
|
|
|
print('length mismatch')
|
|
|
|
return
|
|
|
|
|
|
|
|
print(f'Merge {partial_filenames} to {filename}')
|
|
|
|
with open(filename, 'w', encoding='utf-8') as f:
|
|
|
|
json.dump(preds, f, indent=4, ensure_ascii=False)
|
|
|
|
|
2024-04-09 17:50:23 +08:00
|
|
|
if self.cfg['clean']:
|
|
|
|
for partial_filename in partial_filenames:
|
|
|
|
print(f'Remove {partial_filename}')
|
|
|
|
os.remove(partial_filename)
|
|
|
|
|
2023-07-20 11:53:24 +08:00
|
|
|
|
|
|
|
def dispatch_tasks(cfg):
|
|
|
|
for model in cfg['models']:
|
|
|
|
for dataset in cfg['datasets']:
|
|
|
|
PredictionMerger({
|
|
|
|
'model': model,
|
|
|
|
'dataset': dataset,
|
2024-04-09 17:50:23 +08:00
|
|
|
'work_dir': cfg['work_dir'],
|
|
|
|
'clean': cfg['clean']
|
2023-07-20 11:53:24 +08:00
|
|
|
}).run()
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
args = parse_args()
|
|
|
|
cfg = Config.fromfile(args.config)
|
|
|
|
# set work_dir
|
|
|
|
if args.work_dir is not None:
|
|
|
|
cfg['work_dir'] = args.work_dir
|
|
|
|
else:
|
|
|
|
cfg.setdefault('work_dir', './outputs/default')
|
2024-04-09 17:50:23 +08:00
|
|
|
|
|
|
|
if args.reuse:
|
|
|
|
if args.reuse == 'latest':
|
|
|
|
if not os.path.exists(cfg.work_dir) or not os.listdir(
|
|
|
|
cfg.work_dir):
|
|
|
|
print('No previous results to reuse!')
|
|
|
|
return
|
|
|
|
else:
|
|
|
|
dirs = os.listdir(cfg.work_dir)
|
|
|
|
dir_time_str = sorted(dirs)[-1]
|
|
|
|
else:
|
|
|
|
dir_time_str = args.reuse
|
|
|
|
cfg['work_dir'] = os.path.join(cfg.work_dir, dir_time_str)
|
|
|
|
|
|
|
|
cfg['clean'] = args.clean
|
|
|
|
|
2023-07-20 11:53:24 +08:00
|
|
|
dispatch_tasks(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|