mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Refactor] Refactor instructblip (#227)
* refactor instructblip * add post processor * add forward * fix lint * update * update
This commit is contained in:
parent
02ce139bc6
commit
1034c487ef
@ -7,3 +7,43 @@ git clone https://github.com/salesforce/LAVIS.git
|
|||||||
cd ./LAVIS
|
cd ./LAVIS
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Modify the config
|
||||||
|
|
||||||
|
Modify the config of InstructBlip, like model path of LLM and Qformer.
|
||||||
|
|
||||||
|
Then update `tasks.py` like the following code snippet.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mmengine.config import read_base
|
||||||
|
|
||||||
|
with read_base():
|
||||||
|
from .instructblip.instructblip_mmbench import (instruct_blip_dataloader,
|
||||||
|
instruct_blip_evaluator,
|
||||||
|
instruct_blip_load_from,
|
||||||
|
instruct_blip_model)
|
||||||
|
|
||||||
|
models = [instruct_blip_model]
|
||||||
|
datasets = [instruct_blip_dataloader]
|
||||||
|
evaluators = [instruct_blip_evaluator]
|
||||||
|
load_froms = [instruct_blip_load_from]
|
||||||
|
num_gpus = 8
|
||||||
|
num_procs = 8
|
||||||
|
launcher = 'pytorch' # or 'slurm'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Start evaluation
|
||||||
|
|
||||||
|
#### Slurm
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd $root
|
||||||
|
python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
|
||||||
|
```
|
||||||
|
|
||||||
|
#### PyTorch
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd $root
|
||||||
|
python run.py configs/multimodal/tasks.py --mm-eval
|
||||||
|
```
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
from opencompass.multimodal.models.instructblip import (
|
||||||
|
InstructBlipMMBenchPromptConstructor, InstructBlipMMBenchPostProcessor)
|
||||||
|
|
||||||
# dataloader settings
|
# dataloader settings
|
||||||
val_pipeline = [
|
val_pipeline = [
|
||||||
dict(type='mmpretrain.torchvision/Resize',
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
@ -9,24 +12,27 @@ val_pipeline = [
|
|||||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
dict(type='mmpretrain.PackInputs',
|
dict(type='mmpretrain.PackInputs',
|
||||||
algorithm_keys=[
|
algorithm_keys=[
|
||||||
'question', 'category', 'l2-category', 'context',
|
'question', 'category', 'l2-category', 'context', 'index',
|
||||||
'index', 'options_dict', 'options', 'split'
|
'options_dict', 'options', 'split'
|
||||||
])
|
])
|
||||||
]
|
]
|
||||||
|
|
||||||
dataset = dict(type='opencompass.MMBench',
|
dataset = dict(type='opencompass.MMBenchDataset',
|
||||||
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
||||||
pipeline=val_pipeline)
|
pipeline=val_pipeline)
|
||||||
|
|
||||||
dataloader = dict(batch_size=1,
|
instruct_blip_dataloader = dict(batch_size=1,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
collate_fn=dict(type='pseudo_collate'),
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
sampler=dict(type='DefaultSampler', shuffle=False))
|
sampler=dict(type='DefaultSampler',
|
||||||
|
shuffle=False))
|
||||||
|
|
||||||
# model settings
|
# model settings
|
||||||
model = dict(
|
instruct_blip_model = dict(
|
||||||
type='blip2-vicuna-instruct-mmbench',
|
type='blip2-vicuna-instruct',
|
||||||
|
prompt_constructor=dict(type=InstructBlipMMBenchPromptConstructor),
|
||||||
|
post_processor=dict(type=InstructBlipMMBenchPostProcessor),
|
||||||
freeze_vit=True,
|
freeze_vit=True,
|
||||||
low_resource=False,
|
low_resource=False,
|
||||||
llm_model='/path/to/vicuna-7b/',
|
llm_model='/path/to/vicuna-7b/',
|
||||||
@ -35,11 +41,11 @@ model = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# evaluation settings
|
# evaluation settings
|
||||||
evaluator = [
|
instruct_blip_evaluator = [
|
||||||
dict(
|
dict(
|
||||||
type='opencompass.DumpResults',
|
type='opencompass.DumpResults',
|
||||||
save_path= # noqa: E251
|
save_path= # noqa: E251
|
||||||
'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx')
|
'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx')
|
||||||
]
|
]
|
||||||
|
|
||||||
load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth' # noqa
|
instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed'
|
@ -1,3 +1,8 @@
|
|||||||
from .blip2_vicuna_instruct import Blip2VicunaInstructMMBench
|
from .blip2_vicuna_instruct import InstructBlipInferencer
|
||||||
|
from .post_processor import InstructBlipMMBenchPostProcessor
|
||||||
|
from .prompt_constructor import InstructBlipMMBenchPromptConstructor
|
||||||
|
|
||||||
__all__ = ['Blip2VicunaInstructMMBench']
|
__all__ = [
|
||||||
|
'InstructBlipInferencer', 'InstructBlipMMBenchPromptConstructor',
|
||||||
|
'InstructBlipMMBenchPostProcessor'
|
||||||
|
]
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
"""Requires Transformer 4.28 and above, implementation may change according the
|
"""Requires Transformer 4.28 and above, implementation may change according the
|
||||||
Llama implementation."""
|
Llama implementation."""
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
|
|
||||||
|
import mmengine
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
|
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
|
||||||
@ -12,27 +12,36 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
|
|||||||
from opencompass.registry import MM_MODELS
|
from opencompass.registry import MM_MODELS
|
||||||
|
|
||||||
|
|
||||||
@MM_MODELS.register_module('blip2-vicuna-instruct-mmbench')
|
@MM_MODELS.register_module('blip2-vicuna-instruct')
|
||||||
class Blip2VicunaInstructMMBench(Blip2Base):
|
class InstructBlipInferencer(Blip2Base):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vit_model='eva_clip_g',
|
prompt_constructor: dict,
|
||||||
img_size=224,
|
post_processor: dict,
|
||||||
drop_path_rate=0,
|
vit_model: str = 'eva_clip_g',
|
||||||
use_grad_checkpoint=False,
|
img_size: int = 224,
|
||||||
vit_precision='fp16',
|
drop_path_rate: float = 0,
|
||||||
freeze_vit=True,
|
use_grad_checkpoint: bool = False,
|
||||||
num_query_token=32,
|
vit_precision: str = 'fp16',
|
||||||
llm_model='',
|
freeze_vit: bool = True,
|
||||||
sys_prompt='',
|
num_query_token: int = 32,
|
||||||
prompt='',
|
llm_model: str = '',
|
||||||
max_txt_len=128,
|
sys_prompt: str = '',
|
||||||
max_output_txt_len=256,
|
prompt: str = '',
|
||||||
qformer_text_input=True,
|
max_txt_len: int = 128,
|
||||||
low_resource=False,
|
max_output_txt_len: int = 256,
|
||||||
|
qformer_text_input: bool = True,
|
||||||
|
low_resource: bool = False,
|
||||||
|
mode: str = 'generation',
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.mode = mode
|
||||||
|
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
||||||
|
prompt_constructor, MM_MODELS)
|
||||||
|
self.post_processor = mmengine.registry.build_from_cfg(
|
||||||
|
post_processor, MM_MODELS)
|
||||||
|
|
||||||
self.tokenizer = self.init_tokenizer(truncation_side='left')
|
self.tokenizer = self.init_tokenizer(truncation_side='left')
|
||||||
|
|
||||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||||
@ -92,6 +101,12 @@ class Blip2VicunaInstructMMBench(Blip2Base):
|
|||||||
|
|
||||||
self.qformer_text_input = qformer_text_input
|
self.qformer_text_input = qformer_text_input
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
if self.mode == 'generation':
|
||||||
|
return self.generate(batch)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Invalid mode "{self.mode}".')
|
||||||
|
|
||||||
def concat_text_input_output(self, input_ids, input_atts, output_ids,
|
def concat_text_input_output(self, input_ids, input_atts, output_ids,
|
||||||
output_atts):
|
output_atts):
|
||||||
input_part_targets_len = []
|
input_part_targets_len = []
|
||||||
@ -136,31 +151,13 @@ class Blip2VicunaInstructMMBench(Blip2Base):
|
|||||||
temperature=1,
|
temperature=1,
|
||||||
):
|
):
|
||||||
inputs = self.pack_inputs(batch)
|
inputs = self.pack_inputs(batch)
|
||||||
image = inputs.pop('image')
|
inputs = self.prompt_constructor(inputs)
|
||||||
|
image = inputs['image']
|
||||||
|
prompt = inputs['prompt']
|
||||||
data_samples = inputs['data_samples']
|
data_samples = inputs['data_samples']
|
||||||
samples = {'image': image}
|
|
||||||
questions = [
|
|
||||||
data_sample.get('question') for data_sample in data_samples
|
|
||||||
]
|
|
||||||
options = [data_sample.get('options') for data_sample in data_samples]
|
|
||||||
if data_samples[0].get('context') is not None:
|
|
||||||
contexts = [
|
|
||||||
data_sample.get('context') for data_sample in data_samples
|
|
||||||
]
|
|
||||||
prompt = [
|
|
||||||
context + ' ' + question + ' ' + option for context, question,
|
|
||||||
option in zip(contexts, questions, options)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
prompt = [
|
|
||||||
question + ' ' + option
|
|
||||||
for question, option in zip(questions, options)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.llm_tokenizer.padding_side = 'left'
|
self.llm_tokenizer.padding_side = 'left'
|
||||||
|
|
||||||
image = samples['image']
|
|
||||||
|
|
||||||
bs = image.size(0)
|
bs = image.size(0)
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
@ -237,24 +234,10 @@ class Blip2VicunaInstructMMBench(Blip2Base):
|
|||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
num_return_sequences=num_captions,
|
num_return_sequences=num_captions,
|
||||||
)
|
)
|
||||||
outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
|
|
||||||
output_text = self.llm_tokenizer.batch_decode(outputs,
|
for i, data_sample in enumerate(data_samples):
|
||||||
skip_special_tokens=True)
|
output_token = outputs[i]
|
||||||
output_text = [text.strip() for text in output_text]
|
output_text = self.post_processor(output_token, self.llm_tokenizer)
|
||||||
output_text = self.post_process(output_text[0])
|
|
||||||
data_sample = data_samples[0]
|
|
||||||
data_sample.pred_answer = output_text
|
data_sample.pred_answer = output_text
|
||||||
|
data_samples[i] = data_sample
|
||||||
return data_sample
|
return data_samples
|
||||||
|
|
||||||
def post_process(self, output_text):
|
|
||||||
output_text = output_text.split('###')[0]
|
|
||||||
output_text = output_text.split('Assistant:')[-1].strip()
|
|
||||||
output_text = output_text.strip('</s><s>')
|
|
||||||
output_text = output_text.strip('</Img>')
|
|
||||||
output_text = output_text.strip()
|
|
||||||
pattern = re.compile(r'([A-Z]\.)')
|
|
||||||
res = pattern.findall(output_text)
|
|
||||||
if len(res) > 0:
|
|
||||||
output_text = res[0][:-1]
|
|
||||||
return output_text
|
|
||||||
|
31
opencompass/multimodal/models/instructblip/post_processor.py
Normal file
31
opencompass/multimodal/models/instructblip/post_processor.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class InstructBlipMMBenchPostProcessor:
|
||||||
|
""""Post processor for MiniGPT-4 on MMBench."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||||
|
# convert output id 0 to 2 (eos_token_id)
|
||||||
|
output_token[output_token == 0] = 2
|
||||||
|
output_text = tokenizer.decode(output_token,
|
||||||
|
add_special_tokens=False) # noqa
|
||||||
|
output_text = self._extract_key_words(output_text.strip())
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
def _extract_key_words(self, output_text: str) -> str:
|
||||||
|
|
||||||
|
output_text = output_text.split('###')[0]
|
||||||
|
output_text = output_text.split('Assistant:')[-1].strip()
|
||||||
|
output_text = output_text.strip('</s><s>')
|
||||||
|
output_text = output_text.strip('</Img>')
|
||||||
|
output_text = output_text.strip()
|
||||||
|
pattern = re.compile(r'([A-Z]\.)')
|
||||||
|
res = pattern.findall(output_text)
|
||||||
|
if len(res) > 0:
|
||||||
|
output_text = res[0][:-1]
|
||||||
|
return output_text
|
@ -0,0 +1,55 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from mmpretrain.structures import DataSample
|
||||||
|
|
||||||
|
|
||||||
|
class InstructBlipMMBenchPromptConstructor:
|
||||||
|
"""Prompt constructor for InstructBlip on MMBench.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_prompt (str): Image prompt.
|
||||||
|
reply_prompt (str): Reply prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
|
||||||
|
self.image_prompt = image_prompt
|
||||||
|
self.reply_prompt = reply_prompt
|
||||||
|
|
||||||
|
def __call__(self, inputs: dict) -> dict:
|
||||||
|
"""Construct prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (dict): Input data containing image and data_samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dict containing prompt, images and data_samples.
|
||||||
|
"""
|
||||||
|
data_samples = inputs['data_samples']
|
||||||
|
prompt = self._process(data_samples)
|
||||||
|
inputs.update({'prompt': prompt})
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _process(self, data_samples: List[DataSample]) -> str:
|
||||||
|
"""Process data sample to prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_samples (List[DataSample]): A list of data_samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Prompt.
|
||||||
|
"""
|
||||||
|
assert len(data_samples) == 1, 'Only support batch size 1.'
|
||||||
|
questions = [
|
||||||
|
data_sample.get('question') for data_sample in data_samples
|
||||||
|
]
|
||||||
|
options = [data_sample.get('options') for data_sample in data_samples]
|
||||||
|
contexts = [data_sample.get('context') for data_sample in data_samples]
|
||||||
|
question = questions[0]
|
||||||
|
option = options[0]
|
||||||
|
context = contexts[0]
|
||||||
|
if context is not None:
|
||||||
|
prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
|
||||||
|
else:
|
||||||
|
prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa
|
||||||
|
return prompt
|
Loading…
Reference in New Issue
Block a user