diff --git a/configs/multimodal/minigpt_4/minigpt_4_7b_mme.py b/configs/multimodal/minigpt_4/minigpt_4_7b_mme.py
new file mode 100644
index 00000000..2824a003
--- /dev/null
+++ b/configs/multimodal/minigpt_4/minigpt_4_7b_mme.py
@@ -0,0 +1,43 @@
+from opencompass.multimodal.models.minigpt_4 import (MiniGPT4MMEPostProcessor, MiniGPT4MMEPromptConstructor)
+
+# dataloader settings
+val_pipeline = [
+ dict(type='mmpretrain.LoadImageFromFile'),
+ dict(type='mmpretrain.ToPIL', to_rgb=True),
+ dict(type='mmpretrain.torchvision/Resize',
+ size=(224, 224),
+ interpolation=3),
+ dict(type='mmpretrain.torchvision/ToTensor'),
+ dict(type='mmpretrain.torchvision/Normalize',
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)),
+ dict(type='mmpretrain.PackInputs',
+ algorithm_keys=[
+ 'question', 'answer', 'task'
+ ])
+]
+
+dataset = dict(type='opencompass.MMEDataset',
+ data_dir='/path/to/MME',
+ pipeline=val_pipeline)
+
+minigpt_4_mme_dataloader = dict(batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler', shuffle=False))
+
+# model settings
+minigpt_4_model = dict(
+ type='minigpt-4',
+ low_resource=False,
+ llama_model='/path/to/vicuna/',
+ prompt_constructor=dict(type=MiniGPT4MMEPromptConstructor),
+ post_processor=dict(type=MiniGPT4MMEPostProcessor))
+
+# evaluation settings
+minigpt_4_mme_evaluator = [
+ dict(type='opencompass.MMEMetric')
+]
+
+minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa
diff --git a/opencompass/metrics/__init__.py b/opencompass/metrics/__init__.py
index 68c04467..ec7b8087 100644
--- a/opencompass/metrics/__init__.py
+++ b/opencompass/metrics/__init__.py
@@ -1,4 +1,5 @@
from .dump_results import DumpResults
+from .mme_score import MMEMetric
from .seedbench import SEEDBenchAcc
-__all__ = ['DumpResults', 'SEEDBenchAcc']
+__all__ = ['DumpResults', 'SEEDBenchAcc', 'MMEMetric']
diff --git a/opencompass/metrics/mme_score.py b/opencompass/metrics/mme_score.py
new file mode 100644
index 00000000..28954828
--- /dev/null
+++ b/opencompass/metrics/mme_score.py
@@ -0,0 +1,92 @@
+from collections import defaultdict
+from typing import Optional
+
+from mmengine.evaluator import BaseMetric
+
+from opencompass.registry import METRICS
+
+
+@METRICS.register_module()
+class MMEMetric(BaseMetric):
+ """Dump model's prediction to a file.
+
+ Args:
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Default: None.
+ """
+
+ task_dict = {
+ 'Perception': [
+ 'existence', 'count', 'position', 'color', 'posters', 'celebrity',
+ 'scene', 'landmark', 'artwork', 'OCR'
+ ],
+ 'Cognition': [
+ 'commonsense_reasoning', 'numerical_calculation',
+ 'text_translation', 'code_reasoning'
+ ]
+ } # noqa
+
+ def __init__(self,
+ collect_device: str = 'cpu',
+ prefix: Optional[str] = None) -> None:
+ super().__init__(collect_device, prefix)
+
+ def process(self, data_batch, data_samples) -> None:
+ for data_sample in data_samples:
+ result = dict()
+ result['img_path'] = data_sample['img_path']
+ result['task'] = data_sample['task']
+ result['pred'] = 1 if data_sample['answer'].lower(
+ ) == data_sample['pred_answer'].lower() else 0
+ self.results.append(result)
+
+ def compute_metrics(self, results: list) -> dict:
+
+ # reorganize results
+ record = dict()
+ for task in (self.task_dict['Perception'] +
+ self.task_dict['Cognition']):
+ record[task] = defaultdict(int)
+ for sample in results:
+ record[sample['task']][sample['img_path']] += sample['pred']
+
+ # compute subtask score
+ metric = dict()
+ for task in (self.task_dict['Perception'] +
+ self.task_dict['Cognition']):
+ single_sum, double_sum = 0., 0.
+ for v in record[task].values():
+ assert 0 <= v <= 2
+ if v == 2:
+ single_sum += 2
+ double_sum += 1
+ elif v == 1:
+ single_sum += 1
+ acc = single_sum / 2 / len(record[task])
+ acc_plus = double_sum / len(record[task])
+
+ metric[task] = {
+ 'acc': acc,
+ 'acc_plus': acc_plus,
+ 'score': 100 * (acc + acc_plus)
+ }
+
+ # compute overall score
+ score = 0
+ for task in self.task_dict['Perception']:
+ score += metric[task]['score']
+ metric['Perception'] = score
+
+ score = 0
+ for task in self.task_dict['Cognition']:
+ score += metric[task]['score']
+ metric['Cognition'] = score
+
+ metric['Overall'] = metric['Perception'] + metric['Cognition']
+
+ return metric
diff --git a/opencompass/multimodal/datasets/__init__.py b/opencompass/multimodal/datasets/__init__.py
index c9374c47..dcb96607 100644
--- a/opencompass/multimodal/datasets/__init__.py
+++ b/opencompass/multimodal/datasets/__init__.py
@@ -1,4 +1,5 @@
from .mmbench import MMBenchDataset
+from .mme import MMEDataset
from .seedbench import SEEDBenchDataset
-__all__ = ['MMBenchDataset', 'SEEDBenchDataset']
+__all__ = ['MMBenchDataset', 'SEEDBenchDataset', 'MMEDataset']
diff --git a/opencompass/multimodal/datasets/mmbench.py b/opencompass/multimodal/datasets/mmbench.py
index a1ba4c58..aa2fb5c3 100644
--- a/opencompass/multimodal/datasets/mmbench.py
+++ b/opencompass/multimodal/datasets/mmbench.py
@@ -40,7 +40,7 @@ class MMBenchDataset(Dataset):
def __len__(self) -> None:
return len(self.df)
- def __getitem__(self, idx: str) -> dict:
+ def __getitem__(self, idx: int) -> dict:
index = self.df.iloc[idx]['index']
image = self.df.iloc[idx]['image']
image = decode_base64_to_image(image)
diff --git a/opencompass/multimodal/datasets/mme.py b/opencompass/multimodal/datasets/mme.py
new file mode 100644
index 00000000..c5105175
--- /dev/null
+++ b/opencompass/multimodal/datasets/mme.py
@@ -0,0 +1,74 @@
+import os
+from typing import List
+
+from mmengine.dataset import Compose
+from torch.utils.data import Dataset
+
+from opencompass.registry import DATASETS
+
+
+@DATASETS.register_module()
+class MMEDataset(Dataset):
+ """Dataset to load MME dataset.
+
+ Args:
+ data_dir (str): The path of the dataset.
+ pipeline (List[dict]): The data augmentation.
+ """
+ tasks = [
+ 'artwork', 'celebrity', 'code_reasoning', 'color',
+ 'commonsense_reasoning', 'count', 'existence', 'landmark',
+ 'numerical_calculation', 'OCR', 'position', 'posters', 'scene',
+ 'text_translation'
+ ]
+ sub_dir_name = ('images', 'questions_answers_YN')
+
+ def __init__(self, data_dir: str, pipeline: List[dict]) -> None:
+ self.pipeline = Compose(pipeline)
+ self.load_data(data_dir)
+
+ def load_data(self, data_dir: str):
+ self.data_list = []
+ image_dir, question_dir = self.sub_dir_name
+ for task in self.tasks:
+ if os.path.exists(os.path.join(data_dir, task, question_dir)):
+ q_list = os.listdir(os.path.join(data_dir, task, question_dir))
+ i_list = os.listdir(os.path.join(data_dir, task, image_dir))
+ q_prefix = os.path.join(data_dir, task, question_dir)
+ i_prefix = os.path.join(data_dir, task, image_dir)
+ else:
+ fn_list = os.listdir(os.path.join(data_dir, task))
+ q_list = [fn for fn in fn_list if '.txt' in fn]
+ i_list = [fn for fn in fn_list if fn not in q_list]
+ q_prefix = i_prefix = os.path.join(data_dir, task)
+
+ q_list.sort()
+ i_list.sort()
+ assert len(q_list) == len(i_list)
+ for q_fn, i_fn in zip(q_list, i_list):
+ assert q_fn.split('.')[0] == i_fn.split('.')[0]
+ q_path = os.path.join(q_prefix, q_fn)
+ image_path = os.path.join(i_prefix, i_fn)
+ with open(q_path, 'r') as f:
+ q1, a1 = f.readline().strip().split('\t')
+ q2, a2 = f.readline().strip().split('\t')
+ self.data_list.append({
+ 'img_path': image_path,
+ 'question': q1,
+ 'answer': a1,
+ 'task': task
+ })
+ self.data_list.append({
+ 'img_path': image_path,
+ 'question': q2,
+ 'answer': a2,
+ 'task': task
+ })
+
+ def __len__(self) -> None:
+ return len(self.data_list)
+
+ def __getitem__(self, idx: int) -> dict:
+ data_sample = self.data_list[idx]
+ data_sample = self.pipeline(data_sample)
+ return data_sample
diff --git a/opencompass/multimodal/models/minigpt_4/__init__.py b/opencompass/multimodal/models/minigpt_4/__init__.py
index 715eb01e..20082111 100644
--- a/opencompass/multimodal/models/minigpt_4/__init__.py
+++ b/opencompass/multimodal/models/minigpt_4/__init__.py
@@ -1,15 +1,17 @@
from .minigpt_4 import MiniGPT4Inferencer
from .post_processor import (MiniGPT4COCOCaptionPostProcessor,
MiniGPT4MMBenchPostProcessor,
+ MiniGPT4MMEPostProcessor,
MiniGPT4ScienceQAPostProcessor,
MiniGPT4VQAPostProcessor,
MiniGPT4VSRPostProcessor)
+from .prompt_constructor import MiniGPT4VSRPromptConstructor # noqa
from .prompt_constructor import (MiniGPT4COCOCaotionPromptConstructor,
MiniGPT4MMBenchPromptConstructor,
+ MiniGPT4MMEPromptConstructor,
MiniGPT4ScienceQAPromptConstructor,
MiniGPT4SEEDBenchPromptConstructor,
- MiniGPT4VQAPromptConstructor,
- MiniGPT4VSRPromptConstructor)
+ MiniGPT4VQAPromptConstructor)
__all__ = [
'MiniGPT4Inferencer', 'MiniGPT4MMBenchPostProcessor',
@@ -17,5 +19,6 @@ __all__ = [
'MiniGPT4COCOCaptionPostProcessor', 'MiniGPT4ScienceQAPromptConstructor',
'MiniGPT4ScienceQAPostProcessor', 'MiniGPT4VQAPromptConstructor',
'MiniGPT4VQAPostProcessor', 'MiniGPT4VSRPostProcessor',
- 'MiniGPT4VSRPromptConstructor', 'MiniGPT4SEEDBenchPromptConstructor'
+ 'MiniGPT4VSRPromptConstructor', 'MiniGPT4SEEDBenchPromptConstructor',
+ 'MiniGPT4MMEPostProcessor', 'MiniGPT4MMEPromptConstructor'
]
diff --git a/opencompass/multimodal/models/minigpt_4/post_processor.py b/opencompass/multimodal/models/minigpt_4/post_processor.py
index 85d1f83f..b1f3428e 100644
--- a/opencompass/multimodal/models/minigpt_4/post_processor.py
+++ b/opencompass/multimodal/models/minigpt_4/post_processor.py
@@ -119,3 +119,24 @@ class MiniGPT4VSRPostProcessor:
if len(output_text) > 0:
output_text = output_text[0].lower()
return output_text
+
+
+class MiniGPT4MMEPostProcessor(MiniGPT4MMBenchPostProcessor):
+ """"Post processor for MiniGPT-4 on MME."""
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def __call__(self, output_token: torch.tensor, tokenizer) -> str:
+ response = super().__call__(output_token, tokenizer)
+ # extract yes or no, copy from MME official evaluation script
+ prefix_pred_ans = response[:4].lower()
+
+ if 'yes' in prefix_pred_ans:
+ pred_label = 'yes'
+ elif 'no' in prefix_pred_ans:
+ pred_label = 'no'
+ else:
+ pred_label = 'other'
+
+ return pred_label
diff --git a/opencompass/multimodal/models/minigpt_4/prompt_constructor.py b/opencompass/multimodal/models/minigpt_4/prompt_constructor.py
index 55c8300e..f6d8604d 100644
--- a/opencompass/multimodal/models/minigpt_4/prompt_constructor.py
+++ b/opencompass/multimodal/models/minigpt_4/prompt_constructor.py
@@ -7,8 +7,8 @@ class MiniGPT4MMBenchPromptConstructor:
"""Prompt constructor for MiniGPT-4 on MMBench.
Args:
- image_prompt (str): Image prompt.
- reply_prompt (str): Reply prompt.
+ image_prompt (str): Image prompt. Defaults to `''`.
+ reply_prompt (str): Reply prompt. Defaults to `''`.
"""
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
@@ -138,3 +138,50 @@ class MiniGPT4SEEDBenchPromptConstructor(MiniGPT4MMBenchPromptConstructor):
question = questions[0]
prompt = self.image_prompt + ' ' + question + ' ' + self.reply_prompt
return prompt
+
+
+class MiniGPT4MMEPromptConstructor:
+ """Prompt constructor for MiniGPT-4 on MME.
+
+ Args:
+ image_prompt (str): Image prompt. Defaults to `''`.
+ reply_prompt (str): Reply prompt. Defaults to `''`.
+ """
+
+ def __init__(self) -> None:
+ self.system_prompt = (
+ 'Give the following image:
ImageContent.'
+ 'You will be able to see the image once I provide it to you.'
+ 'Please answer my questions.')
+ self.sep = '###'
+
+ def __call__(self, inputs: dict) -> dict:
+ """Construct prompt.
+
+ Args:
+ inputs (dict): Input data containing image and data_samples.
+
+ Returns:
+ dict: A dict containing prompt, images and data_samples.
+ """
+ data_samples = inputs['data_samples']
+ prompt = self._process(data_samples)
+ inputs.update({'prompt': prompt})
+
+ return inputs
+
+ def _process(self, data_samples: List[DataSample]) -> str:
+ """Process data sample to prompt.
+
+ Args:
+ data_samples (List[DataSample]): A list of data_samples.
+
+ Returns:
+ str: Prompt.
+ """
+ assert len(data_samples) == 1, 'Only support batch size 1.'
+ question = data_samples[0].get('question')
+ prompt = self.system_prompt + self.sep
+ prompt += 'Human: ' + question + ' ' + '
' + ' ' + self.sep # noqa
+ prompt += 'Assistant: '
+ return prompt