mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] Support multi-modal evaluation on MME benchmark. (#197)
* [Feat] Support multi-modal evaluation on MME benchmark. * [Fix] Remove debug code. * [Fix] Remove redundant codes and add type hints. * [Fix] Rename in config. * [Fix] Rebase main. * [Fix] Fix isort and yapf conflict.
This commit is contained in:
parent
3b29aaee2b
commit
a6552224cb
43
configs/multimodal/minigpt_4/minigpt_4_7b_mme.py
Normal file
43
configs/multimodal/minigpt_4/minigpt_4_7b_mme.py
Normal file
@ -0,0 +1,43 @@
|
||||
from opencompass.multimodal.models.minigpt_4 import (MiniGPT4MMEPostProcessor, MiniGPT4MMEPromptConstructor)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.LoadImageFromFile'),
|
||||
dict(type='mmpretrain.ToPIL', to_rgb=True),
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(224, 224),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'answer', 'task'
|
||||
])
|
||||
]
|
||||
|
||||
dataset = dict(type='opencompass.MMEDataset',
|
||||
data_dir='/path/to/MME',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
minigpt_4_mme_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||
|
||||
# model settings
|
||||
minigpt_4_model = dict(
|
||||
type='minigpt-4',
|
||||
low_resource=False,
|
||||
llama_model='/path/to/vicuna/',
|
||||
prompt_constructor=dict(type=MiniGPT4MMEPromptConstructor),
|
||||
post_processor=dict(type=MiniGPT4MMEPostProcessor))
|
||||
|
||||
# evaluation settings
|
||||
minigpt_4_mme_evaluator = [
|
||||
dict(type='opencompass.MMEMetric')
|
||||
]
|
||||
|
||||
minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa
|
@ -1,4 +1,5 @@
|
||||
from .dump_results import DumpResults
|
||||
from .mme_score import MMEMetric
|
||||
from .seedbench import SEEDBenchAcc
|
||||
|
||||
__all__ = ['DumpResults', 'SEEDBenchAcc']
|
||||
__all__ = ['DumpResults', 'SEEDBenchAcc', 'MMEMetric']
|
||||
|
92
opencompass/metrics/mme_score.py
Normal file
92
opencompass/metrics/mme_score.py
Normal file
@ -0,0 +1,92 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
from opencompass.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class MMEMetric(BaseMetric):
|
||||
"""Dump model's prediction to a file.
|
||||
|
||||
Args:
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Default: None.
|
||||
"""
|
||||
|
||||
task_dict = {
|
||||
'Perception': [
|
||||
'existence', 'count', 'position', 'color', 'posters', 'celebrity',
|
||||
'scene', 'landmark', 'artwork', 'OCR'
|
||||
],
|
||||
'Cognition': [
|
||||
'commonsense_reasoning', 'numerical_calculation',
|
||||
'text_translation', 'code_reasoning'
|
||||
]
|
||||
} # noqa
|
||||
|
||||
def __init__(self,
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device, prefix)
|
||||
|
||||
def process(self, data_batch, data_samples) -> None:
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
result['img_path'] = data_sample['img_path']
|
||||
result['task'] = data_sample['task']
|
||||
result['pred'] = 1 if data_sample['answer'].lower(
|
||||
) == data_sample['pred_answer'].lower() else 0
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: list) -> dict:
|
||||
|
||||
# reorganize results
|
||||
record = dict()
|
||||
for task in (self.task_dict['Perception'] +
|
||||
self.task_dict['Cognition']):
|
||||
record[task] = defaultdict(int)
|
||||
for sample in results:
|
||||
record[sample['task']][sample['img_path']] += sample['pred']
|
||||
|
||||
# compute subtask score
|
||||
metric = dict()
|
||||
for task in (self.task_dict['Perception'] +
|
||||
self.task_dict['Cognition']):
|
||||
single_sum, double_sum = 0., 0.
|
||||
for v in record[task].values():
|
||||
assert 0 <= v <= 2
|
||||
if v == 2:
|
||||
single_sum += 2
|
||||
double_sum += 1
|
||||
elif v == 1:
|
||||
single_sum += 1
|
||||
acc = single_sum / 2 / len(record[task])
|
||||
acc_plus = double_sum / len(record[task])
|
||||
|
||||
metric[task] = {
|
||||
'acc': acc,
|
||||
'acc_plus': acc_plus,
|
||||
'score': 100 * (acc + acc_plus)
|
||||
}
|
||||
|
||||
# compute overall score
|
||||
score = 0
|
||||
for task in self.task_dict['Perception']:
|
||||
score += metric[task]['score']
|
||||
metric['Perception'] = score
|
||||
|
||||
score = 0
|
||||
for task in self.task_dict['Cognition']:
|
||||
score += metric[task]['score']
|
||||
metric['Cognition'] = score
|
||||
|
||||
metric['Overall'] = metric['Perception'] + metric['Cognition']
|
||||
|
||||
return metric
|
@ -1,4 +1,5 @@
|
||||
from .mmbench import MMBenchDataset
|
||||
from .mme import MMEDataset
|
||||
from .seedbench import SEEDBenchDataset
|
||||
|
||||
__all__ = ['MMBenchDataset', 'SEEDBenchDataset']
|
||||
__all__ = ['MMBenchDataset', 'SEEDBenchDataset', 'MMEDataset']
|
||||
|
@ -40,7 +40,7 @@ class MMBenchDataset(Dataset):
|
||||
def __len__(self) -> None:
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx: str) -> dict:
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
index = self.df.iloc[idx]['index']
|
||||
image = self.df.iloc[idx]['image']
|
||||
image = decode_base64_to_image(image)
|
||||
|
74
opencompass/multimodal/datasets/mme.py
Normal file
74
opencompass/multimodal/datasets/mme.py
Normal file
@ -0,0 +1,74 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from mmengine.dataset import Compose
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from opencompass.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MMEDataset(Dataset):
|
||||
"""Dataset to load MME dataset.
|
||||
|
||||
Args:
|
||||
data_dir (str): The path of the dataset.
|
||||
pipeline (List[dict]): The data augmentation.
|
||||
"""
|
||||
tasks = [
|
||||
'artwork', 'celebrity', 'code_reasoning', 'color',
|
||||
'commonsense_reasoning', 'count', 'existence', 'landmark',
|
||||
'numerical_calculation', 'OCR', 'position', 'posters', 'scene',
|
||||
'text_translation'
|
||||
]
|
||||
sub_dir_name = ('images', 'questions_answers_YN')
|
||||
|
||||
def __init__(self, data_dir: str, pipeline: List[dict]) -> None:
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.load_data(data_dir)
|
||||
|
||||
def load_data(self, data_dir: str):
|
||||
self.data_list = []
|
||||
image_dir, question_dir = self.sub_dir_name
|
||||
for task in self.tasks:
|
||||
if os.path.exists(os.path.join(data_dir, task, question_dir)):
|
||||
q_list = os.listdir(os.path.join(data_dir, task, question_dir))
|
||||
i_list = os.listdir(os.path.join(data_dir, task, image_dir))
|
||||
q_prefix = os.path.join(data_dir, task, question_dir)
|
||||
i_prefix = os.path.join(data_dir, task, image_dir)
|
||||
else:
|
||||
fn_list = os.listdir(os.path.join(data_dir, task))
|
||||
q_list = [fn for fn in fn_list if '.txt' in fn]
|
||||
i_list = [fn for fn in fn_list if fn not in q_list]
|
||||
q_prefix = i_prefix = os.path.join(data_dir, task)
|
||||
|
||||
q_list.sort()
|
||||
i_list.sort()
|
||||
assert len(q_list) == len(i_list)
|
||||
for q_fn, i_fn in zip(q_list, i_list):
|
||||
assert q_fn.split('.')[0] == i_fn.split('.')[0]
|
||||
q_path = os.path.join(q_prefix, q_fn)
|
||||
image_path = os.path.join(i_prefix, i_fn)
|
||||
with open(q_path, 'r') as f:
|
||||
q1, a1 = f.readline().strip().split('\t')
|
||||
q2, a2 = f.readline().strip().split('\t')
|
||||
self.data_list.append({
|
||||
'img_path': image_path,
|
||||
'question': q1,
|
||||
'answer': a1,
|
||||
'task': task
|
||||
})
|
||||
self.data_list.append({
|
||||
'img_path': image_path,
|
||||
'question': q2,
|
||||
'answer': a2,
|
||||
'task': task
|
||||
})
|
||||
|
||||
def __len__(self) -> None:
|
||||
return len(self.data_list)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
data_sample = self.data_list[idx]
|
||||
data_sample = self.pipeline(data_sample)
|
||||
return data_sample
|
@ -1,15 +1,17 @@
|
||||
from .minigpt_4 import MiniGPT4Inferencer
|
||||
from .post_processor import (MiniGPT4COCOCaptionPostProcessor,
|
||||
MiniGPT4MMBenchPostProcessor,
|
||||
MiniGPT4MMEPostProcessor,
|
||||
MiniGPT4ScienceQAPostProcessor,
|
||||
MiniGPT4VQAPostProcessor,
|
||||
MiniGPT4VSRPostProcessor)
|
||||
from .prompt_constructor import MiniGPT4VSRPromptConstructor # noqa
|
||||
from .prompt_constructor import (MiniGPT4COCOCaotionPromptConstructor,
|
||||
MiniGPT4MMBenchPromptConstructor,
|
||||
MiniGPT4MMEPromptConstructor,
|
||||
MiniGPT4ScienceQAPromptConstructor,
|
||||
MiniGPT4SEEDBenchPromptConstructor,
|
||||
MiniGPT4VQAPromptConstructor,
|
||||
MiniGPT4VSRPromptConstructor)
|
||||
MiniGPT4VQAPromptConstructor)
|
||||
|
||||
__all__ = [
|
||||
'MiniGPT4Inferencer', 'MiniGPT4MMBenchPostProcessor',
|
||||
@ -17,5 +19,6 @@ __all__ = [
|
||||
'MiniGPT4COCOCaptionPostProcessor', 'MiniGPT4ScienceQAPromptConstructor',
|
||||
'MiniGPT4ScienceQAPostProcessor', 'MiniGPT4VQAPromptConstructor',
|
||||
'MiniGPT4VQAPostProcessor', 'MiniGPT4VSRPostProcessor',
|
||||
'MiniGPT4VSRPromptConstructor', 'MiniGPT4SEEDBenchPromptConstructor'
|
||||
'MiniGPT4VSRPromptConstructor', 'MiniGPT4SEEDBenchPromptConstructor',
|
||||
'MiniGPT4MMEPostProcessor', 'MiniGPT4MMEPromptConstructor'
|
||||
]
|
||||
|
@ -119,3 +119,24 @@ class MiniGPT4VSRPostProcessor:
|
||||
if len(output_text) > 0:
|
||||
output_text = output_text[0].lower()
|
||||
return output_text
|
||||
|
||||
|
||||
class MiniGPT4MMEPostProcessor(MiniGPT4MMBenchPostProcessor):
|
||||
""""Post processor for MiniGPT-4 on MME."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||
response = super().__call__(output_token, tokenizer)
|
||||
# extract yes or no, copy from MME official evaluation script
|
||||
prefix_pred_ans = response[:4].lower()
|
||||
|
||||
if 'yes' in prefix_pred_ans:
|
||||
pred_label = 'yes'
|
||||
elif 'no' in prefix_pred_ans:
|
||||
pred_label = 'no'
|
||||
else:
|
||||
pred_label = 'other'
|
||||
|
||||
return pred_label
|
||||
|
@ -7,8 +7,8 @@ class MiniGPT4MMBenchPromptConstructor:
|
||||
"""Prompt constructor for MiniGPT-4 on MMBench.
|
||||
|
||||
Args:
|
||||
image_prompt (str): Image prompt.
|
||||
reply_prompt (str): Reply prompt.
|
||||
image_prompt (str): Image prompt. Defaults to `''`.
|
||||
reply_prompt (str): Reply prompt. Defaults to `''`.
|
||||
"""
|
||||
|
||||
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
|
||||
@ -138,3 +138,50 @@ class MiniGPT4SEEDBenchPromptConstructor(MiniGPT4MMBenchPromptConstructor):
|
||||
question = questions[0]
|
||||
prompt = self.image_prompt + ' ' + question + ' ' + self.reply_prompt
|
||||
return prompt
|
||||
|
||||
|
||||
class MiniGPT4MMEPromptConstructor:
|
||||
"""Prompt constructor for MiniGPT-4 on MME.
|
||||
|
||||
Args:
|
||||
image_prompt (str): Image prompt. Defaults to `''`.
|
||||
reply_prompt (str): Reply prompt. Defaults to `''`.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.system_prompt = (
|
||||
'Give the following image: <Img>ImageContent</Img>.'
|
||||
'You will be able to see the image once I provide it to you.'
|
||||
'Please answer my questions.')
|
||||
self.sep = '###'
|
||||
|
||||
def __call__(self, inputs: dict) -> dict:
|
||||
"""Construct prompt.
|
||||
|
||||
Args:
|
||||
inputs (dict): Input data containing image and data_samples.
|
||||
|
||||
Returns:
|
||||
dict: A dict containing prompt, images and data_samples.
|
||||
"""
|
||||
data_samples = inputs['data_samples']
|
||||
prompt = self._process(data_samples)
|
||||
inputs.update({'prompt': prompt})
|
||||
|
||||
return inputs
|
||||
|
||||
def _process(self, data_samples: List[DataSample]) -> str:
|
||||
"""Process data sample to prompt.
|
||||
|
||||
Args:
|
||||
data_samples (List[DataSample]): A list of data_samples.
|
||||
|
||||
Returns:
|
||||
str: Prompt.
|
||||
"""
|
||||
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||
question = data_samples[0].get('question')
|
||||
prompt = self.system_prompt + self.sep
|
||||
prompt += 'Human: ' + question + ' ' + '<Img><ImageHere></Img>' + ' ' + self.sep # noqa
|
||||
prompt += 'Assistant: '
|
||||
return prompt
|
||||
|
Loading…
Reference in New Issue
Block a user