[Feature]: Add Flamingo (#258)

* [Feature]: Add Openflamingo MMBench

* [Fix]: Fix import error

* [Fix]: Revert task config

* [Fix]: Fix path bug
This commit is contained in:
Yuan Liu 2023-08-24 14:11:29 +08:00 committed by GitHub
parent 77745a84ea
commit 343f785b07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 186 additions and 2 deletions

View File

@ -22,5 +22,5 @@ python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
```sh
cd $root
python run.py configs/multimodal/tasks.py
python run.py configs/multimodal/tasks.py --mm-eval
```

View File

@ -0,0 +1,21 @@
# OpenFlamingo
### Prepare the environment
Install [MMPretrain](https://github.com/open-mmlab/mmpretrain) according to this [doc](https://mmpretrain.readthedocs.io/en/latest/get_started.html#installation)
### Start evaluation
#### Slurm
```sh
cd $root
python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
```
#### PyTorch
```sh
cd $root
python run.py configs/multimodal/tasks.py --mm-eval
```

View File

@ -0,0 +1,73 @@
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.PILToNumpy'),
dict(type='mmpretrain.ResizeEdge',
scale=224,
interpolation='bicubic',
backend='pillow'),
dict(type='CenterCrop', crop_size=(224, 224)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'options', 'category', 'l2-category', 'index',
'context', 'options_dict'
])
]
dataset = dict(type='opencompass.MMBenchDataset',
data_file='data/mmbench/mmbench_test_20230712.tsv',
pipeline=val_pipeline)
openflamingo_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dataset,
sampler=dict(type='DefaultSampler', shuffle=False),
collate_fn=dict(type='default_collate'),
persistent_workers=True,
)
# model settings
openflamingo_model = dict(
type='openflamingo',
data_preprocessor=dict(
type='mmpretrain.MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
),
tokenizer=dict(type='mmpretrain.LlamaTokenizer',
name_or_path='decapoda-research/llama-7b-hf'),
vision_encoder=dict(
type='mmpretrain.VisionTransformer',
arch='l',
patch_size=14,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained= # noqa: E251
'/path/to/vision/encoder', # noqa
),
lang_encoder=dict(
base=dict(type='mmpretrain.AutoModelForCausalLM',
name_or_path=
'decapoda-research/llama-7b-hf',
local_files_only=True),
adapter=dict(type='mmpretrain.FlamingoLMAdapter',
vis_hidden_size=1024,
cross_attn_every_n_layers=4,
use_media_placement_augmentation=False),
),
generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0),
)
# evaluation settings
openflamingo_evaluator = [
dict(
type='opencompass.DumpResults',
save_path= # noqa: E251
'work_dirs/9b-flamingo/9b-flamingo-mmbench.xlsx')
]
openflamingo_load_from = '/path/to/pretrained/weights' # noqa

View File

@ -10,6 +10,7 @@ models = [minigpt_4_mmbench_model]
datasets = [minigpt_4_mmbench_dataloader]
evaluators = [minigpt_4_mmbench_evaluator]
load_froms = [minigpt_4_mmbench_load_from]
num_gpus = 8
num_procs = 8
launcher = 'pytorch'

View File

@ -1,8 +1,13 @@
import os.path as osp
from opencompass.utils import satisfy_requirement
if satisfy_requirement('salesforce-lavis'):
from .instructblip import * # noqa: F401, F403
if osp.exists('opencompass/multimodal/models/minigpt_4/MiniGPT-4'):
from .minigpt_4 import * # noqa: F401, F403
from .llava import * # noqa: F401, F403
from .minigpt_4 import * # noqa: F401, F403
from .openflamingo import * # noqa: F401, F403
from .visualglm import * # noqa: F401, F403

View File

@ -0,0 +1,3 @@
from .openflamingo import OpenFlamingoInferencer
__all__ = ['OpenFlamingoInferencer']

View File

@ -0,0 +1,81 @@
from typing import List, Optional, Union
import mmengine
import torch
from mmpretrain.models.multimodal import Flamingo
from mmpretrain.structures import DataSample
from opencompass.registry import MM_MODELS
@MM_MODELS.register_module('openflamingo')
class OpenFlamingoInferencer(Flamingo):
"""Inference code of OpenFlamingo.
Args:
prompt_constructor (optional, dict): The config of prompt constructor.
Defaults to None.
post_processor (optional, dict): The config of post processor.
Defaults to None.
mode (str): The mode of inference. Defaults to 'generation'.
"""
def __init__(self,
prompt_constructor: Optional[dict] = None,
post_processor: Optional[dict] = None,
mode: str = 'generation',
**kwargs):
super().__init__(**kwargs)
if prompt_constructor is not None:
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
self.mode = mode
def preprocess_text(self, data_samples: List[DataSample],
device: torch.device) -> List[DataSample]:
"""Preprocess text in advance before fed into language model.
Args:
data_samples (List[DataSample]): The annotation
data of every samples. Defaults to None.
device (torch.device): Device for text to put on.
Returns:
List[DataSample]: Return list of data samples.
"""
prompts = []
for sample in data_samples:
question = sample.get('question')
option = sample.get('options')
prompt = '<image>' + question + ' ' + option + ' ' + 'Answer:'
if data_samples[0].get('context') is not None:
prompt = sample.get('context') + ' ' + prompt
prompts.append(prompt)
self.tokenizer.padding_side = 'left'
input_text = self.tokenizer(
prompts,
padding='longest',
truncation=True,
return_tensors='pt',
max_length=2000,
).to(device)
return input_text
def forward(self, batch: dict) -> Union[DataSample, List[DataSample]]:
if self.mode == 'generation':
return self.generate(batch)
else:
raise RuntimeError(f'Unsupported mode: {self.mode}')
def generate(self, batch: dict) -> Union[DataSample, List[DataSample]]:
batch = self.data_preprocessor(batch, False)
images = batch['images']
data_samples = batch['data_samples']
return self.predict(images, data_samples)