mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature]: Add Flamingo (#258)
* [Feature]: Add Openflamingo MMBench * [Fix]: Fix import error * [Fix]: Revert task config * [Fix]: Fix path bug
This commit is contained in:
parent
77745a84ea
commit
343f785b07
@ -22,5 +22,5 @@ python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
|
||||
|
||||
```sh
|
||||
cd $root
|
||||
python run.py configs/multimodal/tasks.py
|
||||
python run.py configs/multimodal/tasks.py --mm-eval
|
||||
```
|
21
configs/multimodal/openflamingo/README.md
Normal file
21
configs/multimodal/openflamingo/README.md
Normal file
@ -0,0 +1,21 @@
|
||||
# OpenFlamingo
|
||||
|
||||
### Prepare the environment
|
||||
|
||||
Install [MMPretrain](https://github.com/open-mmlab/mmpretrain) according to this [doc](https://mmpretrain.readthedocs.io/en/latest/get_started.html#installation)
|
||||
|
||||
### Start evaluation
|
||||
|
||||
#### Slurm
|
||||
|
||||
```sh
|
||||
cd $root
|
||||
python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
|
||||
```
|
||||
|
||||
#### PyTorch
|
||||
|
||||
```sh
|
||||
cd $root
|
||||
python run.py configs/multimodal/tasks.py --mm-eval
|
||||
```
|
73
configs/multimodal/openflamingo/openflamingo_mmbench.py
Normal file
73
configs/multimodal/openflamingo/openflamingo_mmbench.py
Normal file
@ -0,0 +1,73 @@
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.PILToNumpy'),
|
||||
dict(type='mmpretrain.ResizeEdge',
|
||||
scale=224,
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='CenterCrop', crop_size=(224, 224)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'options', 'category', 'l2-category', 'index',
|
||||
'context', 'options_dict'
|
||||
])
|
||||
]
|
||||
|
||||
dataset = dict(type='opencompass.MMBenchDataset',
|
||||
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
openflamingo_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
# model settings
|
||||
openflamingo_model = dict(
|
||||
type='openflamingo',
|
||||
data_preprocessor=dict(
|
||||
type='mmpretrain.MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
),
|
||||
tokenizer=dict(type='mmpretrain.LlamaTokenizer',
|
||||
name_or_path='decapoda-research/llama-7b-hf'),
|
||||
vision_encoder=dict(
|
||||
type='mmpretrain.VisionTransformer',
|
||||
arch='l',
|
||||
patch_size=14,
|
||||
pre_norm=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
|
||||
final_norm=False,
|
||||
out_type='raw',
|
||||
pretrained= # noqa: E251
|
||||
'/path/to/vision/encoder', # noqa
|
||||
),
|
||||
lang_encoder=dict(
|
||||
base=dict(type='mmpretrain.AutoModelForCausalLM',
|
||||
name_or_path=
|
||||
'decapoda-research/llama-7b-hf',
|
||||
local_files_only=True),
|
||||
adapter=dict(type='mmpretrain.FlamingoLMAdapter',
|
||||
vis_hidden_size=1024,
|
||||
cross_attn_every_n_layers=4,
|
||||
use_media_placement_augmentation=False),
|
||||
),
|
||||
generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0),
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
openflamingo_evaluator = [
|
||||
dict(
|
||||
type='opencompass.DumpResults',
|
||||
save_path= # noqa: E251
|
||||
'work_dirs/9b-flamingo/9b-flamingo-mmbench.xlsx')
|
||||
]
|
||||
|
||||
openflamingo_load_from = '/path/to/pretrained/weights' # noqa
|
@ -10,6 +10,7 @@ models = [minigpt_4_mmbench_model]
|
||||
datasets = [minigpt_4_mmbench_dataloader]
|
||||
evaluators = [minigpt_4_mmbench_evaluator]
|
||||
load_froms = [minigpt_4_mmbench_load_from]
|
||||
|
||||
num_gpus = 8
|
||||
num_procs = 8
|
||||
launcher = 'pytorch'
|
@ -1,8 +1,13 @@
|
||||
import os.path as osp
|
||||
|
||||
from opencompass.utils import satisfy_requirement
|
||||
|
||||
if satisfy_requirement('salesforce-lavis'):
|
||||
from .instructblip import * # noqa: F401, F403
|
||||
|
||||
if osp.exists('opencompass/multimodal/models/minigpt_4/MiniGPT-4'):
|
||||
from .minigpt_4 import * # noqa: F401, F403
|
||||
|
||||
from .llava import * # noqa: F401, F403
|
||||
from .minigpt_4 import * # noqa: F401, F403
|
||||
from .openflamingo import * # noqa: F401, F403
|
||||
from .visualglm import * # noqa: F401, F403
|
||||
|
3
opencompass/multimodal/models/openflamingo/__init__.py
Normal file
3
opencompass/multimodal/models/openflamingo/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .openflamingo import OpenFlamingoInferencer
|
||||
|
||||
__all__ = ['OpenFlamingoInferencer']
|
81
opencompass/multimodal/models/openflamingo/openflamingo.py
Normal file
81
opencompass/multimodal/models/openflamingo/openflamingo.py
Normal file
@ -0,0 +1,81 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmpretrain.models.multimodal import Flamingo
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
from opencompass.registry import MM_MODELS
|
||||
|
||||
|
||||
@MM_MODELS.register_module('openflamingo')
|
||||
class OpenFlamingoInferencer(Flamingo):
|
||||
"""Inference code of OpenFlamingo.
|
||||
|
||||
Args:
|
||||
prompt_constructor (optional, dict): The config of prompt constructor.
|
||||
Defaults to None.
|
||||
post_processor (optional, dict): The config of post processor.
|
||||
Defaults to None.
|
||||
mode (str): The mode of inference. Defaults to 'generation'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prompt_constructor: Optional[dict] = None,
|
||||
post_processor: Optional[dict] = None,
|
||||
mode: str = 'generation',
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if prompt_constructor is not None:
|
||||
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
||||
prompt_constructor, MM_MODELS)
|
||||
if post_processor is not None:
|
||||
self.post_processor = mmengine.registry.build_from_cfg(
|
||||
post_processor, MM_MODELS)
|
||||
self.mode = mode
|
||||
|
||||
def preprocess_text(self, data_samples: List[DataSample],
|
||||
device: torch.device) -> List[DataSample]:
|
||||
"""Preprocess text in advance before fed into language model.
|
||||
|
||||
Args:
|
||||
data_samples (List[DataSample]): The annotation
|
||||
data of every samples. Defaults to None.
|
||||
device (torch.device): Device for text to put on.
|
||||
|
||||
Returns:
|
||||
List[DataSample]: Return list of data samples.
|
||||
"""
|
||||
prompts = []
|
||||
for sample in data_samples:
|
||||
question = sample.get('question')
|
||||
option = sample.get('options')
|
||||
|
||||
prompt = '<image>' + question + ' ' + option + ' ' + 'Answer:'
|
||||
if data_samples[0].get('context') is not None:
|
||||
prompt = sample.get('context') + ' ' + prompt
|
||||
|
||||
prompts.append(prompt)
|
||||
|
||||
self.tokenizer.padding_side = 'left'
|
||||
input_text = self.tokenizer(
|
||||
prompts,
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
return_tensors='pt',
|
||||
max_length=2000,
|
||||
).to(device)
|
||||
return input_text
|
||||
|
||||
def forward(self, batch: dict) -> Union[DataSample, List[DataSample]]:
|
||||
|
||||
if self.mode == 'generation':
|
||||
return self.generate(batch)
|
||||
else:
|
||||
raise RuntimeError(f'Unsupported mode: {self.mode}')
|
||||
|
||||
def generate(self, batch: dict) -> Union[DataSample, List[DataSample]]:
|
||||
batch = self.data_preprocessor(batch, False)
|
||||
images = batch['images']
|
||||
data_samples = batch['data_samples']
|
||||
return self.predict(images, data_samples)
|
Loading…
Reference in New Issue
Block a user