OpenCompass/opencompass/multimodal/models/llama_adapter_v2_multimodal/llama_adapter.py
Yuanhan Zhang f2dd98ca7a
[Feat] Support LLaVA and mPLUG-Owl (#331)
* refine gitignore

* [Feature]: Add minigpt-4

* [Feature]: Add mm local runner

* [Feature]: Add instructblip

* add otter and llama-adapter

* add owl

* add llama2-adapter and owl

* lint

* [Feature]: Add minigpt-4

* [Feature]: Add instructblip

* add otter and llama-adapter

* add owl

* add llama2-adapter and owl

* lint

* lint

* update

* lint

* lint

* add __init__.py

* update

* update

* update

---------

Co-authored-by: liuyuan <3463423099@qq.com>
2023-09-01 23:32:05 +08:00

307 lines
11 KiB
Python

import json
import os
from pathlib import Path
import clip
import mmengine
import torch
import torch.nn as nn
from llama_adapter_v2_multimodal7b.llama.llama import ModelArgs, Transformer
from llama_adapter_v2_multimodal7b.llama.tokenizer import Tokenizer
from llama_adapter_v2_multimodal7b.llama.utils import sample_top_p
from mmengine.device import get_device
from timm.models.vision_transformer import Block
from opencompass.registry import MM_MODELS
class LLaMA_adapter(nn.Module):
def __init__(self,
llama_ckpt_dir,
llama_tokenizer,
max_seq_len=512,
max_batch_size=1,
clip_model='ViT-L/14',
v_embed_dim=768,
v_depth=8,
v_num_heads=16,
v_mlp_ratio=4.0,
query_len=10,
query_layer=31,
w_bias=False,
w_lora=False,
lora_rank=16,
prompt_constructor=None,
post_processor=None):
super().__init__()
self.device = get_device()
# load llama configs
with open(os.path.join(llama_ckpt_dir, 'params.json'), 'r') as f:
params = json.loads(f.read())
model_args = ModelArgs(max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params)
# 1. clip and clip projector
self.clip, self.clip_transform = clip.load(clip_model)
clip_dim = self.clip.visual.proj.shape[1]
self.clip_proj = nn.Linear(clip_dim, v_embed_dim)
self.clip_proj_norm = nn.LayerNorm(v_embed_dim)
self.query_len = query_len
self.query_layer = query_layer
# 2. visual query, blocks and projector
self.visual_query = nn.Embedding(query_len, v_embed_dim)
self.visual_blocks = nn.ModuleList([
Block(v_embed_dim, v_num_heads, v_mlp_ratio, qkv_bias=True)
for _ in range(v_depth)
])
self.visual_proj = nn.Linear(v_embed_dim, model_args.dim)
self.visual_proj_norm = nn.LayerNorm(model_args.dim)
# 3. adapter query
self.adapter_query = nn.Embedding(query_len * query_layer,
model_args.dim)
# 4. tokenizer
self.tokenizer = Tokenizer(model_path=llama_tokenizer)
# 5. llama
model_args.vocab_size = self.tokenizer.n_words
model_args.w_bias = w_bias
model_args.w_lora = w_lora
model_args.lora_rank = lora_rank
torch.set_default_tensor_type(torch.cuda.HalfTensor)
self.llama = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
ckpts = sorted(Path(llama_ckpt_dir).glob('*.pth'))
for ckpt in ckpts:
ckpt = torch.load(ckpt, map_location='cpu')
self.llama.load_state_dict(ckpt, strict=False)
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
def clip_encode_image(self, x):
# modified from CLIP
x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid]
# shape = [*, width, grid ** 2]
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([
self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
],
dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.clip.visual.positional_embedding.to(x.dtype)
x = self.clip.visual.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip.visual.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
# preserve all spatial tokens
x = self.clip.visual.ln_post(x[:, :, :])
if self.clip.visual.proj is not None:
x = x @ self.clip.visual.proj
return x
def forward_visual(self, imgs):
clip_feats = self.clip_encode_image(imgs)
clip_feats = self.clip_proj_norm(self.clip_proj(clip_feats.float()))
visual_query = self.visual_query.weight.unsqueeze(0).repeat(
len(imgs), 1, 1)
visual_query = torch.cat([visual_query, clip_feats], dim=1)
for block in self.visual_blocks:
visual_query = block(visual_query)
visual_query = visual_query[:, :self.query_len, :]
visual_query = self.visual_proj(visual_query)
visual_query = self.visual_proj_norm(visual_query)
return visual_query
@torch.inference_mode()
def forward(self, visual_query, tokens, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.llama.tok_embeddings(tokens)
freqs_cis = self.llama.freqs_cis.to(h.device)
freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
mask = None
mask = torch.full((1, 1, seqlen, seqlen),
float('-inf'),
device=h.device)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
for layer in self.llama.layers[:-1 * self.query_layer]:
h = layer(h, start_pos, freqs_cis, mask)
adapter = self.adapter_query.weight.reshape(self.query_layer,
self.query_len,
-1).unsqueeze(1)
adapter_index = 0
for layer in self.llama.layers[-1 * self.query_layer:]:
dynamic_adapter = adapter[adapter_index].repeat(_bsz, 1, 1)
dynamic_adapter = dynamic_adapter + visual_query
h = layer(h, start_pos, freqs_cis, mask, dynamic_adapter)
adapter_index = adapter_index + 1
h = self.llama.norm(h)
output = self.llama.output(h[:, -1, :])
return output.float()
def pack_inputs(self, batch):
images = [image.unsqueeze(0) for image in batch['inputs']]
data_samples = [data_sample for data_sample in batch['data_samples']]
images = torch.cat(images, dim=0).to(get_device())
inputs = {'image': images, 'data_samples': data_samples}
return inputs
@torch.inference_mode()
def generate(self, batch):
max_gen_len = 256
temperature = 0.1
top_p = 0.75
inputs = self.pack_inputs(batch)
inputs = self.prompt_constructor(inputs)
image = inputs['image']
prompts = inputs['prompt']
data_samples = inputs['data_samples']
data_sample = data_samples[0]
prompts = [prompts]
imgs = image
# import pdb;pdb.set_trace()
bsz = len(imgs)
params = self.llama.params
with torch.cuda.amp.autocast():
visual_query = self.forward_visual(imgs)
# import pdb;pdb.set_trace()
if isinstance(prompts[0], str):
prompts = [
self.tokenizer.encode(x, bos=True, eos=False) for x in prompts
]
# import pdb;pdb.set_trace()
min_prompt_size = min([len(t) for t in prompts])
max_prompt_size = max([len(t) for t in prompts])
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
tokens = torch.full((bsz, total_len),
self.tokenizer.pad_id).cuda().long()
# import pdb;pdb.set_trace()
for k, t in enumerate(prompts):
if len(t) <= total_len:
tokens[k, :len(t)] = torch.tensor(t).cuda().long()
else:
tokens[k, :total_len] = torch.tensor(
t[:total_len]).cuda().long()
input_text_mask = tokens != self.tokenizer.pad_id
start_pos = min_prompt_size
prev_pos = 0
for cur_pos in range(start_pos, total_len):
with torch.cuda.amp.autocast():
logits = self.forward(visual_query,
tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
next_token = torch.where(input_text_mask[:, cur_pos],
tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
# trick: early stop if bsz==1
if bsz == 1 and next_token[0] == self.tokenizer.eos_id:
break
prev_pos = cur_pos
decoded = []
for i, t in enumerate(tokens.tolist()):
# cut to max gen len
t = t[len(prompts[i]):len(prompts[i]) + max_gen_len]
# cut to eos tok if any
try:
t = t[:t.index(self.tokenizer.eos_id)]
except ValueError:
pass
decoded.append(self.tokenizer.decode(t))
output_text = self.post_processor(decoded[0])
data_sample.pred_answer = output_text
return data_sample
@MM_MODELS.register_module('LLaMA-adapter-v2')
class LLaMA_adapter_v2(nn.Module):
def __init__(self,
llama_dir,
prompt_constructor: dict,
post_processor: dict,
mode: str = 'generation',
device='cuda' if torch.cuda.is_available() else 'cpu',
download_root='ckpts'):
super().__init__()
name = 'BIAS-7B'
# BIAS-7B or https://xxx/sha256_BIAS-7B.pth -> 7B
llama_type = name.split('.')[0].split('-')[-1]
llama_ckpt_dir = os.path.join(llama_dir, llama_type)
llama_tokenzier_path = os.path.join(llama_dir, 'tokenizer.model')
# load llama_adapter weights and model_cfg
print(f'Loading LLaMA-Adapter from {llama_dir}')
ckpt = torch.load(
f'{llama_dir}/7fa55208379faf2dd862565284101b0e4a2a72114d6490a95e432cf9d9b6c813_BIAS-7B.pth', # noqa: E501
map_location='cpu')
model_cfg = ckpt.get('config', {})
self.model = LLaMA_adapter(
llama_ckpt_dir,
llama_tokenzier_path,
max_seq_len=512,
max_batch_size=1,
clip_model='ViT-L/14',
v_embed_dim=768,
v_depth=8,
v_num_heads=16,
v_mlp_ratio=4.0,
query_len=10,
query_layer=31,
w_bias=model_cfg.get('w_bias', False),
w_lora=model_cfg.get('w_lora', False),
lora_rank=model_cfg.get('lora_rank', 16),
prompt_constructor=prompt_constructor,
post_processor=post_processor,
)
self.model.load_state_dict(ckpt['model'], strict=False)
self.mode = mode
def forward(self, batch):
if self.mode == 'generation':
return self.model.generate(batch)