[Feat] Support multi-modal evaluation on MME benchmark. (#197)

* [Feat] Support multi-modal evaluation on MME benchmark.

* [Fix] Remove debug code.

* [Fix] Remove redundant codes and add type hints.

* [Fix] Rename in config.

* [Fix] Rebase main.

* [Fix] Fix isort and yapf conflict.
This commit is contained in:
Yike Yuan 2023-08-21 15:53:20 +08:00 committed by GitHub
parent 3b29aaee2b
commit a6552224cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 290 additions and 8 deletions

View File

@ -0,0 +1,43 @@
from opencompass.multimodal.models.minigpt_4 import (MiniGPT4MMEPostProcessor, MiniGPT4MMEPromptConstructor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.LoadImageFromFile'),
dict(type='mmpretrain.ToPIL', to_rgb=True),
dict(type='mmpretrain.torchvision/Resize',
size=(224, 224),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'answer', 'task'
])
]
dataset = dict(type='opencompass.MMEDataset',
data_dir='/path/to/MME',
pipeline=val_pipeline)
minigpt_4_mme_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
minigpt_4_model = dict(
type='minigpt-4',
low_resource=False,
llama_model='/path/to/vicuna/',
prompt_constructor=dict(type=MiniGPT4MMEPromptConstructor),
post_processor=dict(type=MiniGPT4MMEPostProcessor))
# evaluation settings
minigpt_4_mme_evaluator = [
dict(type='opencompass.MMEMetric')
]
minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa

View File

@ -1,4 +1,5 @@
from .dump_results import DumpResults
from .mme_score import MMEMetric
from .seedbench import SEEDBenchAcc
__all__ = ['DumpResults', 'SEEDBenchAcc']
__all__ = ['DumpResults', 'SEEDBenchAcc', 'MMEMetric']

View File

@ -0,0 +1,92 @@
from collections import defaultdict
from typing import Optional
from mmengine.evaluator import BaseMetric
from opencompass.registry import METRICS
@METRICS.register_module()
class MMEMetric(BaseMetric):
"""Dump model's prediction to a file.
Args:
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Default: None.
"""
task_dict = {
'Perception': [
'existence', 'count', 'position', 'color', 'posters', 'celebrity',
'scene', 'landmark', 'artwork', 'OCR'
],
'Cognition': [
'commonsense_reasoning', 'numerical_calculation',
'text_translation', 'code_reasoning'
]
} # noqa
def __init__(self,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device, prefix)
def process(self, data_batch, data_samples) -> None:
for data_sample in data_samples:
result = dict()
result['img_path'] = data_sample['img_path']
result['task'] = data_sample['task']
result['pred'] = 1 if data_sample['answer'].lower(
) == data_sample['pred_answer'].lower() else 0
self.results.append(result)
def compute_metrics(self, results: list) -> dict:
# reorganize results
record = dict()
for task in (self.task_dict['Perception'] +
self.task_dict['Cognition']):
record[task] = defaultdict(int)
for sample in results:
record[sample['task']][sample['img_path']] += sample['pred']
# compute subtask score
metric = dict()
for task in (self.task_dict['Perception'] +
self.task_dict['Cognition']):
single_sum, double_sum = 0., 0.
for v in record[task].values():
assert 0 <= v <= 2
if v == 2:
single_sum += 2
double_sum += 1
elif v == 1:
single_sum += 1
acc = single_sum / 2 / len(record[task])
acc_plus = double_sum / len(record[task])
metric[task] = {
'acc': acc,
'acc_plus': acc_plus,
'score': 100 * (acc + acc_plus)
}
# compute overall score
score = 0
for task in self.task_dict['Perception']:
score += metric[task]['score']
metric['Perception'] = score
score = 0
for task in self.task_dict['Cognition']:
score += metric[task]['score']
metric['Cognition'] = score
metric['Overall'] = metric['Perception'] + metric['Cognition']
return metric

View File

@ -1,4 +1,5 @@
from .mmbench import MMBenchDataset
from .mme import MMEDataset
from .seedbench import SEEDBenchDataset
__all__ = ['MMBenchDataset', 'SEEDBenchDataset']
__all__ = ['MMBenchDataset', 'SEEDBenchDataset', 'MMEDataset']

View File

@ -40,7 +40,7 @@ class MMBenchDataset(Dataset):
def __len__(self) -> None:
return len(self.df)
def __getitem__(self, idx: str) -> dict:
def __getitem__(self, idx: int) -> dict:
index = self.df.iloc[idx]['index']
image = self.df.iloc[idx]['image']
image = decode_base64_to_image(image)

View File

