mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Fix] Fix bugs of multiple rounds of inference when using mm_eval (#201)
This commit is contained in:
parent
4fc1701209
commit
3a46b6c64f
@ -6,6 +6,7 @@ import random
|
|||||||
import time
|
import time
|
||||||
from typing import List, Sequence
|
from typing import List, Sequence
|
||||||
|
|
||||||
|
import mmengine
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from mmengine.config import Config, ConfigDict
|
from mmengine.config import Config, ConfigDict
|
||||||
@ -75,8 +76,8 @@ class MultimodalInferTask:
|
|||||||
dataset_name = self.dataloader['dataset']['type']
|
dataset_name = self.dataloader['dataset']['type']
|
||||||
evaluator_name = self.evaluator[0]['type']
|
evaluator_name = self.evaluator[0]['type']
|
||||||
|
|
||||||
return osp.join(model_name,
|
return osp.join(self.cfg.work_dir, model_name, dataset_name,
|
||||||
f'{dataset_name}-{evaluator_name}.{file_extension}')
|
f'{evaluator_name}.{file_extension}')
|
||||||
|
|
||||||
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
|
def get_output_paths(self, file_extension: str = 'json') -> List[str]:
|
||||||
"""Get the path to the output file.
|
"""Get the path to the output file.
|
||||||
@ -90,7 +91,7 @@ class MultimodalInferTask:
|
|||||||
evaluator_name = self.evaluator[0]['type']
|
evaluator_name = self.evaluator[0]['type']
|
||||||
|
|
||||||
return [
|
return [
|
||||||
osp.join(model_name, dataset_name,
|
osp.join(self.cfg.work_dir, model_name, dataset_name,
|
||||||
f'{evaluator_name}.{file_extension}')
|
f'{evaluator_name}.{file_extension}')
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -134,7 +135,8 @@ class MultimodalInferTask:
|
|||||||
evaluator.process(data_samples)
|
evaluator.process(data_samples)
|
||||||
|
|
||||||
metrics = evaluator.evaluate(len(dataloader.dataset))
|
metrics = evaluator.evaluate(len(dataloader.dataset))
|
||||||
metrics_file = osp.join(cfg.work_dir, 'res.log')
|
metrics_file = self.get_output_paths()[0]
|
||||||
|
mmengine.mkdir_or_exist(osp.split(metrics_file)[0])
|
||||||
with open(metrics_file, 'w') as f:
|
with open(metrics_file, 'w') as f:
|
||||||
json.dump(metrics, f)
|
json.dump(metrics, f)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user