[Refactor] Refactor instructblip (#227)

* refactor instructblip

* add post processor

* add forward

* fix lint

* update

* update
This commit is contained in:
Yixiao Fang 2023-08-23 15:33:59 +08:00 committed by GitHub
parent 02ce139bc6
commit 1034c487ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 193 additions and 73 deletions

View File

@ -7,3 +7,43 @@ git clone https://github.com/salesforce/LAVIS.git
cd ./LAVIS cd ./LAVIS
pip install -e . pip install -e .
``` ```
### Modify the config
Modify the config of InstructBlip, like model path of LLM and Qformer.
Then update `tasks.py` like the following code snippet.
```python
from mmengine.config import read_base
with read_base():
from .instructblip.instructblip_mmbench import (instruct_blip_dataloader,
instruct_blip_evaluator,
instruct_blip_load_from,
instruct_blip_model)
models = [instruct_blip_model]
datasets = [instruct_blip_dataloader]
evaluators = [instruct_blip_evaluator]
load_froms = [instruct_blip_load_from]
num_gpus = 8
num_procs = 8
launcher = 'pytorch' # or 'slurm'
```
### Start evaluation
#### Slurm
```sh
cd $root
python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
```
#### PyTorch
```sh
cd $root
python run.py configs/multimodal/tasks.py --mm-eval
```

View File

@ -1,3 +1,6 @@
from opencompass.multimodal.models.instructblip import (
InstructBlipMMBenchPromptConstructor, InstructBlipMMBenchPostProcessor)
# dataloader settings # dataloader settings
val_pipeline = [ val_pipeline = [
dict(type='mmpretrain.torchvision/Resize', dict(type='mmpretrain.torchvision/Resize',
@ -9,24 +12,27 @@ val_pipeline = [
std=(0.26862954, 0.26130258, 0.27577711)), std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs', dict(type='mmpretrain.PackInputs',
algorithm_keys=[ algorithm_keys=[
'question', 'category', 'l2-category', 'context', 'question', 'category', 'l2-category', 'context', 'index',
'index', 'options_dict', 'options', 'split' 'options_dict', 'options', 'split'
]) ])
] ]
dataset = dict(type='opencompass.MMBench', dataset = dict(type='opencompass.MMBenchDataset',
data_file='data/mmbench/mmbench_test_20230712.tsv', data_file='data/mmbench/mmbench_test_20230712.tsv',
pipeline=val_pipeline) pipeline=val_pipeline)
dataloader = dict(batch_size=1, instruct_blip_dataloader = dict(batch_size=1,
num_workers=4, num_workers=4,
dataset=dataset, dataset=dataset,
collate_fn=dict(type='pseudo_collate'), collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False)) sampler=dict(type='DefaultSampler',
shuffle=False))
# model settings # model settings
model = dict( instruct_blip_model = dict(
type='blip2-vicuna-instruct-mmbench', type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipMMBenchPromptConstructor),
post_processor=dict(type=InstructBlipMMBenchPostProcessor),
freeze_vit=True, freeze_vit=True,
low_resource=False, low_resource=False,
llm_model='/path/to/vicuna-7b/', llm_model='/path/to/vicuna-7b/',
@ -35,11 +41,11 @@ model = dict(
) )
# evaluation settings # evaluation settings
evaluator = [ instruct_blip_evaluator = [
dict( dict(
type='opencompass.DumpResults', type='opencompass.DumpResults',
save_path= # noqa: E251 save_path= # noqa: E251
'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx') 'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx')
] ]
load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth' # noqa instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed'

View File

@ -1,3 +1,8 @@
from .blip2_vicuna_instruct import Blip2VicunaInstructMMBench from .blip2_vicuna_instruct import InstructBlipInferencer
from .post_processor import InstructBlipMMBenchPostProcessor
from .prompt_constructor import InstructBlipMMBenchPromptConstructor
__all__ = ['Blip2VicunaInstructMMBench'] __all__ = [
'InstructBlipInferencer', 'InstructBlipMMBenchPromptConstructor',
'InstructBlipMMBenchPostProcessor'
]

View File

