mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feature]: Add Flamingo (#258)
* [Feature]: Add Openflamingo MMBench * [Fix]: Fix import error * [Fix]: Revert task config * [Fix]: Fix path bug
This commit is contained in:
parent
77745a84ea
commit
343f785b07
@ -22,5 +22,5 @@ python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
|
|||||||
|
|
||||||
```sh
|
```sh
|
||||||
cd $root
|
cd $root
|
||||||
python run.py configs/multimodal/tasks.py
|
python run.py configs/multimodal/tasks.py --mm-eval
|
||||||
```
|
```
|
21
configs/multimodal/openflamingo/README.md
Normal file
21
configs/multimodal/openflamingo/README.md
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# OpenFlamingo
|
||||||
|
|
||||||
|
### Prepare the environment
|
||||||
|
|
||||||
|
Install [MMPretrain](https://github.com/open-mmlab/mmpretrain) according to this [doc](https://mmpretrain.readthedocs.io/en/latest/get_started.html#installation)
|
||||||
|
|
||||||
|
### Start evaluation
|
||||||
|
|
||||||
|
#### Slurm
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd $root
|
||||||
|
python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
|
||||||
|
```
|
||||||
|
|
||||||
|
#### PyTorch
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd $root
|
||||||
|
python run.py configs/multimodal/tasks.py --mm-eval
|
||||||
|
```
|
73
configs/multimodal/openflamingo/openflamingo_mmbench.py
Normal file
73
configs/multimodal/openflamingo/openflamingo_mmbench.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.PILToNumpy'),
|
||||||
|
dict(type='mmpretrain.ResizeEdge',
|
||||||
|
scale=224,
|
||||||
|
interpolation='bicubic',
|
||||||
|
backend='pillow'),
|
||||||
|
dict(type='CenterCrop', crop_size=(224, 224)),
|
||||||
|
dict(type='mmpretrain.PackInputs',
|
||||||
|
algorithm_keys=[
|
||||||
|
'question', 'options', 'category', 'l2-category', 'index',
|
||||||
|
'context', 'options_dict'
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(type='opencompass.MMBenchDataset',
|
||||||
|
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
||||||
|
pipeline=val_pipeline)
|
||||||
|
|
||||||
|
openflamingo_dataloader = dict(
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
collate_fn=dict(type='default_collate'),
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
openflamingo_model = dict(
|
||||||
|
type='openflamingo',
|
||||||
|
data_preprocessor=dict(
|
||||||
|
type='mmpretrain.MultiModalDataPreprocessor',
|
||||||
|
mean=[122.770938, 116.7460125, 104.09373615],
|
||||||
|
std=[68.5005327, 66.6321579, 70.32316305],
|
||||||
|
to_rgb=True,
|
||||||
|
),
|
||||||
|
tokenizer=dict(type='mmpretrain.LlamaTokenizer',
|
||||||
|
name_or_path='decapoda-research/llama-7b-hf'),
|
||||||
|
vision_encoder=dict(
|
||||||
|
type='mmpretrain.VisionTransformer',
|
||||||
|
arch='l',
|
||||||
|
patch_size=14,
|
||||||
|
pre_norm=True,
|
||||||
|
norm_cfg=dict(type='LN', eps=1e-5),
|
||||||
|
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
|
||||||
|
final_norm=False,
|
||||||
|
out_type='raw',
|
||||||
|
pretrained= # noqa: E251
|
||||||
|
'/path/to/vision/encoder', # noqa
|
||||||
|
),
|
||||||
|
lang_encoder=dict(
|
||||||
|
base=dict(type='mmpretrain.AutoModelForCausalLM',
|
||||||
|
name_or_path=
|
||||||
|
'decapoda-research/llama-7b-hf',
|
||||||
|
local_files_only=True),
|
||||||
|
adapter=dict(type='mmpretrain.FlamingoLMAdapter',
|
||||||
|
vis_hidden_size=1024,
|
||||||
|
cross_attn_every_n_layers=4,
|
||||||
|
use_media_placement_augmentation=False),
|
||||||
|
),
|
||||||
|
generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
openflamingo_evaluator = [
|
||||||
|
dict(
|
||||||
|
type='opencompass.DumpResults',
|
||||||
|
save_path= # noqa: E251
|
||||||
|
'work_dirs/9b-flamingo/9b-flamingo-mmbench.xlsx')
|
||||||
|
]
|
||||||
|
|
||||||
|
openflamingo_load_from = '/path/to/pretrained/weights' # noqa
|
@ -10,6 +10,7 @@ models = [minigpt_4_mmbench_model]
|
|||||||
datasets = [minigpt_4_mmbench_dataloader]
|
datasets = [minigpt_4_mmbench_dataloader]
|
||||||
evaluators = [minigpt_4_mmbench_evaluator]
|
evaluators = [minigpt_4_mmbench_evaluator]
|
||||||
load_froms = [minigpt_4_mmbench_load_from]
|
load_froms = [minigpt_4_mmbench_load_from]
|
||||||
|
|
||||||
num_gpus = 8
|
num_gpus = 8
|
||||||
num_procs = 8
|
num_procs = 8
|
||||||
launcher = 'pytorch'
|
launcher = 'pytorch'
|
@ -1,8 +1,13 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
from opencompass.utils import satisfy_requirement
|
from opencompass.utils import satisfy_requirement
|
||||||
|
|
||||||
if satisfy_requirement('salesforce-lavis'):
|
if satisfy_requirement('salesforce-lavis'):
|
||||||
from .instructblip import * # noqa: F401, F403
|
from .instructblip import * # noqa: F401, F403
|
||||||
|
|
||||||
from .llava import * # noqa: F401, F403
|
if osp.exists('opencompass/multimodal/models/minigpt_4/MiniGPT-4'):
|
||||||
from .minigpt_4 import * # noqa: F401, F403
|
from .minigpt_4 import * # noqa: F401, F403
|
||||||
|
|
||||||
|
from .llava import * # noqa: F401, F403
|
||||||
|
from .openflamingo import * # noqa: F401, F403
|
||||||
from .visualglm import * # noqa: F401, F403
|
from .visualglm import * # noqa: F401, F403
|
||||||
|
3
opencompass/multimodal/models/openflamingo/__init__.py
Normal file
3
opencompass/multimodal/models/openflamingo/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .openflamingo import OpenFlamingoInferencer
|
||||||
|
|
||||||
|
__all__ = ['OpenFlamingoInferencer']
|
81
opencompass/multimodal/models/openflamingo/openflamingo.py
Normal file
81
opencompass/multimodal/models/openflamingo/openflamingo.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import mmengine
|
||||||
|
import torch
|
||||||
|
from mmpretrain.models.multimodal import Flamingo
|
||||||
|
from mmpretrain.structures import DataSample
|
||||||
|
|
||||||
|
from opencompass.registry import MM_MODELS
|
||||||
|
|
||||||
|
|
||||||
|
@MM_MODELS.register_module('openflamingo')
|
||||||
|
class OpenFlamingoInferencer(Flamingo):
|
||||||
|
"""Inference code of OpenFlamingo.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_constructor (optional, dict): The config of prompt constructor.
|
||||||
|
Defaults to None.
|
||||||
|
post_processor (optional, dict): The config of post processor.
|
||||||
|
Defaults to None.
|
||||||
|
mode (str): The mode of inference. Defaults to 'generation'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
prompt_constructor: Optional[dict] = None,
|
||||||
|
post_processor: Optional[dict] = None,
|
||||||
|
mode: str = 'generation',
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if prompt_constructor is not None:
|
||||||
|
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
||||||
|
prompt_constructor, MM_MODELS)
|
||||||
|
if post_processor is not None:
|
||||||
|
self.post_processor = mmengine.registry.build_from_cfg(
|
||||||
|
post_processor, MM_MODELS)
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def preprocess_text(self, data_samples: List[DataSample],
|
||||||
|
device: torch.device) -> List[DataSample]:
|
||||||
|
"""Preprocess text in advance before fed into language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_samples (List[DataSample]): The annotation
|
||||||
|
data of every samples. Defaults to None.
|
||||||
|
device (torch.device): Device for text to put on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[DataSample]: Return list of data samples.
|
||||||
|
"""
|
||||||
|
prompts = []
|
||||||
|
for sample in data_samples:
|
||||||
|
question = sample.get('question')
|
||||||
|
option = sample.get('options')
|
||||||
|
|
||||||
|
prompt = '<image>' + question + ' ' + option + ' ' + 'Answer:'
|
||||||
|
if data_samples[0].get('context') is not None:
|
||||||
|
prompt = sample.get('context') + ' ' + prompt
|
||||||
|
|
||||||
|
prompts.append(prompt)
|
||||||
|
|
||||||
|
self.tokenizer.padding_side = 'left'
|
||||||
|
input_text = self.tokenizer(
|
||||||
|
prompts,
|
||||||
|
padding='longest',
|
||||||
|
truncation=True,
|
||||||
|
return_tensors='pt',
|
||||||
|
max_length=2000,
|
||||||
|
).to(device)
|
||||||
|
return input_text
|
||||||
|
|
||||||
|
def forward(self, batch: dict) -> Union[DataSample, List[DataSample]]:
|
||||||
|
|
||||||
|
if self.mode == 'generation':
|
||||||
|
return self.generate(batch)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Unsupported mode: {self.mode}')
|
||||||
|
|
||||||
|
def generate(self, batch: dict) -> Union[DataSample, List[DataSample]]:
|
||||||
|
batch = self.data_preprocessor(batch, False)
|
||||||
|
images = batch['images']
|
||||||
|
data_samples = batch['data_samples']
|
||||||
|
return self.predict(images, data_samples)
|
Loading…
Reference in New Issue
Block a user