@ -0,0 +1,74 @@
import os
from typing import List
from mmengine.dataset import Compose
from torch.utils.data import Dataset
from opencompass.registry import DATASETS
@DATASETS.register_module()
class MMEDataset(Dataset):
"""Dataset to load MME dataset.
Args:
data_dir (str): The path of the dataset.
pipeline (List[dict]): The data augmentation.
"""
tasks = [
'artwork', 'celebrity', 'code_reasoning', 'color',
'commonsense_reasoning', 'count', 'existence', 'landmark',
'numerical_calculation', 'OCR', 'position', 'posters', 'scene',
'text_translation'
]
sub_dir_name = ('images', 'questions_answers_YN')
def __init__(self, data_dir: str, pipeline: List[dict]) -> None:
self.pipeline = Compose(pipeline)
self.load_data(data_dir)
def load_data(self, data_dir: str):
self.data_list = []
image_dir, question_dir = self.sub_dir_name
for task in self.tasks:
if os.path.exists(os.path.join(data_dir, task, question_dir)):
q_list = os.listdir(os.path.join(data_dir, task, question_dir))
i_list = os.listdir(os.path.join(data_dir, task, image_dir))
q_prefix = os.path.join(data_dir, task, question_dir)
i_prefix = os.path.join(data_dir, task, image_dir)
else:
fn_list = os.listdir(os.path.join(data_dir, task))
q_list = [fn for fn in fn_list if '.txt' in fn]
i_list = [fn for fn in fn_list if fn not in q_list]
q_prefix = i_prefix = os.path.join(data_dir, task)
q_list.sort()
i_list.sort()
assert len(q_list) == len(i_list)
for q_fn, i_fn in zip(q_list, i_list):
assert q_fn.split('.')[0] == i_fn.split('.')[0]
q_path = os.path.join(q_prefix, q_fn)
image_path = os.path.join(i_prefix, i_fn)
with open(q_path, 'r') as f:
q1, a1 = f.readline().strip().split('\t')
q2, a2 = f.readline().strip().split('\t')
self.data_list.append({
'img_path': image_path,
'question': q1,
'answer': a1,
'task': task
})
self.data_list.append({
'img_path': image_path,
'question': q2,
'answer': a2,
'task': task
})
def __len__(self) -> None:
return len(self.data_list)
def __getitem__(self, idx: int) -> dict:
data_sample = self.data_list[idx]
data_sample = self.pipeline(data_sample)
return data_sample

View File

@ -1,15 +1,17 @@
from .minigpt_4 import MiniGPT4Inferencer
from .post_processor import (MiniGPT4COCOCaptionPostProcessor,
MiniGPT4MMBenchPostProcessor,
MiniGPT4MMEPostProcessor,
MiniGPT4ScienceQAPostProcessor,
MiniGPT4VQAPostProcessor,
MiniGPT4VSRPostProcessor)
from .prompt_constructor import MiniGPT4VSRPromptConstructor # noqa
from .prompt_constructor import (MiniGPT4COCOCaotionPromptConstructor,
MiniGPT4MMBenchPromptConstructor,
MiniGPT4MMEPromptConstructor,
MiniGPT4ScienceQAPromptConstructor,
MiniGPT4SEEDBenchPromptConstructor,
MiniGPT4VQAPromptConstructor,
MiniGPT4VSRPromptConstructor)
MiniGPT4VQAPromptConstructor)
__all__ = [
'MiniGPT4Inferencer', 'MiniGPT4MMBenchPostProcessor',
@ -17,5 +19,6 @@ __all__ = [
'MiniGPT4COCOCaptionPostProcessor', 'MiniGPT4ScienceQAPromptConstructor',
'MiniGPT4ScienceQAPostProcessor', 'MiniGPT4VQAPromptConstructor',
'MiniGPT4VQAPostProcessor', 'MiniGPT4VSRPostProcessor',
'MiniGPT4VSRPromptConstructor', 'MiniGPT4SEEDBenchPromptConstructor'
'MiniGPT4VSRPromptConstructor', 'MiniGPT4SEEDBenchPromptConstructor',
'MiniGPT4MMEPostProcessor', 'MiniGPT4MMEPromptConstructor'
]

View File

@ -119,3 +119,24 @@ class MiniGPT4VSRPostProcessor:
if len(output_text) > 0:
output_text = output_text[0].lower()
return output_text
class MiniGPT4MMEPostProcessor(MiniGPT4MMBenchPostProcessor):
""""Post processor for MiniGPT-4 on MME."""
def __init__(self) -> None:
super().__init__()
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
response = super().__call__(output_token, tokenizer)
# extract yes or no, copy from MME official evaluation script
prefix_pred_ans = response[:4].lower()
if 'yes' in prefix_pred_ans:
pred_label = 'yes'
elif 'no' in prefix_pred_ans:
pred_label = 'no'
else:
pred_label = 'other'
return pred_label

View File

@ -7,8 +7,8 @@ class MiniGPT4MMBenchPromptConstructor:
"""Prompt constructor for MiniGPT-4 on MMBench.
Args:
image_prompt (str): Image prompt.
reply_prompt (str): Reply prompt.
image_prompt (str): Image prompt. Defaults to `''`.
reply_prompt (str): Reply prompt. Defaults to `''`.
"""
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
@ -138,3 +138,50 @@ class MiniGPT4SEEDBenchPromptConstructor(MiniGPT4MMBenchPromptConstructor):
question = questions[0]
prompt = self.image_prompt + ' ' + question + ' ' + self.reply_prompt
return prompt
class MiniGPT4MMEPromptConstructor:
"""Prompt constructor for MiniGPT-4 on MME.
Args:
image_prompt (str): Image prompt. Defaults to `''`.
reply_prompt (str): Reply prompt. Defaults to `''`.
"""
def __init__(self) -> None:
self.system_prompt = (
'Give the following image: <Img>ImageContent</Img>.'
'You will be able to see the image once I provide it to you.'
'Please answer my questions.')
self.sep = '###'
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
question = data_samples[0].get('question')
prompt = self.system_prompt + self.sep
prompt += 'Human: ' + question + ' ' + '<Img><ImageHere></Img>' + ' ' + self.sep # noqa
prompt += 'Assistant: '
return prompt