From 343f785b07a90c4ef4df5621544b210641f56f81 Mon Sep 17 00:00:00 2001 From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Date: Thu, 24 Aug 2023 14:11:29 +0800 Subject: [PATCH] [Feature]: Add Flamingo (#258) * [Feature]: Add Openflamingo MMBench * [Fix]: Fix import error * [Fix]: Revert task config * [Fix]: Fix path bug --- configs/multimodal/minigpt_4/README.md | 2 +- configs/multimodal/openflamingo/README.md | 21 +++++ .../openflamingo/openflamingo_mmbench.py | 73 +++++++++++++++++ configs/multimodal/tasks.py | 1 + opencompass/multimodal/models/__init__.py | 7 +- .../models/openflamingo/__init__.py | 3 + .../models/openflamingo/openflamingo.py | 81 +++++++++++++++++++ 7 files changed, 186 insertions(+), 2 deletions(-) create mode 100644 configs/multimodal/openflamingo/README.md create mode 100644 configs/multimodal/openflamingo/openflamingo_mmbench.py create mode 100644 opencompass/multimodal/models/openflamingo/__init__.py create mode 100644 opencompass/multimodal/models/openflamingo/openflamingo.py diff --git a/configs/multimodal/minigpt_4/README.md b/configs/multimodal/minigpt_4/README.md index aa434d0e..c7a06d34 100644 --- a/configs/multimodal/minigpt_4/README.md +++ b/configs/multimodal/minigpt_4/README.md @@ -22,5 +22,5 @@ python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION ```sh cd $root -python run.py configs/multimodal/tasks.py +python run.py configs/multimodal/tasks.py --mm-eval ``` \ No newline at end of file diff --git a/configs/multimodal/openflamingo/README.md b/configs/multimodal/openflamingo/README.md new file mode 100644 index 00000000..c8b62736 --- /dev/null +++ b/configs/multimodal/openflamingo/README.md @@ -0,0 +1,21 @@ +# OpenFlamingo + +### Prepare the environment + +Install [MMPretrain](https://github.com/open-mmlab/mmpretrain) according to this [doc](https://mmpretrain.readthedocs.io/en/latest/get_started.html#installation) + +### Start evaluation + +#### Slurm + +```sh +cd $root +python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION +``` + +#### PyTorch + +```sh +cd $root +python run.py configs/multimodal/tasks.py --mm-eval +``` \ No newline at end of file diff --git a/configs/multimodal/openflamingo/openflamingo_mmbench.py b/configs/multimodal/openflamingo/openflamingo_mmbench.py new file mode 100644 index 00000000..8327fb09 --- /dev/null +++ b/configs/multimodal/openflamingo/openflamingo_mmbench.py @@ -0,0 +1,73 @@ +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.PILToNumpy'), + dict(type='mmpretrain.ResizeEdge', + scale=224, + interpolation='bicubic', + backend='pillow'), + dict(type='CenterCrop', crop_size=(224, 224)), + dict(type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'options', 'category', 'l2-category', 'index', + 'context', 'options_dict' + ]) +] + +dataset = dict(type='opencompass.MMBenchDataset', + data_file='data/mmbench/mmbench_test_20230712.tsv', + pipeline=val_pipeline) + +openflamingo_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=dict(type='default_collate'), + persistent_workers=True, +) + +# model settings +openflamingo_model = dict( + type='openflamingo', + data_preprocessor=dict( + type='mmpretrain.MultiModalDataPreprocessor', + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, + ), + tokenizer=dict(type='mmpretrain.LlamaTokenizer', + name_or_path='decapoda-research/llama-7b-hf'), + vision_encoder=dict( + type='mmpretrain.VisionTransformer', + arch='l', + patch_size=14, + pre_norm=True, + norm_cfg=dict(type='LN', eps=1e-5), + layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), + final_norm=False, + out_type='raw', + pretrained= # noqa: E251 + '/path/to/vision/encoder', # noqa + ), + lang_encoder=dict( + base=dict(type='mmpretrain.AutoModelForCausalLM', + name_or_path= + 'decapoda-research/llama-7b-hf', + local_files_only=True), + adapter=dict(type='mmpretrain.FlamingoLMAdapter', + vis_hidden_size=1024, + cross_attn_every_n_layers=4, + use_media_placement_augmentation=False), + ), + generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0), +) + +# evaluation settings +openflamingo_evaluator = [ + dict( + type='opencompass.DumpResults', + save_path= # noqa: E251 + 'work_dirs/9b-flamingo/9b-flamingo-mmbench.xlsx') +] + +openflamingo_load_from = '/path/to/pretrained/weights' # noqa diff --git a/configs/multimodal/tasks.py b/configs/multimodal/tasks.py index e03a1ed2..ef6bd417 100644 --- a/configs/multimodal/tasks.py +++ b/configs/multimodal/tasks.py @@ -10,6 +10,7 @@ models = [minigpt_4_mmbench_model] datasets = [minigpt_4_mmbench_dataloader] evaluators = [minigpt_4_mmbench_evaluator] load_froms = [minigpt_4_mmbench_load_from] + num_gpus = 8 num_procs = 8 launcher = 'pytorch' \ No newline at end of file diff --git a/opencompass/multimodal/models/__init__.py b/opencompass/multimodal/models/__init__.py index 72465706..b61e20f0 100644 --- a/opencompass/multimodal/models/__init__.py +++ b/opencompass/multimodal/models/__init__.py @@ -1,8 +1,13 @@ +import os.path as osp + from opencompass.utils import satisfy_requirement if satisfy_requirement('salesforce-lavis'): from .instructblip import * # noqa: F401, F403 +if osp.exists('opencompass/multimodal/models/minigpt_4/MiniGPT-4'): + from .minigpt_4 import * # noqa: F401, F403 + from .llava import * # noqa: F401, F403 -from .minigpt_4 import * # noqa: F401, F403 +from .openflamingo import * # noqa: F401, F403 from .visualglm import * # noqa: F401, F403 diff --git a/opencompass/multimodal/models/openflamingo/__init__.py b/opencompass/multimodal/models/openflamingo/__init__.py new file mode 100644 index 00000000..a6707eaf --- /dev/null +++ b/opencompass/multimodal/models/openflamingo/__init__.py @@ -0,0 +1,3 @@ +from .openflamingo import OpenFlamingoInferencer + +__all__ = ['OpenFlamingoInferencer'] diff --git a/opencompass/multimodal/models/openflamingo/openflamingo.py b/opencompass/multimodal/models/openflamingo/openflamingo.py new file mode 100644 index 00000000..a46e7ff0 --- /dev/null +++ b/opencompass/multimodal/models/openflamingo/openflamingo.py @@ -0,0 +1,81 @@ +from typing import List, Optional, Union + +import mmengine +import torch +from mmpretrain.models.multimodal import Flamingo +from mmpretrain.structures import DataSample + +from opencompass.registry import MM_MODELS + + +@MM_MODELS.register_module('openflamingo') +class OpenFlamingoInferencer(Flamingo): + """Inference code of OpenFlamingo. + + Args: + prompt_constructor (optional, dict): The config of prompt constructor. + Defaults to None. + post_processor (optional, dict): The config of post processor. + Defaults to None. + mode (str): The mode of inference. Defaults to 'generation'. + """ + + def __init__(self, + prompt_constructor: Optional[dict] = None, + post_processor: Optional[dict] = None, + mode: str = 'generation', + **kwargs): + super().__init__(**kwargs) + if prompt_constructor is not None: + self.prompt_constructor = mmengine.registry.build_from_cfg( + prompt_constructor, MM_MODELS) + if post_processor is not None: + self.post_processor = mmengine.registry.build_from_cfg( + post_processor, MM_MODELS) + self.mode = mode + + def preprocess_text(self, data_samples: List[DataSample], + device: torch.device) -> List[DataSample]: + """Preprocess text in advance before fed into language model. + + Args: + data_samples (List[DataSample]): The annotation + data of every samples. Defaults to None. + device (torch.device): Device for text to put on. + + Returns: + List[DataSample]: Return list of data samples. + """ + prompts = [] + for sample in data_samples: + question = sample.get('question') + option = sample.get('options') + + prompt = '' + question + ' ' + option + ' ' + 'Answer:' + if data_samples[0].get('context') is not None: + prompt = sample.get('context') + ' ' + prompt + + prompts.append(prompt) + + self.tokenizer.padding_side = 'left' + input_text = self.tokenizer( + prompts, + padding='longest', + truncation=True, + return_tensors='pt', + max_length=2000, + ).to(device) + return input_text + + def forward(self, batch: dict) -> Union[DataSample, List[DataSample]]: + + if self.mode == 'generation': + return self.generate(batch) + else: + raise RuntimeError(f'Unsupported mode: {self.mode}') + + def generate(self, batch: dict) -> Union[DataSample, List[DataSample]]: + batch = self.data_preprocessor(batch, False) + images = batch['images'] + data_samples = batch['data_samples'] + return self.predict(images, data_samples)