diff --git a/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py b/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py
index 90a6df61..43ecb801 100644
--- a/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py
+++ b/configs/multimodal/minigpt_4/minigpt_4_7b_mmbench.py
@@ -1,3 +1,6 @@
+from opencompass.multimodal.models.minigpt_4 import (
+ MiniGPT4MMBenchPromptConstructor, MiniGPT4PostProcessor)
+
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.torchvision/Resize',
@@ -9,8 +12,8 @@ val_pipeline = [
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
- 'question', 'category', 'l2-category', 'context',
- 'index', 'options_dict', 'options', 'split'
+ 'question', 'category', 'l2-category', 'context', 'index',
+ 'options_dict', 'options', 'split'
])
]
@@ -27,11 +30,12 @@ minigpt_4_dataloader = dict(batch_size=1,
# model settings
minigpt_4_model = dict(
type='minigpt-4-mmbench',
- low_resource=True,
- llama_model='/path/to/vicuna',
- sys_prompt= # noqa: E251
- '###Human: What is the capital of China? There are several options:\nA. Beijing\nB. Shanghai\nC. Guangzhou\nD. Shenzhen\n###Assistant: A\n'
-)
+ low_resource=False,
+ llama_model='/path/to/vicuna-7b/',
+ prompt_constructor=dict(type=MiniGPT4MMBenchPromptConstructor,
+ image_prompt='###Human:
',
+ reply_prompt='###Assistant:'),
+ post_processor=dict(type=MiniGPT4PostProcessor))
# evaluation settings
minigpt_4_evaluator = [
@@ -39,4 +43,4 @@ minigpt_4_evaluator = [
save_path='work_dirs/minigpt-4-7b-mmbench.xlsx')
]
-minigpt_4_load_from = '/path/to/minigpt-4' # noqa
+minigpt_4_load_from = '/path/to/prerained_minigpt4_7b.pth' # noqa
diff --git a/configs/multimodal/tasks.py b/configs/multimodal/tasks.py
index b8fd75fd..94273b96 100644
--- a/configs/multimodal/tasks.py
+++ b/configs/multimodal/tasks.py
@@ -10,6 +10,6 @@ models = [minigpt_4_model]
datasets = [minigpt_4_dataloader]
evaluators = [minigpt_4_evaluator]
load_froms = [minigpt_4_load_from]
-num_gpus = 1
-num_procs = 1
-launcher = 'slurm'
+num_gpus = 8
+num_procs = 8
+launcher = 'pytorch'
diff --git a/opencompass/multimodal/models/minigpt_4/__init__.py b/opencompass/multimodal/models/minigpt_4/__init__.py
index 3104855c..6604c669 100644
--- a/opencompass/multimodal/models/minigpt_4/__init__.py
+++ b/opencompass/multimodal/models/minigpt_4/__init__.py
@@ -1,3 +1,8 @@
from .minigpt_4 import MiniGPT4MMBench
+from .post_processor import MiniGPT4PostProcessor
+from .prompt_constructor import MiniGPT4MMBenchPromptConstructor
-__all__ = ['MiniGPT4MMBench']
+__all__ = [
+ 'MiniGPT4MMBench', 'MiniGPT4PostProcessor',
+ 'MiniGPT4MMBenchPromptConstructor'
+]
diff --git a/opencompass/multimodal/models/minigpt_4/minigpt_4.py b/opencompass/multimodal/models/minigpt_4/minigpt_4.py
index 306fec58..ee4d4c8c 100644
--- a/opencompass/multimodal/models/minigpt_4/minigpt_4.py
+++ b/opencompass/multimodal/models/minigpt_4/minigpt_4.py
@@ -1,7 +1,7 @@
import os
-import re
import sys
+import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
@@ -43,15 +43,16 @@ class MiniGPT4MMBench(MiniGPT4):
Args:
llama_model (str): The path of vicuna path.
- sys_prompt (str): The prompt added to the beginning
- of each query. Defaults to ''.
+ prompt_constructor (dict): The config of prompt constructor.
+ post_processor (dict): The config of post processor.
low_resource (bool): Whether loaded in low precision.
Defaults to False.
"""
def __init__(self,
llama_model: str,
- sys_prompt: str = '',
+ prompt_constructor: dict,
+ post_processor: dict,
low_resource: bool = False) -> None:
super().__init__(llama_model=llama_model, low_resource=low_resource)
@@ -62,7 +63,10 @@ class MiniGPT4MMBench(MiniGPT4):
]
self.stopping_criteria = StoppingCriteriaList(
[StoppingCriteriaSub(stops=stop_words_ids)])
- self.sys_prompt = sys_prompt
+ self.prompt_constructor = mmengine.registry.build_from_cfg(
+ prompt_constructor, MM_MODELS)
+ self.post_processor = mmengine.registry.build_from_cfg(
+ post_processor, MM_MODELS)
def encode_img(self, image):
device = image.device
@@ -96,38 +100,13 @@ class MiniGPT4MMBench(MiniGPT4):
def generate(self, batch):
inputs = self.pack_inputs(batch)
- image = inputs.pop('image')
+ inputs = self.prompt_constructor(inputs)
+ image = inputs['image']
+ prompt = inputs['prompt']
data_samples = inputs['data_samples']
- samples = {'image': image}
- question = [
- data_sample.get('question') for data_sample in data_samples
- ]
- options = [data_sample.get('options') for data_sample in data_samples]
- samples.update({'question': question[0]})
- samples.update({'options': options[0]})
- if data_samples[0].get('context') is not None:
- context = [
- data_sample.get('context') for data_sample in data_samples
- ]
- samples.update({'context': context})
- data_sample = data_samples[0]
- img_prompt = '###Human:
'
- if 'context' in samples:
- context_prompt = samples['context'][0]
- question = samples['question']
- options = samples['options']
- if 'context' in samples:
- prompt = img_prompt + ' ' + context_prompt + ' ' + question + ' ' + options # noqa
- else:
- prompt = img_prompt + ' ' + question + ' ' + options
-
- # prompt = self.sys_prompt + prompt
- prompt = prompt + '###Assistant:'
-
- image = samples['image']
+ # The main process of generation
img_embeds, _ = self.encode_img(image)
-
prompt_segs = prompt.split('')
prompt_seg_tokens = [
self.llama_tokenizer(seg,
@@ -157,25 +136,10 @@ class MiniGPT4MMBench(MiniGPT4):
stopping_criteria=self.stopping_criteria,
num_return_sequences=1)
- output_token = outputs[0]
- if output_token[0] == 0:
- output_token = output_token[1:]
- if output_token[0] == 1:
- output_token = output_token[1:]
- output_text = self.llama_tokenizer.decode(output_token,
- add_special_tokens=False)
- output_text = self.post_process(output_text)
- data_sample.pred_answer = output_text
- return data_sample
-
- def post_process(self, output_text):
- output_text = output_text.split('###')[0]
- output_text = output_text.split('Assistant:')[-1].strip()
- output_text = output_text.strip('')
- output_text = output_text.strip('')
- output_text = output_text.strip()
- pattern = re.compile(r'([A-Z]\.)')
- res = pattern.findall(output_text)
- if len(res) > 0:
- output_text = res[0][:-1]
- return output_text
+ for i, data_sample in enumerate(data_samples):
+ output_token = outputs[i]
+ output_text = self.post_processor(output_token,
+ self.llama_tokenizer)
+ data_sample.pred_answer = output_text
+ data_samples[i] = data_sample
+ return data_samples
diff --git a/opencompass/multimodal/models/minigpt_4/post_processor.py b/opencompass/multimodal/models/minigpt_4/post_processor.py
new file mode 100644
index 00000000..301a3422
--- /dev/null
+++ b/opencompass/multimodal/models/minigpt_4/post_processor.py
@@ -0,0 +1,34 @@
+import re
+
+import torch
+
+
+class MiniGPT4PostProcessor:
+ """"Post processor for MiniGPT-4 on MMBench."""
+
+ def __init__(self) -> None:
+ pass
+
+ def __call__(self, output_token: torch.tensor, tokenizer) -> str:
+
+ if output_token[0] == 0:
+ output_token = output_token[1:]
+ if output_token[0] == 1:
+ output_token = output_token[1:]
+ output_text = tokenizer.decode(output_token,
+ add_special_tokens=False) # noqa
+ output_text = self._extract_key_words(output_text)
+ return output_text
+
+ def _extract_key_words(self, output_text: str) -> str:
+
+ output_text = output_text.split('###')[0]
+ output_text = output_text.split('Assistant:')[-1].strip()
+ output_text = output_text.strip('')
+ output_text = output_text.strip('')
+ output_text = output_text.strip()
+ pattern = re.compile(r'([A-Z]\.)')
+ res = pattern.findall(output_text)
+ if len(res) > 0:
+ output_text = res[0][:-1]
+ return output_text
diff --git a/opencompass/multimodal/models/minigpt_4/prompt_constructor.py b/opencompass/multimodal/models/minigpt_4/prompt_constructor.py
new file mode 100644
index 00000000..de07c1bf
--- /dev/null
+++ b/opencompass/multimodal/models/minigpt_4/prompt_constructor.py
@@ -0,0 +1,55 @@
+from typing import List
+
+from mmpretrain.structures import DataSample
+
+
+class MiniGPT4MMBenchPromptConstructor:
+ """Prompt constructor for MiniGPT-4 on MMBench.
+
+ Args:
+ image_prompt (str): Image prompt.
+ reply_prompt (str): Reply prompt.
+ """
+
+ def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
+ self.image_prompt = image_prompt
+ self.reply_prompt = reply_prompt
+
+ def __call__(self, inputs: dict) -> dict:
+ """Construct prompt.
+
+ Args:
+ inputs (dict): Input data containing image and data_samples.
+
+ Returns:
+ dict: A dict containing prompt, images and data_samples.
+ """
+ data_samples = inputs['data_samples']
+ prompt = self._process(data_samples)
+ inputs.update({'prompt': prompt})
+
+ return inputs
+
+ def _process(self, data_samples: List[DataSample]) -> str:
+ """Process data sample to prompt.
+
+ Args:
+ data_samples (List[DataSample]): A list of data_samples.
+
+ Returns:
+ str: Prompt.
+ """
+ assert len(data_samples) == 1, 'Only support batch size 1.'
+ questions = [
+ data_sample.get('question') for data_sample in data_samples
+ ]
+ options = [data_sample.get('options') for data_sample in data_samples]
+ contexts = [data_sample.get('context') for data_sample in data_samples]
+ question = questions[0]
+ option = options[0]
+ context = contexts[0]
+ if context is not None:
+ prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
+ else:
+ prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
+ return prompt
diff --git a/opencompass/tasks/mm_infer.py b/opencompass/tasks/mm_infer.py
index 5ecc1ab3..46ef4a79 100644
--- a/opencompass/tasks/mm_infer.py
+++ b/opencompass/tasks/mm_infer.py
@@ -4,7 +4,7 @@ import os
import os.path as osp
import random
import time
-from typing import Sequence
+from typing import List, Sequence
import torch
import torch.distributed as dist
@@ -78,6 +78,22 @@ class MultimodalInferTask:
return osp.join(model_name,
f'{dataset_name}-{evaluator_name}.{file_extension}')
+ def get_output_paths(self, file_extension: str = 'json') -> List[str]:
+ """Get the path to the output file.
+
+ Args:
+ file_extension (str): The file extension of the log file.
+ Default: 'json'.
+ """
+ model_name = self.model['type']
+ dataset_name = self.dataloader['dataset']['type']
+ evaluator_name = self.evaluator[0]['type']
+
+ return [
+ osp.join(model_name, dataset_name,
+ f'{evaluator_name}.{file_extension}')
+ ]
+
def get_command(self, cfg_path, template):
"""Get the command template for the task.