mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Refactor] Refactor instructblip (#227)
* refactor instructblip * add post processor * add forward * fix lint * update * update
This commit is contained in:
parent
02ce139bc6
commit
1034c487ef
@ -6,4 +6,44 @@
|
||||
git clone https://github.com/salesforce/LAVIS.git
|
||||
cd ./LAVIS
|
||||
pip install -e .
|
||||
```
|
||||
```
|
||||
|
||||
### Modify the config
|
||||
|
||||
Modify the config of InstructBlip, like model path of LLM and Qformer.
|
||||
|
||||
Then update `tasks.py` like the following code snippet.
|
||||
|
||||
```python
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .instructblip.instructblip_mmbench import (instruct_blip_dataloader,
|
||||
instruct_blip_evaluator,
|
||||
instruct_blip_load_from,
|
||||
instruct_blip_model)
|
||||
|
||||
models = [instruct_blip_model]
|
||||
datasets = [instruct_blip_dataloader]
|
||||
evaluators = [instruct_blip_evaluator]
|
||||
load_froms = [instruct_blip_load_from]
|
||||
num_gpus = 8
|
||||
num_procs = 8
|
||||
launcher = 'pytorch' # or 'slurm'
|
||||
```
|
||||
|
||||
### Start evaluation
|
||||
|
||||
#### Slurm
|
||||
|
||||
```sh
|
||||
cd $root
|
||||
python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
|
||||
```
|
||||
|
||||
#### PyTorch
|
||||
|
||||
```sh
|
||||
cd $root
|
||||
python run.py configs/multimodal/tasks.py --mm-eval
|
||||
```
|
||||
|
@ -1,3 +1,6 @@
|
||||
from opencompass.multimodal.models.instructblip import (
|
||||
InstructBlipMMBenchPromptConstructor, InstructBlipMMBenchPostProcessor)
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
@ -9,24 +12,27 @@ val_pipeline = [
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'category', 'l2-category', 'context',
|
||||
'index', 'options_dict', 'options', 'split'
|
||||
'question', 'category', 'l2-category', 'context', 'index',
|
||||
'options_dict', 'options', 'split'
|
||||
])
|
||||
]
|
||||
|
||||
dataset = dict(type='opencompass.MMBench',
|
||||
dataset = dict(type='opencompass.MMBenchDataset',
|
||||
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||
instruct_blip_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler',
|
||||
shuffle=False))
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='blip2-vicuna-instruct-mmbench',
|
||||
instruct_blip_model = dict(
|
||||
type='blip2-vicuna-instruct',
|
||||
prompt_constructor=dict(type=InstructBlipMMBenchPromptConstructor),
|
||||
post_processor=dict(type=InstructBlipMMBenchPostProcessor),
|
||||
freeze_vit=True,
|
||||
low_resource=False,
|
||||
llm_model='/path/to/vicuna-7b/',
|
||||
@ -35,11 +41,11 @@ model = dict(
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
evaluator = [
|
||||
instruct_blip_evaluator = [
|
||||
dict(
|
||||
type='opencompass.DumpResults',
|
||||
save_path= # noqa: E251
|
||||
'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx')
|
||||
]
|
||||
|
||||
load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth' # noqa
|
||||
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed'
|
@ -1,3 +1,8 @@
|
||||
from .blip2_vicuna_instruct import Blip2VicunaInstructMMBench
|
||||
from .blip2_vicuna_instruct import InstructBlipInferencer
|
||||
from .post_processor import InstructBlipMMBenchPostProcessor
|
||||
from .prompt_constructor import InstructBlipMMBenchPromptConstructor
|
||||
|
||||
__all__ = ['Blip2VicunaInstructMMBench']
|
||||
__all__ = [
|
||||
'InstructBlipInferencer', 'InstructBlipMMBenchPromptConstructor',
|
||||
'InstructBlipMMBenchPostProcessor'
|
||||
]
|
||||
|
@ -1,8 +1,8 @@
|
||||
"""Requires Transformer 4.28 and above, implementation may change according the
|
||||
Llama implementation."""
|
||||
import logging
|
||||
import re
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
|
||||
@ -12,27 +12,36 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
from opencompass.registry import MM_MODELS
|
||||
|
||||
|
||||
@MM_MODELS.register_module('blip2-vicuna-instruct-mmbench')
|
||||
class Blip2VicunaInstructMMBench(Blip2Base):
|
||||
@MM_MODELS.register_module('blip2-vicuna-instruct')
|
||||
class InstructBlipInferencer(Blip2Base):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_model='eva_clip_g',
|
||||
img_size=224,
|
||||
drop_path_rate=0,
|
||||
use_grad_checkpoint=False,
|
||||
vit_precision='fp16',
|
||||
freeze_vit=True,
|
||||
num_query_token=32,
|
||||
llm_model='',
|
||||
sys_prompt='',
|
||||
prompt='',
|
||||
max_txt_len=128,
|
||||
max_output_txt_len=256,
|
||||
qformer_text_input=True,
|
||||
low_resource=False,
|
||||
prompt_constructor: dict,
|
||||
post_processor: dict,
|
||||
vit_model: str = 'eva_clip_g',
|
||||
img_size: int = 224,
|
||||
drop_path_rate: float = 0,
|
||||
use_grad_checkpoint: bool = False,
|
||||
vit_precision: str = 'fp16',
|
||||
freeze_vit: bool = True,
|
||||
num_query_token: int = 32,
|
||||
llm_model: str = '',
|
||||
sys_prompt: str = '',
|
||||
prompt: str = '',
|
||||
max_txt_len: int = 128,
|
||||
max_output_txt_len: int = 256,
|
||||
qformer_text_input: bool = True,
|
||||
low_resource: bool = False,
|
||||
mode: str = 'generation',
|
||||
):
|
||||
super().__init__()
|
||||
self.mode = mode
|
||||
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
||||
prompt_constructor, MM_MODELS)
|
||||
self.post_processor = mmengine.registry.build_from_cfg(
|
||||
post_processor, MM_MODELS)
|
||||
|
||||
self.tokenizer = self.init_tokenizer(truncation_side='left')
|
||||
|
||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
@ -92,6 +101,12 @@ class Blip2VicunaInstructMMBench(Blip2Base):
|
||||
|
||||
self.qformer_text_input = qformer_text_input
|
||||
|
||||
def forward(self, batch):
|
||||
if self.mode == 'generation':
|
||||
return self.generate(batch)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{self.mode}".')
|
||||
|
||||
def concat_text_input_output(self, input_ids, input_atts, output_ids,
|
||||
output_atts):
|
||||
input_part_targets_len = []
|
||||
@ -136,31 +151,13 @@ class Blip2VicunaInstructMMBench(Blip2Base):
|
||||
temperature=1,
|
||||
):
|
||||
inputs = self.pack_inputs(batch)
|
||||
image = inputs.pop('image')
|
||||
inputs = self.prompt_constructor(inputs)
|
||||
image = inputs['image']
|
||||
prompt = inputs['prompt']
|
||||
data_samples = inputs['data_samples']
|
||||
samples = {'image': image}
|
||||
questions = [
|
||||
data_sample.get('question') for data_sample in data_samples
|
||||
]
|
||||
options = [data_sample.get('options') for data_sample in data_samples]
|
||||
if data_samples[0].get('context') is not None:
|
||||
contexts = [
|
||||
data_sample.get('context') for data_sample in data_samples
|
||||
]
|
||||
prompt = [
|
||||
context + ' ' + question + ' ' + option for context, question,
|
||||
option in zip(contexts, questions, options)
|
||||
]
|
||||
else:
|
||||
prompt = [
|
||||
question + ' ' + option
|
||||
for question, option in zip(questions, options)
|
||||
]
|
||||
|
||||
self.llm_tokenizer.padding_side = 'left'
|
||||
|
||||
image = samples['image']
|
||||
|
||||
bs = image.size(0)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
@ -237,24 +234,10 @@ class Blip2VicunaInstructMMBench(Blip2Base):
|
||||
length_penalty=length_penalty,
|
||||
num_return_sequences=num_captions,
|
||||
)
|
||||
outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
|
||||
output_text = self.llm_tokenizer.batch_decode(outputs,
|
||||
skip_special_tokens=True)
|
||||
output_text = [text.strip() for text in output_text]
|
||||
output_text = self.post_process(output_text[0])
|
||||
data_sample = data_samples[0]
|
||||
data_sample.pred_answer = output_text
|
||||
|
||||
return data_sample
|
||||
|
||||
def post_process(self, output_text):
|
||||
output_text = output_text.split('###')[0]
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
output_text = output_text.strip('</s><s>')
|
||||
output_text = output_text.strip('</Img>')
|
||||
output_text = output_text.strip()
|
||||
pattern = re.compile(r'([A-Z]\.)')
|
||||
res = pattern.findall(output_text)
|
||||
if len(res) > 0:
|
||||
output_text = res[0][:-1]
|
||||
return output_text
|
||||
for i, data_sample in enumerate(data_samples):
|
||||
output_token = outputs[i]
|
||||
output_text = self.post_processor(output_token, self.llm_tokenizer)
|
||||
data_sample.pred_answer = output_text
|
||||
data_samples[i] = data_sample
|
||||
return data_samples
|
||||
|
31
opencompass/multimodal/models/instructblip/post_processor.py
Normal file
31
opencompass/multimodal/models/instructblip/post_processor.py
Normal file
@ -0,0 +1,31 @@
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class InstructBlipMMBenchPostProcessor:
|
||||
""""Post processor for MiniGPT-4 on MMBench."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||
# convert output id 0 to 2 (eos_token_id)
|
||||
output_token[output_token == 0] = 2
|
||||
output_text = tokenizer.decode(output_token,
|
||||
add_special_tokens=False) # noqa
|
||||
output_text = self._extract_key_words(output_text.strip())
|
||||
return output_text
|
||||
|
||||
def _extract_key_words(self, output_text: str) -> str:
|
||||
|
||||
output_text = output_text.split('###')[0]
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
output_text = output_text.strip('</s><s>')
|
||||
output_text = output_text.strip('</Img>')
|
||||
output_text = output_text.strip()
|
||||
pattern = re.compile(r'([A-Z]\.)')
|
||||
res = pattern.findall(output_text)
|
||||
if len(res) > 0:
|
||||
output_text = res[0][:-1]
|
||||
return output_text
|
@ -0,0 +1,55 @@
|
||||
from typing import List
|
||||
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
class InstructBlipMMBenchPromptConstructor:
|
||||
"""Prompt constructor for InstructBlip on MMBench.
|
||||
|
||||
Args:
|
||||
image_prompt (str): Image prompt.
|
||||
reply_prompt (str): Reply prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
|
||||
self.image_prompt = image_prompt
|
||||
self.reply_prompt = reply_prompt
|
||||
|
||||
def __call__(self, inputs: dict) -> dict:
|
||||
"""Construct prompt.
|
||||
|
||||
Args:
|
||||
inputs (dict): Input data containing image and data_samples.
|
||||
|
||||
Returns:
|
||||
dict: A dict containing prompt, images and data_samples.
|
||||
"""
|
||||
data_samples = inputs['data_samples']
|
||||
prompt = self._process(data_samples)
|
||||
inputs.update({'prompt': prompt})
|
||||
|
||||
return inputs
|
||||
|
||||
def _process(self, data_samples: List[DataSample]) -> str:
|
||||
"""Process data sample to prompt.
|
||||
|
||||
Args:
|
||||
data_samples (List[DataSample]): A list of data_samples.
|
||||
|
||||
Returns:
|
||||
str: Prompt.
|
||||
"""
|
||||
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||
questions = [
|
||||
data_sample.get('question') for data_sample in data_samples
|
||||
]
|
||||
options = [data_sample.get('options') for data_sample in data_samples]
|
||||
contexts = [data_sample.get('context') for data_sample in data_samples]
|
||||
question = questions[0]
|
||||
option = options[0]
|
||||
context = contexts[0]
|
||||
if context is not None:
|
||||
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
|
||||
else:
|
||||
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
|
||||
return prompt
|
Loading…
Reference in New Issue
Block a user