From 1034c487efe3cc0854791c7d6115418f71714fb7 Mon Sep 17 00:00:00 2001
From: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
Date: Wed, 23 Aug 2023 15:33:59 +0800
Subject: [PATCH] [Refactor] Refactor instructblip (#227)
* refactor instructblip
* add post processor
* add forward
* fix lint
* update
* update
---
configs/multimodal/instructblip/README.md | 42 +++++++-
...lip-mmbench.py => instructblip_mmbench.py} | 30 +++---
.../models/instructblip/__init__.py | 9 +-
.../instructblip/blip2_vicuna_instruct.py | 99 ++++++++-----------
.../models/instructblip/post_processor.py | 31 ++++++
.../models/instructblip/prompt_constructor.py | 55 +++++++++++
6 files changed, 193 insertions(+), 73 deletions(-)
rename configs/multimodal/instructblip/{instructblip-mmbench.py => instructblip_mmbench.py} (51%)
create mode 100644 opencompass/multimodal/models/instructblip/post_processor.py
create mode 100644 opencompass/multimodal/models/instructblip/prompt_constructor.py
diff --git a/configs/multimodal/instructblip/README.md b/configs/multimodal/instructblip/README.md
index 1002328d..1b5ea393 100644
--- a/configs/multimodal/instructblip/README.md
+++ b/configs/multimodal/instructblip/README.md
@@ -6,4 +6,44 @@
git clone https://github.com/salesforce/LAVIS.git
cd ./LAVIS
pip install -e .
-```
\ No newline at end of file
+```
+
+### Modify the config
+
+Modify the config of InstructBlip, like model path of LLM and Qformer.
+
+Then update `tasks.py` like the following code snippet.
+
+```python
+from mmengine.config import read_base
+
+with read_base():
+ from .instructblip.instructblip_mmbench import (instruct_blip_dataloader,
+ instruct_blip_evaluator,
+ instruct_blip_load_from,
+ instruct_blip_model)
+
+models = [instruct_blip_model]
+datasets = [instruct_blip_dataloader]
+evaluators = [instruct_blip_evaluator]
+load_froms = [instruct_blip_load_from]
+num_gpus = 8
+num_procs = 8
+launcher = 'pytorch' # or 'slurm'
+```
+
+### Start evaluation
+
+#### Slurm
+
+```sh
+cd $root
+python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
+```
+
+#### PyTorch
+
+```sh
+cd $root
+python run.py configs/multimodal/tasks.py --mm-eval
+```
diff --git a/configs/multimodal/instructblip/instructblip-mmbench.py b/configs/multimodal/instructblip/instructblip_mmbench.py
similarity index 51%
rename from configs/multimodal/instructblip/instructblip-mmbench.py
rename to configs/multimodal/instructblip/instructblip_mmbench.py
index 2ae74009..b7113e69 100644
--- a/configs/multimodal/instructblip/instructblip-mmbench.py
+++ b/configs/multimodal/instructblip/instructblip_mmbench.py
@@ -1,3 +1,6 @@
+from opencompass.multimodal.models.instructblip import (
+ InstructBlipMMBenchPromptConstructor, InstructBlipMMBenchPostProcessor)
+
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.torchvision/Resize',
@@ -9,24 +12,27 @@ val_pipeline = [
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
- 'question', 'category', 'l2-category', 'context',
- 'index', 'options_dict', 'options', 'split'
+ 'question', 'category', 'l2-category', 'context', 'index',
+ 'options_dict', 'options', 'split'
])
]
-dataset = dict(type='opencompass.MMBench',
+dataset = dict(type='opencompass.MMBenchDataset',
data_file='data/mmbench/mmbench_test_20230712.tsv',
pipeline=val_pipeline)
-dataloader = dict(batch_size=1,
- num_workers=4,
- dataset=dataset,
- collate_fn=dict(type='pseudo_collate'),
- sampler=dict(type='DefaultSampler', shuffle=False))
+instruct_blip_dataloader = dict(batch_size=1,
+ num_workers=4,
+ dataset=dataset,
+ collate_fn=dict(type='pseudo_collate'),
+ sampler=dict(type='DefaultSampler',
+ shuffle=False))
# model settings
-model = dict(
- type='blip2-vicuna-instruct-mmbench',
+instruct_blip_model = dict(
+ type='blip2-vicuna-instruct',
+ prompt_constructor=dict(type=InstructBlipMMBenchPromptConstructor),
+ post_processor=dict(type=InstructBlipMMBenchPostProcessor),
freeze_vit=True,
low_resource=False,
llm_model='/path/to/vicuna-7b/',
@@ -35,11 +41,11 @@ model = dict(
)
# evaluation settings
-evaluator = [
+instruct_blip_evaluator = [
dict(
type='opencompass.DumpResults',
save_path= # noqa: E251
'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx')
]
-load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth' # noqa
+instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed'
diff --git a/opencompass/multimodal/models/instructblip/__init__.py b/opencompass/multimodal/models/instructblip/__init__.py
index 1aa1c98b..af926280 100644
--- a/opencompass/multimodal/models/instructblip/__init__.py
+++ b/opencompass/multimodal/models/instructblip/__init__.py
@@ -1,3 +1,8 @@
-from .blip2_vicuna_instruct import Blip2VicunaInstructMMBench
+from .blip2_vicuna_instruct import InstructBlipInferencer
+from .post_processor import InstructBlipMMBenchPostProcessor
+from .prompt_constructor import InstructBlipMMBenchPromptConstructor
-__all__ = ['Blip2VicunaInstructMMBench']
+__all__ = [
+ 'InstructBlipInferencer', 'InstructBlipMMBenchPromptConstructor',
+ 'InstructBlipMMBenchPostProcessor'
+]
diff --git a/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py b/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py
index 6595df10..bc08a31d 100644
--- a/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py
+++ b/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py
@@ -1,8 +1,8 @@
"""Requires Transformer 4.28 and above, implementation may change according the
Llama implementation."""
import logging
-import re
+import mmengine
import torch
import torch.nn as nn
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
@@ -12,27 +12,36 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
from opencompass.registry import MM_MODELS
-@MM_MODELS.register_module('blip2-vicuna-instruct-mmbench')
-class Blip2VicunaInstructMMBench(Blip2Base):
+@MM_MODELS.register_module('blip2-vicuna-instruct')
+class InstructBlipInferencer(Blip2Base):
def __init__(
self,
- vit_model='eva_clip_g',
- img_size=224,
- drop_path_rate=0,
- use_grad_checkpoint=False,
- vit_precision='fp16',
- freeze_vit=True,
- num_query_token=32,
- llm_model='',
- sys_prompt='',
- prompt='',
- max_txt_len=128,
- max_output_txt_len=256,
- qformer_text_input=True,
- low_resource=False,
+ prompt_constructor: dict,
+ post_processor: dict,
+ vit_model: str = 'eva_clip_g',
+ img_size: int = 224,
+ drop_path_rate: float = 0,
+ use_grad_checkpoint: bool = False,
+ vit_precision: str = 'fp16',
+ freeze_vit: bool = True,
+ num_query_token: int = 32,
+ llm_model: str = '',
+ sys_prompt: str = '',
+ prompt: str = '',
+ max_txt_len: int = 128,
+ max_output_txt_len: int = 256,
+ qformer_text_input: bool = True,
+ low_resource: bool = False,
+ mode: str = 'generation',
):
super().__init__()
+ self.mode = mode
+ self.prompt_constructor = mmengine.registry.build_from_cfg(
+ prompt_constructor, MM_MODELS)
+ self.post_processor = mmengine.registry.build_from_cfg(
+ post_processor, MM_MODELS)
+
self.tokenizer = self.init_tokenizer(truncation_side='left')
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
@@ -92,6 +101,12 @@ class Blip2VicunaInstructMMBench(Blip2Base):
self.qformer_text_input = qformer_text_input
+ def forward(self, batch):
+ if self.mode == 'generation':
+ return self.generate(batch)
+ else:
+ raise RuntimeError(f'Invalid mode "{self.mode}".')
+
def concat_text_input_output(self, input_ids, input_atts, output_ids,
output_atts):
input_part_targets_len = []
@@ -136,31 +151,13 @@ class Blip2VicunaInstructMMBench(Blip2Base):
temperature=1,
):
inputs = self.pack_inputs(batch)
- image = inputs.pop('image')
+ inputs = self.prompt_constructor(inputs)
+ image = inputs['image']
+ prompt = inputs['prompt']
data_samples = inputs['data_samples']
- samples = {'image': image}
- questions = [
- data_sample.get('question') for data_sample in data_samples
- ]
- options = [data_sample.get('options') for data_sample in data_samples]
- if data_samples[0].get('context') is not None:
- contexts = [
- data_sample.get('context') for data_sample in data_samples
- ]
- prompt = [
- context + ' ' + question + ' ' + option for context, question,
- option in zip(contexts, questions, options)
- ]
- else:
- prompt = [
- question + ' ' + option
- for question, option in zip(questions, options)
- ]
self.llm_tokenizer.padding_side = 'left'
- image = samples['image']
-
bs = image.size(0)
if isinstance(prompt, str):
@@ -237,24 +234,10 @@ class Blip2VicunaInstructMMBench(Blip2Base):
length_penalty=length_penalty,
num_return_sequences=num_captions,
)
- outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
- output_text = self.llm_tokenizer.batch_decode(outputs,
- skip_special_tokens=True)
- output_text = [text.strip() for text in output_text]
- output_text = self.post_process(output_text[0])
- data_sample = data_samples[0]
- data_sample.pred_answer = output_text
- return data_sample
-
- def post_process(self, output_text):
- output_text = output_text.split('###')[0]
- output_text = output_text.split('Assistant:')[-1].strip()
- output_text = output_text.strip('')
- output_text = output_text.strip('')
- output_text = output_text.strip()
- pattern = re.compile(r'([A-Z]\.)')
- res = pattern.findall(output_text)
- if len(res) > 0:
- output_text = res[0][:-1]
- return output_text
+ for i, data_sample in enumerate(data_samples):
+ output_token = outputs[i]
+ output_text = self.post_processor(output_token, self.llm_tokenizer)
+ data_sample.pred_answer = output_text
+ data_samples[i] = data_sample
+ return data_samples
diff --git a/opencompass/multimodal/models/instructblip/post_processor.py b/opencompass/multimodal/models/instructblip/post_processor.py
new file mode 100644
index 00000000..0b124a6f
--- /dev/null
+++ b/opencompass/multimodal/models/instructblip/post_processor.py
@@ -0,0 +1,31 @@
+import re
+
+import torch
+
+
+class InstructBlipMMBenchPostProcessor:
+ """"Post processor for MiniGPT-4 on MMBench."""
+
+ def __init__(self) -> None:
+ pass
+
+ def __call__(self, output_token: torch.tensor, tokenizer) -> str:
+ # convert output id 0 to 2 (eos_token_id)
+ output_token[output_token == 0] = 2
+ output_text = tokenizer.decode(output_token,
+ add_special_tokens=False) # noqa
+ output_text = self._extract_key_words(output_text.strip())
+ return output_text
+
+ def _extract_key_words(self, output_text: str) -> str:
+
+ output_text = output_text.split('###')[0]
+ output_text = output_text.split('Assistant:')[-1].strip()
+ output_text = output_text.strip('')
+ output_text = output_text.strip('')
+ output_text = output_text.strip()
+ pattern = re.compile(r'([A-Z]\.)')
+ res = pattern.findall(output_text)
+ if len(res) > 0:
+ output_text = res[0][:-1]
+ return output_text
diff --git a/opencompass/multimodal/models/instructblip/prompt_constructor.py b/opencompass/multimodal/models/instructblip/prompt_constructor.py
new file mode 100644
index 00000000..f617e929
--- /dev/null
+++ b/opencompass/multimodal/models/instructblip/prompt_constructor.py
@@ -0,0 +1,55 @@
+from typing import List
+
+from mmpretrain.structures import DataSample
+
+
+class InstructBlipMMBenchPromptConstructor:
+ """Prompt constructor for InstructBlip on MMBench.
+
+ Args:
+ image_prompt (str): Image prompt.
+ reply_prompt (str): Reply prompt.
+ """
+
+ def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
+ self.image_prompt = image_prompt
+ self.reply_prompt = reply_prompt
+
+ def __call__(self, inputs: dict) -> dict:
+ """Construct prompt.
+
+ Args:
+ inputs (dict): Input data containing image and data_samples.
+
+ Returns:
+ dict: A dict containing prompt, images and data_samples.
+ """
+ data_samples = inputs['data_samples']
+ prompt = self._process(data_samples)
+ inputs.update({'prompt': prompt})
+
+ return inputs
+
+ def _process(self, data_samples: List[DataSample]) -> str:
+ """Process data sample to prompt.
+
+ Args:
+ data_samples (List[DataSample]): A list of data_samples.
+
+ Returns:
+ str: Prompt.
+ """
+ assert len(data_samples) == 1, 'Only support batch size 1.'
+ questions = [
+ data_sample.get('question') for data_sample in data_samples
+ ]
+ options = [data_sample.get('options') for data_sample in data_samples]
+ contexts = [data_sample.get('context') for data_sample in data_samples]
+ question = questions[0]
+ option = options[0]
+ context = contexts[0]
+ if context is not None:
+ prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
+ else:
+ prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
+ return prompt