mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] Support LLaVA and mPLUG-Owl (#331)
* refine gitignore * [Feature]: Add minigpt-4 * [Feature]: Add mm local runner * [Feature]: Add instructblip * add otter and llama-adapter * add owl * add llama2-adapter and owl * lint * [Feature]: Add minigpt-4 * [Feature]: Add instructblip * add otter and llama-adapter * add owl * add llama2-adapter and owl * lint * lint * update * lint * lint * add __init__.py * update * update * update --------- Co-authored-by: liuyuan <3463423099@qq.com>
This commit is contained in:
parent
b95aea75ce
commit
f2dd98ca7a
24
configs/multimodal/llama_adapter_v2_multimodal/README.md
Normal file
24
configs/multimodal/llama_adapter_v2_multimodal/README.md
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# Llama Adapter V2
|
||||||
|
|
||||||
|
### Prepare the environment
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd opencompass/multimodal/models/llama_adapter_v2_multimodal
|
||||||
|
git clone https://github.com/OpenGVLab/LLaMA-Adapter.git
|
||||||
|
```
|
||||||
|
|
||||||
|
### Start evaluation
|
||||||
|
|
||||||
|
#### Slurm
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd $root
|
||||||
|
python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
|
||||||
|
```
|
||||||
|
|
||||||
|
#### PyTorch
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd $root
|
||||||
|
python run.py configs/multimodal/tasks.py --mm-eval
|
||||||
|
```
|
@ -0,0 +1,45 @@
|
|||||||
|
from opencompass.multimodal.models.llama_adapter_v2_multimodal import (
|
||||||
|
LlamaAadapterMMBenchPostProcessor, LlamaAadapterMMBenchPromptConstructor)
|
||||||
|
|
||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
|
size=(224, 224),
|
||||||
|
interpolation=3),
|
||||||
|
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||||
|
dict(type='mmpretrain.torchvision/Normalize',
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||||
|
dict(type='mmpretrain.PackInputs',
|
||||||
|
algorithm_keys=[
|
||||||
|
'question', 'answer', 'options', 'category', 'l2-category',
|
||||||
|
'index', 'context', 'options_dict'
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(type='opencompass.MMBenchDataset',
|
||||||
|
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
||||||
|
pipeline=val_pipeline)
|
||||||
|
|
||||||
|
llama_adapter_mmbench_dataloader = dict(batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
llama_adapter_model = dict(
|
||||||
|
type='LLaMA-adapter-v2',
|
||||||
|
llama_dir= # noqa
|
||||||
|
'/llama_adapter_v2_multimodal',
|
||||||
|
prompt_constructor=dict(type=LlamaAadapterMMBenchPromptConstructor),
|
||||||
|
post_processor=dict(type=LlamaAadapterMMBenchPostProcessor))
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
llama_adapter_evaluator = [
|
||||||
|
dict(
|
||||||
|
type='opencompass.DumpResults',
|
||||||
|
save_path='work_dirs/llama-adapter-v2-multimodal-mmagibench-v0.1.0.xlsx'
|
||||||
|
)
|
||||||
|
]
|
24
configs/multimodal/mplug_owl/README.md
Normal file
24
configs/multimodal/mplug_owl/README.md
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# MplugOwl
|
||||||
|
|
||||||
|
### Prepare the environment
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd opencompass/multimodal/models/mplug_owl
|
||||||
|
git clone https://github.com/X-PLUG/mPLUG-Owl.git
|
||||||
|
```
|
||||||
|
|
||||||
|
### Start evaluation
|
||||||
|
|
||||||
|
#### Slurm
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd $root
|
||||||
|
python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION
|
||||||
|
```
|
||||||
|
|
||||||
|
#### PyTorch
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd $root
|
||||||
|
python run.py configs/multimodal/tasks.py --mm-eval
|
||||||
|
```
|
48
configs/multimodal/mplug_owl/mplug_owl-7b-mmbench.py
Normal file
48
configs/multimodal/mplug_owl/mplug_owl-7b-mmbench.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from opencompass.multimodal.models.mplug_owl import (
|
||||||
|
MplugOwlMMBenchPostProcessor, MplugOwlMMBenchPromptConstructor)
|
||||||
|
|
||||||
|
# dataloader settings
|
||||||
|
val_pipeline = [
|
||||||
|
dict(type='mmpretrain.torchvision/Resize',
|
||||||
|
size=(224, 224),
|
||||||
|
interpolation=3),
|
||||||
|
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||||
|
dict(
|
||||||
|
type='mmpretrain.torchvision/Normalize',
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711),
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
type='mmpretrain.PackInputs',
|
||||||
|
algorithm_keys=[
|
||||||
|
'question', 'answer', 'category', 'l2-category', 'context',
|
||||||
|
'index', 'options_dict', 'options'
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset = dict(type='opencompass.MMBenchDataset',
|
||||||
|
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
||||||
|
pipeline=val_pipeline)
|
||||||
|
|
||||||
|
mplug_owl_mmbench_dataloader = dict(
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
dataset=dataset,
|
||||||
|
collate_fn=dict(type='pseudo_collate'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
mplug_owl_mmbench_model = dict(
|
||||||
|
type='mplug_owl-7b',
|
||||||
|
model_path='/mplug-owl-llama-7b-ft',
|
||||||
|
prompt_constructor=dict(type=MplugOwlMMBenchPromptConstructor),
|
||||||
|
post_processor=dict(type=MplugOwlMMBenchPostProcessor)
|
||||||
|
) # noqa
|
||||||
|
|
||||||
|
# evaluation settings
|
||||||
|
mplug_owl_mmbench_evaluator = [
|
||||||
|
dict(type='opencompass.DumpResults',
|
||||||
|
save_path='work_dirs/mplug_owl-7b-mmagibench-v0.1.0.xlsx')
|
||||||
|
]
|
@ -1,5 +1,6 @@
|
|||||||
from .mmbench import MMBenchDataset
|
from .mmbench import MMBenchDataset # noqa: F401, F403
|
||||||
from .mme import MMEDataset
|
from .mme import MMEDataset # noqa: F401, F403
|
||||||
from .seedbench import SEEDBenchDataset
|
from .seedbench import SEEDBenchDataset # noqa: F401, F403
|
||||||
|
|
||||||
__all__ = ['MMBenchDataset', 'SEEDBenchDataset', 'MMEDataset']
|
__all__ = ['MMBenchDataset'
|
||||||
|
'SEEDBenchDataset', 'MMEDataset']
|
||||||
|
@ -8,7 +8,9 @@ if satisfy_requirement('salesforce-lavis'):
|
|||||||
if osp.exists('opencompass/multimodal/models/minigpt_4/MiniGPT-4'):
|
if osp.exists('opencompass/multimodal/models/minigpt_4/MiniGPT-4'):
|
||||||
from .minigpt_4 import * # noqa: F401, F403
|
from .minigpt_4 import * # noqa: F401, F403
|
||||||
|
|
||||||
|
from .llama_adapter_v2_multimodal import * # noqa: F401, F403
|
||||||
from .llava import * # noqa: F401, F403
|
from .llava import * # noqa: F401, F403
|
||||||
|
from .mplug_owl import * # noqa: F401, F403
|
||||||
from .openflamingo import * # noqa: F401, F403
|
from .openflamingo import * # noqa: F401, F403
|
||||||
from .otter import * # noqa: F401, F403
|
from .otter import * # noqa: F401, F403
|
||||||
from .visualglm import * # noqa: F401, F403
|
from .visualglm import * # noqa: F401, F403
|
||||||
|
@ -0,0 +1,8 @@
|
|||||||
|
from .llama_adapter import LLaMA_adapter_v2
|
||||||
|
from .post_processor import LlamaAadapterMMBenchPostProcessor
|
||||||
|
from .prompt_constructor import LlamaAadapterMMBenchPromptConstructor # noqa
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'LLaMA_adapter_v2', 'LlamaAadapterMMBenchPostProcessor',
|
||||||
|
'LlamaAadapterMMBenchPromptConstructor'
|
||||||
|
]
|
@ -0,0 +1,306 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import clip
|
||||||
|
import mmengine
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from llama_adapter_v2_multimodal7b.llama.llama import ModelArgs, Transformer
|
||||||
|
from llama_adapter_v2_multimodal7b.llama.tokenizer import Tokenizer
|
||||||
|
from llama_adapter_v2_multimodal7b.llama.utils import sample_top_p
|
||||||
|
from mmengine.device import get_device
|
||||||
|
from timm.models.vision_transformer import Block
|
||||||
|
|
||||||
|
from opencompass.registry import MM_MODELS
|
||||||
|
|
||||||
|
|
||||||
|
class LLaMA_adapter(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
llama_ckpt_dir,
|
||||||
|
llama_tokenizer,
|
||||||
|
max_seq_len=512,
|
||||||
|
max_batch_size=1,
|
||||||
|
clip_model='ViT-L/14',
|
||||||
|
v_embed_dim=768,
|
||||||
|
v_depth=8,
|
||||||
|
v_num_heads=16,
|
||||||
|
v_mlp_ratio=4.0,
|
||||||
|
query_len=10,
|
||||||
|
query_layer=31,
|
||||||
|
w_bias=False,
|
||||||
|
w_lora=False,
|
||||||
|
lora_rank=16,
|
||||||
|
prompt_constructor=None,
|
||||||
|
post_processor=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.device = get_device()
|
||||||
|
# load llama configs
|
||||||
|
with open(os.path.join(llama_ckpt_dir, 'params.json'), 'r') as f:
|
||||||
|
params = json.loads(f.read())
|
||||||
|
model_args = ModelArgs(max_seq_len=max_seq_len,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
**params)
|
||||||
|
|
||||||
|
# 1. clip and clip projector
|
||||||
|
self.clip, self.clip_transform = clip.load(clip_model)
|
||||||
|
|
||||||
|
clip_dim = self.clip.visual.proj.shape[1]
|
||||||
|
self.clip_proj = nn.Linear(clip_dim, v_embed_dim)
|
||||||
|
self.clip_proj_norm = nn.LayerNorm(v_embed_dim)
|
||||||
|
|
||||||
|
self.query_len = query_len
|
||||||
|
self.query_layer = query_layer
|
||||||
|
|
||||||
|
# 2. visual query, blocks and projector
|
||||||
|
self.visual_query = nn.Embedding(query_len, v_embed_dim)
|
||||||
|
self.visual_blocks = nn.ModuleList([
|
||||||
|
Block(v_embed_dim, v_num_heads, v_mlp_ratio, qkv_bias=True)
|
||||||
|
for _ in range(v_depth)
|
||||||
|
])
|
||||||
|
self.visual_proj = nn.Linear(v_embed_dim, model_args.dim)
|
||||||
|
self.visual_proj_norm = nn.LayerNorm(model_args.dim)
|
||||||
|
|
||||||
|
# 3. adapter query
|
||||||
|
self.adapter_query = nn.Embedding(query_len * query_layer,
|
||||||
|
model_args.dim)
|
||||||
|
|
||||||
|
# 4. tokenizer
|
||||||
|
self.tokenizer = Tokenizer(model_path=llama_tokenizer)
|
||||||
|
|
||||||
|
# 5. llama
|
||||||
|
model_args.vocab_size = self.tokenizer.n_words
|
||||||
|
model_args.w_bias = w_bias
|
||||||
|
model_args.w_lora = w_lora
|
||||||
|
model_args.lora_rank = lora_rank
|
||||||
|
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||||
|
self.llama = Transformer(model_args)
|
||||||
|
torch.set_default_tensor_type(torch.FloatTensor)
|
||||||
|
|
||||||
|
ckpts = sorted(Path(llama_ckpt_dir).glob('*.pth'))
|
||||||
|
for ckpt in ckpts:
|
||||||
|
ckpt = torch.load(ckpt, map_location='cpu')
|
||||||
|
self.llama.load_state_dict(ckpt, strict=False)
|
||||||
|
|
||||||
|
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
||||||
|
prompt_constructor, MM_MODELS)
|
||||||
|
if post_processor is not None:
|
||||||
|
self.post_processor = mmengine.registry.build_from_cfg(
|
||||||
|
post_processor, MM_MODELS)
|
||||||
|
|
||||||
|
def clip_encode_image(self, x):
|
||||||
|
# modified from CLIP
|
||||||
|
x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid]
|
||||||
|
# shape = [*, width, grid ** 2]
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1], -1)
|
||||||
|
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||||
|
x = torch.cat([
|
||||||
|
self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(
|
||||||
|
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
|
||||||
|
],
|
||||||
|
dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||||
|
x = x + self.clip.visual.positional_embedding.to(x.dtype)
|
||||||
|
x = self.clip.visual.ln_pre(x)
|
||||||
|
|
||||||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
|
x = self.clip.visual.transformer(x)
|
||||||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
|
||||||
|
# preserve all spatial tokens
|
||||||
|
x = self.clip.visual.ln_post(x[:, :, :])
|
||||||
|
|
||||||
|
if self.clip.visual.proj is not None:
|
||||||
|
x = x @ self.clip.visual.proj
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward_visual(self, imgs):
|
||||||
|
clip_feats = self.clip_encode_image(imgs)
|
||||||
|
clip_feats = self.clip_proj_norm(self.clip_proj(clip_feats.float()))
|
||||||
|
|
||||||
|
visual_query = self.visual_query.weight.unsqueeze(0).repeat(
|
||||||
|
len(imgs), 1, 1)
|
||||||
|
visual_query = torch.cat([visual_query, clip_feats], dim=1)
|
||||||
|
for block in self.visual_blocks:
|
||||||
|
visual_query = block(visual_query)
|
||||||
|
|
||||||
|
visual_query = visual_query[:, :self.query_len, :]
|
||||||
|
visual_query = self.visual_proj(visual_query)
|
||||||
|
visual_query = self.visual_proj_norm(visual_query)
|
||||||
|
|
||||||
|
return visual_query
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward(self, visual_query, tokens, start_pos: int):
|
||||||
|
_bsz, seqlen = tokens.shape
|
||||||
|
h = self.llama.tok_embeddings(tokens)
|
||||||
|
freqs_cis = self.llama.freqs_cis.to(h.device)
|
||||||
|
freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
|
||||||
|
mask = None
|
||||||
|
mask = torch.full((1, 1, seqlen, seqlen),
|
||||||
|
float('-inf'),
|
||||||
|
device=h.device)
|
||||||
|
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
||||||
|
|
||||||
|
for layer in self.llama.layers[:-1 * self.query_layer]:
|
||||||
|
h = layer(h, start_pos, freqs_cis, mask)
|
||||||
|
|
||||||
|
adapter = self.adapter_query.weight.reshape(self.query_layer,
|
||||||
|
self.query_len,
|
||||||
|
-1).unsqueeze(1)
|
||||||
|
adapter_index = 0
|
||||||
|
for layer in self.llama.layers[-1 * self.query_layer:]:
|
||||||
|
dynamic_adapter = adapter[adapter_index].repeat(_bsz, 1, 1)
|
||||||
|
dynamic_adapter = dynamic_adapter + visual_query
|
||||||
|
h = layer(h, start_pos, freqs_cis, mask, dynamic_adapter)
|
||||||
|
adapter_index = adapter_index + 1
|
||||||
|
|
||||||
|
h = self.llama.norm(h)
|
||||||
|
output = self.llama.output(h[:, -1, :])
|
||||||
|
|
||||||
|
return output.float()
|
||||||
|
|
||||||
|
def pack_inputs(self, batch):
|
||||||
|
images = [image.unsqueeze(0) for image in batch['inputs']]
|
||||||
|
data_samples = [data_sample for data_sample in batch['data_samples']]
|
||||||
|
images = torch.cat(images, dim=0).to(get_device())
|
||||||
|
inputs = {'image': images, 'data_samples': data_samples}
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(self, batch):
|
||||||
|
max_gen_len = 256
|
||||||
|
temperature = 0.1
|
||||||
|
top_p = 0.75
|
||||||
|
inputs = self.pack_inputs(batch)
|
||||||
|
inputs = self.prompt_constructor(inputs)
|
||||||
|
image = inputs['image']
|
||||||
|
prompts = inputs['prompt']
|
||||||
|
data_samples = inputs['data_samples']
|
||||||
|
|
||||||
|
data_sample = data_samples[0]
|
||||||
|
|
||||||
|
prompts = [prompts]
|
||||||
|
imgs = image
|
||||||
|
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
|
bsz = len(imgs)
|
||||||
|
params = self.llama.params
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
visual_query = self.forward_visual(imgs)
|
||||||
|
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
|
if isinstance(prompts[0], str):
|
||||||
|
prompts = [
|
||||||
|
self.tokenizer.encode(x, bos=True, eos=False) for x in prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
|
min_prompt_size = min([len(t) for t in prompts])
|
||||||
|
max_prompt_size = max([len(t) for t in prompts])
|
||||||
|
|
||||||
|
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
|
||||||
|
|
||||||
|
tokens = torch.full((bsz, total_len),
|
||||||
|
self.tokenizer.pad_id).cuda().long()
|
||||||
|
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
|
for k, t in enumerate(prompts):
|
||||||
|
if len(t) <= total_len:
|
||||||
|
tokens[k, :len(t)] = torch.tensor(t).cuda().long()
|
||||||
|
else:
|
||||||
|
tokens[k, :total_len] = torch.tensor(
|
||||||
|
t[:total_len]).cuda().long()
|
||||||
|
|
||||||
|
input_text_mask = tokens != self.tokenizer.pad_id
|
||||||
|
start_pos = min_prompt_size
|
||||||
|
prev_pos = 0
|
||||||
|
for cur_pos in range(start_pos, total_len):
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
logits = self.forward(visual_query,
|
||||||
|
tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
|
if temperature > 0:
|
||||||
|
probs = torch.softmax(logits / temperature, dim=-1)
|
||||||
|
next_token = sample_top_p(probs, top_p)
|
||||||
|
else:
|
||||||
|
next_token = torch.argmax(logits, dim=-1)
|
||||||
|
next_token = next_token.reshape(-1)
|
||||||
|
|
||||||
|
next_token = torch.where(input_text_mask[:, cur_pos],
|
||||||
|
tokens[:, cur_pos], next_token)
|
||||||
|
tokens[:, cur_pos] = next_token
|
||||||
|
# trick: early stop if bsz==1
|
||||||
|
if bsz == 1 and next_token[0] == self.tokenizer.eos_id:
|
||||||
|
break
|
||||||
|
prev_pos = cur_pos
|
||||||
|
|
||||||
|
decoded = []
|
||||||
|
for i, t in enumerate(tokens.tolist()):
|
||||||
|
|
||||||
|
# cut to max gen len
|
||||||
|
t = t[len(prompts[i]):len(prompts[i]) + max_gen_len]
|
||||||
|
# cut to eos tok if any
|
||||||
|
try:
|
||||||
|
t = t[:t.index(self.tokenizer.eos_id)]
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
decoded.append(self.tokenizer.decode(t))
|
||||||
|
|
||||||
|
output_text = self.post_processor(decoded[0])
|
||||||
|
data_sample.pred_answer = output_text
|
||||||
|
return data_sample
|
||||||
|
|
||||||
|
|
||||||
|
@MM_MODELS.register_module('LLaMA-adapter-v2')
|
||||||
|
class LLaMA_adapter_v2(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
llama_dir,
|
||||||
|
prompt_constructor: dict,
|
||||||
|
post_processor: dict,
|
||||||
|
mode: str = 'generation',
|
||||||
|
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||||
|
download_root='ckpts'):
|
||||||
|
super().__init__()
|
||||||
|
name = 'BIAS-7B'
|
||||||
|
|
||||||
|
# BIAS-7B or https://xxx/sha256_BIAS-7B.pth -> 7B
|
||||||
|
llama_type = name.split('.')[0].split('-')[-1]
|
||||||
|
llama_ckpt_dir = os.path.join(llama_dir, llama_type)
|
||||||
|
llama_tokenzier_path = os.path.join(llama_dir, 'tokenizer.model')
|
||||||
|
|
||||||
|
# load llama_adapter weights and model_cfg
|
||||||
|
print(f'Loading LLaMA-Adapter from {llama_dir}')
|
||||||
|
ckpt = torch.load(
|
||||||
|
f'{llama_dir}/7fa55208379faf2dd862565284101b0e4a2a72114d6490a95e432cf9d9b6c813_BIAS-7B.pth', # noqa: E501
|
||||||
|
map_location='cpu')
|
||||||
|
model_cfg = ckpt.get('config', {})
|
||||||
|
|
||||||
|
self.model = LLaMA_adapter(
|
||||||
|
llama_ckpt_dir,
|
||||||
|
llama_tokenzier_path,
|
||||||
|
max_seq_len=512,
|
||||||
|
max_batch_size=1,
|
||||||
|
clip_model='ViT-L/14',
|
||||||
|
v_embed_dim=768,
|
||||||
|
v_depth=8,
|
||||||
|
v_num_heads=16,
|
||||||
|
v_mlp_ratio=4.0,
|
||||||
|
query_len=10,
|
||||||
|
query_layer=31,
|
||||||
|
w_bias=model_cfg.get('w_bias', False),
|
||||||
|
w_lora=model_cfg.get('w_lora', False),
|
||||||
|
lora_rank=model_cfg.get('lora_rank', 16),
|
||||||
|
prompt_constructor=prompt_constructor,
|
||||||
|
post_processor=post_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model.load_state_dict(ckpt['model'], strict=False)
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
if self.mode == 'generation':
|
||||||
|
return self.model.generate(batch)
|
@ -0,0 +1,15 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaAadapterMMBenchPostProcessor:
|
||||||
|
""""Post processor for Llama Aadapter V2 on MMBench."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||||
|
|
||||||
|
if len(output_token) >= 2:
|
||||||
|
if output_token[1] == '.':
|
||||||
|
output_token = output_token[2:].strip()
|
||||||
|
return output_token
|
@ -0,0 +1,56 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from mmpretrain.structures import DataSample
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaAadapterMMBenchPromptConstructor:
|
||||||
|
"""Prompt constructor for Llama Adapter v2 on MMBench.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_prompt (str): Image prompt. Defaults to `''`.
|
||||||
|
reply_prompt (str): Reply prompt. Defaults to `''`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
|
||||||
|
self.image_prompt = image_prompt
|
||||||
|
self.reply_prompt = reply_prompt
|
||||||
|
|
||||||
|
def __call__(self, inputs: dict) -> dict:
|
||||||
|
"""Construct prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (dict): Input data containing image and data_samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dict containing prompt, images and data_samples.
|
||||||
|
"""
|
||||||
|
data_samples = inputs['data_samples']
|
||||||
|
prompt = self._process(data_samples)
|
||||||
|
inputs.update({'prompt': prompt})
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _process(self, data_samples: List[DataSample]) -> str:
|
||||||
|
"""Process data sample to prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_samples (List[DataSample]): A list of data_samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Prompt.
|
||||||
|
"""
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
|
question = [
|
||||||
|
data_sample.get('question') for data_sample in data_samples
|
||||||
|
]
|
||||||
|
options = [data_sample.get('options') for data_sample in data_samples]
|
||||||
|
if data_samples[0].get('context') is not None:
|
||||||
|
context = [
|
||||||
|
data_sample.get('context') for data_sample in data_samples
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
context = ''
|
||||||
|
|
||||||
|
prompts = context + ' ' + question + ' ' + options # noqa
|
||||||
|
|
||||||
|
return prompts
|
8
opencompass/multimodal/models/mplug_owl/__init__.py
Normal file
8
opencompass/multimodal/models/mplug_owl/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from .mplug_owl import MplugOwl
|
||||||
|
from .post_processor import MplugOwlMMBenchPostProcessor
|
||||||
|
from .prompt_constructor import MplugOwlMMBenchPromptConstructor # noqa
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'MplugOwl', 'MplugOwlMMBenchPostProcessor',
|
||||||
|
'MplugOwlMMBenchPromptConstructor'
|
||||||
|
]
|
86
opencompass/multimodal/models/mplug_owl/mplug_owl.py
Normal file
86
opencompass/multimodal/models/mplug_owl/mplug_owl.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import mmengine
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmengine.device import get_device
|
||||||
|
# Load via Huggingface Style
|
||||||
|
from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
|
||||||
|
from mplug_owl.processing_mplug_owl import (MplugOwlImageProcessor,
|
||||||
|
MplugOwlProcessor)
|
||||||
|
from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer
|
||||||
|
|
||||||
|
from opencompass.registry import MM_MODELS
|
||||||
|
|
||||||
|
|
||||||
|
@MM_MODELS.register_module('mplug_owl')
|
||||||
|
class MplugOwl(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
prompt_constructor: dict,
|
||||||
|
post_processor: dict,
|
||||||
|
model_path='MAGAer13/mplug-owl-llama-7b',
|
||||||
|
mode: str = 'generation') -> None:
|
||||||
|
super().__init__()
|
||||||
|
pretrained_ckpt = model_path
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
|
self.model = MplugOwlForConditionalGeneration.from_pretrained(
|
||||||
|
pretrained_ckpt,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
).cuda()
|
||||||
|
self.image_processor = MplugOwlImageProcessor.from_pretrained(
|
||||||
|
pretrained_ckpt)
|
||||||
|
self.tokenizer = MplugOwlTokenizer.from_pretrained(pretrained_ckpt)
|
||||||
|
self.processor = MplugOwlProcessor(self.image_processor,
|
||||||
|
self.tokenizer)
|
||||||
|
self.generate_kwargs = {
|
||||||
|
'do_sample': False,
|
||||||
|
'top_k': 5,
|
||||||
|
'max_length': 20,
|
||||||
|
'num_beams': 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
||||||
|
prompt_constructor, MM_MODELS)
|
||||||
|
if post_processor is not None:
|
||||||
|
self.post_processor = mmengine.registry.build_from_cfg(
|
||||||
|
post_processor, MM_MODELS)
|
||||||
|
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
if self.mode == 'generation':
|
||||||
|
return self.generate(batch)
|
||||||
|
|
||||||
|
def generate(self, batch):
|
||||||
|
images = [image.unsqueeze(0) for image in batch['inputs']]
|
||||||
|
data_samples = [data_sample for data_sample in batch['data_samples']]
|
||||||
|
images = torch.cat(images, dim=0).to(get_device())
|
||||||
|
inputs = {'image': images, 'data_samples': data_samples}
|
||||||
|
inputs = self.prompt_constructor(inputs)
|
||||||
|
image = inputs['image']
|
||||||
|
prompt = inputs['prompt']
|
||||||
|
data_samples = inputs['data_samples']
|
||||||
|
|
||||||
|
data_sample = data_samples[0]
|
||||||
|
owl_template = """The following is a conversation
|
||||||
|
between a curious human and AI assistant.
|
||||||
|
The assistant gives helpful, detailed, and
|
||||||
|
polite answers to the user's questions.
|
||||||
|
Human: <image>
|
||||||
|
Human: {text_input}
|
||||||
|
AI: """
|
||||||
|
prompt = owl_template.format(text_input=prompt)
|
||||||
|
inputs = self.processor(text=[prompt], return_tensors='pt')
|
||||||
|
inputs['pixel_values'] = image
|
||||||
|
# inputs['pixel_values'] = torch.zeros_like(samples['image'])
|
||||||
|
inputs = {
|
||||||
|
k: v.bfloat16() if v.dtype == torch.float else v
|
||||||
|
for k, v in inputs.items()
|
||||||
|
}
|
||||||
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
||||||
|
with torch.no_grad():
|
||||||
|
res = self.model.generate(**inputs, **self.generate_kwargs)
|
||||||
|
output_text = self.tokenizer.decode(res.tolist()[0],
|
||||||
|
skip_special_tokens=True)
|
||||||
|
output_text = self.post_processor(output_text)
|
||||||
|
data_sample.pred_answer = output_text
|
||||||
|
return data_sample
|
17
opencompass/multimodal/models/mplug_owl/post_processor.py
Normal file
17
opencompass/multimodal/models/mplug_owl/post_processor.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class MplugOwlMMBenchPostProcessor:
|
||||||
|
""""Post processor for MplugOwl on MMBench."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, output_token: torch.tensor, tokenizer) -> str:
|
||||||
|
pattern = re.compile(r'([A-Z]\.)')
|
||||||
|
res = pattern.findall(output_token)
|
||||||
|
if len(res) > 0:
|
||||||
|
output_token = res[0][:-1]
|
||||||
|
return output_token
|
@ -0,0 +1,55 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from mmpretrain.structures import DataSample
|
||||||
|
|
||||||
|
|
||||||
|
class MplugOwlMMBenchPromptConstructor:
|
||||||
|
"""Prompt constructor for MplugOwl on MMBench.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_prompt (str): Image prompt. Defaults to `''`.
|
||||||
|
reply_prompt (str): Reply prompt. Defaults to `''`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
|
||||||
|
self.image_prompt = image_prompt
|
||||||
|
self.reply_prompt = reply_prompt
|
||||||
|
|
||||||
|
def __call__(self, inputs: dict) -> dict:
|
||||||
|
"""Construct prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (dict): Input data containing image and data_samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dict containing prompt, images and data_samples.
|
||||||
|
"""
|
||||||
|
data_samples = inputs['data_samples']
|
||||||
|
prompt = self._process(data_samples)
|
||||||
|
inputs.update({'prompt': prompt})
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _process(self, data_samples: List[DataSample]) -> str:
|
||||||
|
"""Process data sample to prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_samples (List[DataSample]): A list of data_samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Prompt.
|
||||||
|
"""
|
||||||
|
question = [
|
||||||
|
data_sample.get('question') for data_sample in data_samples
|
||||||
|
]
|
||||||
|
options = [data_sample.get('options') for data_sample in data_samples]
|
||||||
|
if data_samples[0].get('context') is not None:
|
||||||
|
context = [
|
||||||
|
data_sample.get('context') for data_sample in data_samples
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
context = ''
|
||||||
|
|
||||||
|
prompts = context + ' ' + question + ' ' + options # noqa
|
||||||
|
|
||||||
|
return prompts
|
Loading…
Reference in New Issue
Block a user