@ -1,8 +1,8 @@
"""Requires Transformer 4.28 and above, implementation may change according the """Requires Transformer 4.28 and above, implementation may change according the
Llama implementation.""" Llama implementation."""
import logging import logging
import re
import mmengine
import torch import torch
import torch.nn as nn import torch.nn as nn
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
@ -12,27 +12,36 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
from opencompass.registry import MM_MODELS from opencompass.registry import MM_MODELS
@MM_MODELS.register_module('blip2-vicuna-instruct-mmbench') @MM_MODELS.register_module('blip2-vicuna-instruct')
class Blip2VicunaInstructMMBench(Blip2Base): class InstructBlipInferencer(Blip2Base):
def __init__( def __init__(
self, self,
vit_model='eva_clip_g', prompt_constructor: dict,
img_size=224, post_processor: dict,
drop_path_rate=0, vit_model: str = 'eva_clip_g',
use_grad_checkpoint=False, img_size: int = 224,
vit_precision='fp16', drop_path_rate: float = 0,
freeze_vit=True, use_grad_checkpoint: bool = False,
num_query_token=32, vit_precision: str = 'fp16',
llm_model='', freeze_vit: bool = True,
sys_prompt='', num_query_token: int = 32,
prompt='', llm_model: str = '',
max_txt_len=128, sys_prompt: str = '',
max_output_txt_len=256, prompt: str = '',
qformer_text_input=True, max_txt_len: int = 128,
low_resource=False, max_output_txt_len: int = 256,
qformer_text_input: bool = True,
low_resource: bool = False,
mode: str = 'generation',
): ):
super().__init__() super().__init__()
self.mode = mode
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
self.tokenizer = self.init_tokenizer(truncation_side='left') self.tokenizer = self.init_tokenizer(truncation_side='left')
self.visual_encoder, self.ln_vision = self.init_vision_encoder( self.visual_encoder, self.ln_vision = self.init_vision_encoder(
@ -92,6 +101,12 @@ class Blip2VicunaInstructMMBench(Blip2Base):
self.qformer_text_input = qformer_text_input self.qformer_text_input = qformer_text_input
def forward(self, batch):
if self.mode == 'generation':
return self.generate(batch)
else:
raise RuntimeError(f'Invalid mode "{self.mode}".')
def concat_text_input_output(self, input_ids, input_atts, output_ids, def concat_text_input_output(self, input_ids, input_atts, output_ids,
output_atts): output_atts):
input_part_targets_len = [] input_part_targets_len = []
@ -136,31 +151,13 @@ class Blip2VicunaInstructMMBench(Blip2Base):
temperature=1, temperature=1,
): ):
inputs = self.pack_inputs(batch) inputs = self.pack_inputs(batch)
image = inputs.pop('image') inputs = self.prompt_constructor(inputs)
image = inputs['image']
prompt = inputs['prompt']
data_samples = inputs['data_samples'] data_samples = inputs['data_samples']
samples = {'image': image}
questions = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
if data_samples[0].get('context') is not None:
contexts = [
data_sample.get('context') for data_sample in data_samples
]
prompt = [
context + ' ' + question + ' ' + option for context, question,
option in zip(contexts, questions, options)
]
else:
prompt = [
question + ' ' + option
for question, option in zip(questions, options)
]
self.llm_tokenizer.padding_side = 'left' self.llm_tokenizer.padding_side = 'left'
image = samples['image']
bs = image.size(0) bs = image.size(0)
if isinstance(prompt, str): if isinstance(prompt, str):
@ -237,24 +234,10 @@ class Blip2VicunaInstructMMBench(Blip2Base):
length_penalty=length_penalty, length_penalty=length_penalty,
num_return_sequences=num_captions, num_return_sequences=num_captions,
) )
outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
output_text = self.llm_tokenizer.batch_decode(outputs,
skip_special_tokens=True)
output_text = [text.strip() for text in output_text]
output_text = self.post_process(output_text[0])
data_sample = data_samples[0]
data_sample.pred_answer = output_text
return data_sample for i, data_sample in enumerate(data_samples):
output_token = outputs[i]
def post_process(self, output_text): output_text = self.post_processor(output_token, self.llm_tokenizer)
output_text = output_text.split('###')[0] data_sample.pred_answer = output_text
output_text = output_text.split('Assistant:')[-1].strip() data_samples[i] = data_sample
output_text = output_text.strip('</s><s>') return data_samples
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text

View File

@ -0,0 +1,31 @@
import re
import torch
class InstructBlipMMBenchPostProcessor:
""""Post processor for MiniGPT-4 on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
# convert output id 0 to 2 (eos_token_id)
output_token[output_token == 0] = 2
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = self._extract_key_words(output_text.strip())
return output_text
def _extract_key_words(self, output_text: str) -> str:
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text

View File

@ -0,0 +1,55 @@
from typing import List
from mmpretrain.structures import DataSample
class InstructBlipMMBenchPromptConstructor:
"""Prompt constructor for InstructBlip on MMBench.
Args:
image_prompt (str): Image prompt.
reply_prompt (str): Reply prompt.
"""
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
self.image_prompt = image_prompt
self.reply_prompt = reply_prompt
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
contexts = [data_sample.get('context') for data_sample in data_samples]
question = questions[0]
option = options[0]
context = contexts[0]
if context is not None:
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
else:
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
return prompt