From 1034c487efe3cc0854791c7d6115418f71714fb7 Mon Sep 17 00:00:00 2001 From: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com> Date: Wed, 23 Aug 2023 15:33:59 +0800 Subject: [PATCH] [Refactor] Refactor instructblip (#227) * refactor instructblip * add post processor * add forward * fix lint * update * update --- configs/multimodal/instructblip/README.md | 42 +++++++- ...lip-mmbench.py => instructblip_mmbench.py} | 30 +++--- .../models/instructblip/__init__.py | 9 +- .../instructblip/blip2_vicuna_instruct.py | 99 ++++++++----------- .../models/instructblip/post_processor.py | 31 ++++++ .../models/instructblip/prompt_constructor.py | 55 +++++++++++ 6 files changed, 193 insertions(+), 73 deletions(-) rename configs/multimodal/instructblip/{instructblip-mmbench.py => instructblip_mmbench.py} (51%) create mode 100644 opencompass/multimodal/models/instructblip/post_processor.py create mode 100644 opencompass/multimodal/models/instructblip/prompt_constructor.py diff --git a/configs/multimodal/instructblip/README.md b/configs/multimodal/instructblip/README.md index 1002328d..1b5ea393 100644 --- a/configs/multimodal/instructblip/README.md +++ b/configs/multimodal/instructblip/README.md @@ -6,4 +6,44 @@ git clone https://github.com/salesforce/LAVIS.git cd ./LAVIS pip install -e . -``` \ No newline at end of file +``` + +### Modify the config + +Modify the config of InstructBlip, like model path of LLM and Qformer. + +Then update `tasks.py` like the following code snippet. + +```python +from mmengine.config import read_base + +with read_base(): + from .instructblip.instructblip_mmbench import (instruct_blip_dataloader, + instruct_blip_evaluator, + instruct_blip_load_from, + instruct_blip_model) + +models = [instruct_blip_model] +datasets = [instruct_blip_dataloader] +evaluators = [instruct_blip_evaluator] +load_froms = [instruct_blip_load_from] +num_gpus = 8 +num_procs = 8 +launcher = 'pytorch' # or 'slurm' +``` + +### Start evaluation + +#### Slurm + +```sh +cd $root +python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION +``` + +#### PyTorch + +```sh +cd $root +python run.py configs/multimodal/tasks.py --mm-eval +``` diff --git a/configs/multimodal/instructblip/instructblip-mmbench.py b/configs/multimodal/instructblip/instructblip_mmbench.py similarity index 51% rename from configs/multimodal/instructblip/instructblip-mmbench.py rename to configs/multimodal/instructblip/instructblip_mmbench.py index 2ae74009..b7113e69 100644 --- a/configs/multimodal/instructblip/instructblip-mmbench.py +++ b/configs/multimodal/instructblip/instructblip_mmbench.py @@ -1,3 +1,6 @@ +from opencompass.multimodal.models.instructblip import ( + InstructBlipMMBenchPromptConstructor, InstructBlipMMBenchPostProcessor) + # dataloader settings val_pipeline = [ dict(type='mmpretrain.torchvision/Resize', @@ -9,24 +12,27 @@ val_pipeline = [ std=(0.26862954, 0.26130258, 0.27577711)), dict(type='mmpretrain.PackInputs', algorithm_keys=[ - 'question', 'category', 'l2-category', 'context', - 'index', 'options_dict', 'options', 'split' + 'question', 'category', 'l2-category', 'context', 'index', + 'options_dict', 'options', 'split' ]) ] -dataset = dict(type='opencompass.MMBench', +dataset = dict(type='opencompass.MMBenchDataset', data_file='data/mmbench/mmbench_test_20230712.tsv', pipeline=val_pipeline) -dataloader = dict(batch_size=1, - num_workers=4, - dataset=dataset, - collate_fn=dict(type='pseudo_collate'), - sampler=dict(type='DefaultSampler', shuffle=False)) +instruct_blip_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', + shuffle=False)) # model settings -model = dict( - type='blip2-vicuna-instruct-mmbench', +instruct_blip_model = dict( + type='blip2-vicuna-instruct', + prompt_constructor=dict(type=InstructBlipMMBenchPromptConstructor), + post_processor=dict(type=InstructBlipMMBenchPostProcessor), freeze_vit=True, low_resource=False, llm_model='/path/to/vicuna-7b/', @@ -35,11 +41,11 @@ model = dict( ) # evaluation settings -evaluator = [ +instruct_blip_evaluator = [ dict( type='opencompass.DumpResults', save_path= # noqa: E251 'work_dirs/instructblip_vicuna7b/instructblipvicuna_mmbench.xlsx') ] -load_from = '/path/to/instruct_blip_vicuna7b_trimmed.pth' # noqa +instruct_blip_load_from = '/path/to/instruct_blip_vicuna7b_trimmed' diff --git a/opencompass/multimodal/models/instructblip/__init__.py b/opencompass/multimodal/models/instructblip/__init__.py index 1aa1c98b..af926280 100644 --- a/opencompass/multimodal/models/instructblip/__init__.py +++ b/opencompass/multimodal/models/instructblip/__init__.py @@ -1,3 +1,8 @@ -from .blip2_vicuna_instruct import Blip2VicunaInstructMMBench +from .blip2_vicuna_instruct import InstructBlipInferencer +from .post_processor import InstructBlipMMBenchPostProcessor +from .prompt_constructor import InstructBlipMMBenchPromptConstructor -__all__ = ['Blip2VicunaInstructMMBench'] +__all__ = [ + 'InstructBlipInferencer', 'InstructBlipMMBenchPromptConstructor', + 'InstructBlipMMBenchPostProcessor' +] diff --git a/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py b/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py index 6595df10..bc08a31d 100644 --- a/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py +++ b/opencompass/multimodal/models/instructblip/blip2_vicuna_instruct.py @@ -1,8 +1,8 @@ """Requires Transformer 4.28 and above, implementation may change according the Llama implementation.""" import logging -import re +import mmengine import torch import torch.nn as nn from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train @@ -12,27 +12,36 @@ from transformers import LlamaForCausalLM, LlamaTokenizer from opencompass.registry import MM_MODELS -@MM_MODELS.register_module('blip2-vicuna-instruct-mmbench') -class Blip2VicunaInstructMMBench(Blip2Base): +@MM_MODELS.register_module('blip2-vicuna-instruct') +class InstructBlipInferencer(Blip2Base): def __init__( self, - vit_model='eva_clip_g', - img_size=224, - drop_path_rate=0, - use_grad_checkpoint=False, - vit_precision='fp16', - freeze_vit=True, - num_query_token=32, - llm_model='', - sys_prompt='', - prompt='', - max_txt_len=128, - max_output_txt_len=256, - qformer_text_input=True, - low_resource=False, + prompt_constructor: dict, + post_processor: dict, + vit_model: str = 'eva_clip_g', + img_size: int = 224, + drop_path_rate: float = 0, + use_grad_checkpoint: bool = False, + vit_precision: str = 'fp16', + freeze_vit: bool = True, + num_query_token: int = 32, + llm_model: str = '', + sys_prompt: str = '', + prompt: str = '', + max_txt_len: int = 128, + max_output_txt_len: int = 256, + qformer_text_input: bool = True, + low_resource: bool = False, + mode: str = 'generation', ): super().__init__() + self.mode = mode + self.prompt_constructor = mmengine.registry.build_from_cfg( + prompt_constructor, MM_MODELS) + self.post_processor = mmengine.registry.build_from_cfg( + post_processor, MM_MODELS) + self.tokenizer = self.init_tokenizer(truncation_side='left') self.visual_encoder, self.ln_vision = self.init_vision_encoder( @@ -92,6 +101,12 @@ class Blip2VicunaInstructMMBench(Blip2Base): self.qformer_text_input = qformer_text_input + def forward(self, batch): + if self.mode == 'generation': + return self.generate(batch) + else: + raise RuntimeError(f'Invalid mode "{self.mode}".') + def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts): input_part_targets_len = [] @@ -136,31 +151,13 @@ class Blip2VicunaInstructMMBench(Blip2Base): temperature=1, ): inputs = self.pack_inputs(batch) - image = inputs.pop('image') + inputs = self.prompt_constructor(inputs) + image = inputs['image'] + prompt = inputs['prompt'] data_samples = inputs['data_samples'] - samples = {'image': image} - questions = [ - data_sample.get('question') for data_sample in data_samples - ] - options = [data_sample.get('options') for data_sample in data_samples] - if data_samples[0].get('context') is not None: - contexts = [ - data_sample.get('context') for data_sample in data_samples - ] - prompt = [ - context + ' ' + question + ' ' + option for context, question, - option in zip(contexts, questions, options) - ] - else: - prompt = [ - question + ' ' + option - for question, option in zip(questions, options) - ] self.llm_tokenizer.padding_side = 'left' - image = samples['image'] - bs = image.size(0) if isinstance(prompt, str): @@ -237,24 +234,10 @@ class Blip2VicunaInstructMMBench(Blip2Base): length_penalty=length_penalty, num_return_sequences=num_captions, ) - outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id) - output_text = self.llm_tokenizer.batch_decode(outputs, - skip_special_tokens=True) - output_text = [text.strip() for text in output_text] - output_text = self.post_process(output_text[0]) - data_sample = data_samples[0] - data_sample.pred_answer = output_text - return data_sample - - def post_process(self, output_text): - output_text = output_text.split('###')[0] - output_text = output_text.split('Assistant:')[-1].strip() - output_text = output_text.strip('') - output_text = output_text.strip('') - output_text = output_text.strip() - pattern = re.compile(r'([A-Z]\.)') - res = pattern.findall(output_text) - if len(res) > 0: - output_text = res[0][:-1] - return output_text + for i, data_sample in enumerate(data_samples): + output_token = outputs[i] + output_text = self.post_processor(output_token, self.llm_tokenizer) + data_sample.pred_answer = output_text + data_samples[i] = data_sample + return data_samples diff --git a/opencompass/multimodal/models/instructblip/post_processor.py b/opencompass/multimodal/models/instructblip/post_processor.py new file mode 100644 index 00000000..0b124a6f --- /dev/null +++ b/opencompass/multimodal/models/instructblip/post_processor.py @@ -0,0 +1,31 @@ +import re + +import torch + + +class InstructBlipMMBenchPostProcessor: + """"Post processor for MiniGPT-4 on MMBench.""" + + def __init__(self) -> None: + pass + + def __call__(self, output_token: torch.tensor, tokenizer) -> str: + # convert output id 0 to 2 (eos_token_id) + output_token[output_token == 0] = 2 + output_text = tokenizer.decode(output_token, + add_special_tokens=False) # noqa + output_text = self._extract_key_words(output_text.strip()) + return output_text + + def _extract_key_words(self, output_text: str) -> str: + + output_text = output_text.split('###')[0] + output_text = output_text.split('Assistant:')[-1].strip() + output_text = output_text.strip('') + output_text = output_text.strip('') + output_text = output_text.strip() + pattern = re.compile(r'([A-Z]\.)') + res = pattern.findall(output_text) + if len(res) > 0: + output_text = res[0][:-1] + return output_text diff --git a/opencompass/multimodal/models/instructblip/prompt_constructor.py b/opencompass/multimodal/models/instructblip/prompt_constructor.py new file mode 100644 index 00000000..f617e929 --- /dev/null +++ b/opencompass/multimodal/models/instructblip/prompt_constructor.py @@ -0,0 +1,55 @@ +from typing import List + +from mmpretrain.structures import DataSample + + +class InstructBlipMMBenchPromptConstructor: + """Prompt constructor for InstructBlip on MMBench. + + Args: + image_prompt (str): Image prompt. + reply_prompt (str): Reply prompt. + """ + + def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None: + self.image_prompt = image_prompt + self.reply_prompt = reply_prompt + + def __call__(self, inputs: dict) -> dict: + """Construct prompt. + + Args: + inputs (dict): Input data containing image and data_samples. + + Returns: + dict: A dict containing prompt, images and data_samples. + """ + data_samples = inputs['data_samples'] + prompt = self._process(data_samples) + inputs.update({'prompt': prompt}) + + return inputs + + def _process(self, data_samples: List[DataSample]) -> str: + """Process data sample to prompt. + + Args: + data_samples (List[DataSample]): A list of data_samples. + + Returns: + str: Prompt. + """ + assert len(data_samples) == 1, 'Only support batch size 1.' + questions = [ + data_sample.get('question') for data_sample in data_samples + ] + options = [data_sample.get('options') for data_sample in data_samples] + contexts = [data_sample.get('context') for data_sample in data_samples] + question = questions[0] + option = options[0] + context = contexts[0] + if context is not None: + prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa + else: + prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt # noqa + return prompt