[Fix] Fix bugs of multiple rounds of inference when using mm_eval (#201)

This commit is contained in:
Yike Yuan 2023-08-16 11:15:11 +08:00 committed by GitHub
parent 4fc1701209
commit 3a46b6c64f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,6 +6,7 @@ import random
import time
from typing import List, Sequence
import mmengine
import torch
import torch.distributed as dist
from mmengine.config import Config, ConfigDict
@ -75,8 +76,8 @@ class MultimodalInferTask:
dataset_name = self.dataloader['dataset']['type']
evaluator_name = self.evaluator[0]['type']
return osp.join(model_name,
f'{dataset_name}-{evaluator_name}.{file_extension}')
return osp.join(self.cfg.work_dir, model_name, dataset_name,
f'{evaluator_name}.{file_extension}')
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
"""Get the path to the output file.
@ -90,7 +91,7 @@ class MultimodalInferTask:
evaluator_name = self.evaluator[0]['type']
return [
osp.join(model_name, dataset_name,
osp.join(self.cfg.work_dir, model_name, dataset_name,
f'{evaluator_name}.{file_extension}')
]
@ -134,7 +135,8 @@ class MultimodalInferTask:
evaluator.process(data_samples)
metrics = evaluator.evaluate(len(dataloader.dataset))
metrics_file = osp.join(cfg.work_dir, 'res.log')
metrics_file = self.get_output_paths()[0]
mmengine.mkdir_or_exist(osp.split(metrics_file)[0])
with open(metrics_file, 'w') as f:
json.dump(metrics, f)