[Refactor] Refactor instructblip (#227)

* refactor instructblip

* add post processor

* add forward

* fix lint

* update

* update
This commit is contained in:
Yixiao Fang 2023-08-23 15:33:59 +08:00 committed by GitHub
parent 02ce139bc6
commit 1034c487ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 193 additions and 73 deletions

View File

@ -6,4 +6,44 @@
git clone https://github.com/salesforce/LAVIS.git
cd ./LAVIS
pip install -e .
```
```
### Modify the config
Modify the config of InstructBlip, like model path of LLM and Qformer.
Then update `tasks.py` like the following code snippet.
```python
from mmengine.config import read_base
with read_base():
from .instructblip.instructblip_mmbench import (instruct_blip_dataloader,
instruct_blip_evaluator,
instruct_blip_load_from,
instruct_blip_model)
models = [instruct_blip_model]
datasets = [instruct_blip_dataloader]
evaluators = [instruct_blip_evaluator]
load_froms = [instruct_blip_load_from]
num_gpus = 8
num_procs = 8
launcher = 'pytorch' # or 'slurm'
```
### Start evaluation
#### Slurm
```sh
cd $root
python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
```
#### PyTorch
```sh
cd $root
python run.py configs/multimodal/tasks.py --mm-eval
```

View File

@ -1,3 +1,6 @@
from opencompass.multimodal.models.instructblip import (
InstructBlipMMBenchPromptConstructor, InstructBlipMMBenchPostProcessor)
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.torchvision/Resize',
@ -9,24 +12,27 @@ val_pipeline = [
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'category', 'l2-category', 'context',
'index', 'options_dict', 'options', 'split'
'question', 'category', 'l2-category', 'context', 'index',
'options_dict', 'options', 'split'
])
]
dataset = dict(type='opencompass.MMBench',
dataset = dict(type='opencompass.MMBenchDataset',
data_file='data/mmbench/mmbench_test_20230712.tsv',
pipeline=val_pipeline)
dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
instruct_blip_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler',
shuffle=False))
# model settings
model = dict(
type='blip2-vicuna-instruct-mmbench',
instruct_blip_model = dict(
type='blip2-vicuna-instruct',
prompt_constructor=dict(type=InstructBlipMMBenchPromptConstructor),
post_processor=dict(type=InstructBlipMMBenchPostProcessor),
freeze_vit=True,
low_resource=False,
llm_model='/path/to/vicuna-7b/',
@ -35,11 +41,11 @@ model = dict(
)
# evaluation settings
evaluator = [
instruct_blip_evaluator = [
dict(
type='opencompass.DumpResults',
save_path= # noqa: E251
'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx')
]
load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth' # noqa
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed'

View File

@ -1,3 +1,8 @@
from .blip2_vicuna_instruct import Blip2VicunaInstructMMBench
from .blip2_vicuna_instruct import InstructBlipInferencer
from .post_processor import InstructBlipMMBenchPostProcessor
from .prompt_constructor import InstructBlipMMBenchPromptConstructor
__all__ = ['Blip2VicunaInstructMMBench']
__all__ = [
'InstructBlipInferencer', 'InstructBlipMMBenchPromptConstructor',
'InstructBlipMMBenchPostProcessor'
]

View File

@ -1,8 +1,8 @@
"""Requires Transformer 4.28 and above, implementation may change according the
Llama implementation."""
import logging
import re
import mmengine
import torch
import torch.nn as nn
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
@ -12,27 +12,36 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
from opencompass.registry import MM_MODELS
@MM_MODELS.register_module('blip2-vicuna-instruct-mmbench')
class Blip2VicunaInstructMMBench(Blip2Base):
@MM_MODELS.register_module('blip2-vicuna-instruct')
class InstructBlipInferencer(Blip2Base):
def __init__(
self,
vit_model='eva_clip_g',
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision='fp16',
freeze_vit=True,
num_query_token=32,
llm_model='',
sys_prompt='',
prompt='',
max_txt_len=128,
max_output_txt_len=256,
qformer_text_input=True,
low_resource=False,
prompt_constructor: dict,
post_processor: dict,
vit_model: str = 'eva_clip_g',
img_size: int = 224,
drop_path_rate: float = 0,
use_grad_checkpoint: bool = False,
vit_precision: str = 'fp16',
freeze_vit: bool = True,
num_query_token: int = 32,
llm_model: str = '',
sys_prompt: str = '',
prompt: str = '',
max_txt_len: int = 128,
max_output_txt_len: int = 256,
qformer_text_input: bool = True,
low_resource: bool = False,
mode: str = 'generation',
):
super().__init__()
self.mode = mode
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
self.tokenizer = self.init_tokenizer(truncation_side='left')
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
@ -92,6 +101,12 @@ class Blip2VicunaInstructMMBench(Blip2Base):
self.qformer_text_input = qformer_text_input
def forward(self, batch):
if self.mode == 'generation':
return self.generate(batch)
else:
raise RuntimeError(f'Invalid mode "{self.mode}".')
def concat_text_input_output(self, input_ids, input_atts, output_ids,
output_atts):
input_part_targets_len = []
@ -136,31 +151,13 @@ class Blip2VicunaInstructMMBench(Blip2Base):
temperature=1,
):
inputs = self.pack_inputs(batch)
image = inputs.pop('image')
inputs = self.prompt_constructor(inputs)
image = inputs['image']
prompt = inputs['prompt']
data_samples = inputs['data_samples']
samples = {'image': image}
questions = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
if data_samples[0].get('context') is not None:
contexts = [
data_sample.get('context') for data_sample in data_samples
]
prompt = [
context + ' ' + question + ' ' + option for context, question,
option in zip(contexts, questions, options)
]
else:
prompt = [
question + ' ' + option
for question, option in zip(questions, options)
]
self.llm_tokenizer.padding_side = 'left'
image = samples['image']
bs = image.size(0)
if isinstance(prompt, str):
@ -237,24 +234,10 @@ class Blip2VicunaInstructMMBench(Blip2Base):
length_penalty=length_penalty,
num_return_sequences=num_captions,
)
outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
output_text = self.llm_tokenizer.batch_decode(outputs,
skip_special_tokens=True)
output_text = [text.strip() for text in output_text]
output_text = self.post_process(output_text[0])
data_sample = data_samples[0]
data_sample.pred_answer = output_text
return data_sample
def post_process(self, output_text):
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text
for i, data_sample in enumerate(data_samples):
output_token = outputs[i]
output_text = self.post_processor(output_token, self.llm_tokenizer)
data_sample.pred_answer = output_text
data_samples[i] = data_sample
return data_samples

View File

@ -0,0 +1,31 @@
import re
import torch
class InstructBlipMMBenchPostProcessor:
""""Post processor for MiniGPT-4 on MMBench."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
# convert output id 0 to 2 (eos_token_id)
output_token[output_token == 0] = 2
output_text = tokenizer.decode(output_token,
add_special_tokens=False) # noqa
output_text = self._extract_key_words(output_text.strip())
return output_text
def _extract_key_words(self, output_text: str) -> str:
output_text = output_text.split('###')[0]
output_text = output_text.split('Assistant:')[-1].strip()
output_text = output_text.strip('</s><s>')
output_text = output_text.strip('</Img>')
output_text = output_text.strip()
pattern = re.compile(r'([A-Z]\.)')
res = pattern.findall(output_text)
if len(res) > 0:
output_text = res[0][:-1]
return output_text

View File

@ -0,0 +1,55 @@
from typing import List
from mmpretrain.structures import DataSample
class InstructBlipMMBenchPromptConstructor:
"""Prompt constructor for InstructBlip on MMBench.
Args:
image_prompt (str): Image prompt.
reply_prompt (str): Reply prompt.
"""
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
self.image_prompt = image_prompt
self.reply_prompt = reply_prompt
def __call__(self, inputs: dict) -> dict:
"""Construct prompt.
Args:
inputs (dict): Input data containing image and data_samples.
Returns:
dict: A dict containing prompt, images and data_samples.
"""
data_samples = inputs['data_samples']
prompt = self._process(data_samples)
inputs.update({'prompt': prompt})
return inputs
def _process(self, data_samples: List[DataSample]) -> str:
"""Process data sample to prompt.
Args:
data_samples (List[DataSample]): A list of data_samples.
Returns:
str: Prompt.
"""
assert len(data_samples) == 1, 'Only support batch size 1.'
questions = [
data_sample.get('question') for data_sample in data_samples
]
options = [data_sample.get('options') for data_sample in data_samples]
contexts = [data_sample.get('context') for data_sample in data_samples]
question = questions[0]
option = options[0]
context = contexts[0]
if context is not None:
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
else:
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
return prompt