mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
249 lines
9.3 KiB
Python
249 lines
9.3 KiB
Python
"""Requires Transformer 4.28 and above, implementation may change according the
|
|
Llama implementation."""
|
|
import logging
|
|
|
|
import mmengine
|
|
import torch
|
|
import torch.nn as nn
|
|
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
|
|
from mmengine.device import get_device
|
|
from transformers import LlamaForCausalLM, LlamaTokenizer
|
|
|
|
from opencompass.registry import MM_MODELS
|
|
|
|
|
|
@MM_MODELS.register_module('blip2-vicuna-instruct')
|
|
class InstructBlipInferencer(Blip2Base):
|
|
|
|
def __init__(
|
|
self,
|
|
prompt_constructor: dict,
|
|
post_processor: dict,
|
|
vit_model: str = 'eva_clip_g',
|
|
img_size: int = 224,
|
|
drop_path_rate: float = 0,
|
|
use_grad_checkpoint: bool = False,
|
|
vit_precision: str = 'fp16',
|
|
freeze_vit: bool = True,
|
|
num_query_token: int = 32,
|
|
llm_model: str = '',
|
|
sys_prompt: str = '',
|
|
prompt: str = '',
|
|
max_txt_len: int = 128,
|
|
max_output_txt_len: int = 256,
|
|
qformer_text_input: bool = True,
|
|
low_resource: bool = False,
|
|
mode: str = 'generation',
|
|
is_caption_task=False,
|
|
):
|
|
super().__init__()
|
|
self.mode = mode
|
|
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
|
prompt_constructor, MM_MODELS)
|
|
self.post_processor = mmengine.registry.build_from_cfg(
|
|
post_processor, MM_MODELS)
|
|
|
|
self.tokenizer = self.init_tokenizer(truncation_side='left')
|
|
|
|
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
|
vit_model, img_size, drop_path_rate, use_grad_checkpoint,
|
|
vit_precision)
|
|
if freeze_vit:
|
|
for name, param in self.visual_encoder.named_parameters():
|
|
param.requires_grad = False
|
|
self.visual_encoder = self.visual_encoder.eval()
|
|
self.visual_encoder.train = disabled_train
|
|
logging.info('freeze vision encoder')
|
|
|
|
self.Qformer, self.query_tokens = self.init_Qformer(
|
|
num_query_token, self.visual_encoder.num_features)
|
|
|
|
if not qformer_text_input:
|
|
self.Qformer.bert.embeddings.word_embeddings = None
|
|
self.Qformer.bert.embeddings.position_embeddings = None
|
|
for layer in self.Qformer.bert.encoder.layer:
|
|
layer.output = None
|
|
layer.intermediate = None
|
|
else:
|
|
self.Qformer.resize_token_embeddings(len(self.tokenizer))
|
|
self.Qformer.cls = None
|
|
|
|
self.llm_tokenizer = LlamaTokenizer.from_pretrained(
|
|
llm_model, use_fast=False, truncation_side='left')
|
|
|
|
if low_resource:
|
|
self.llm_model = LlamaForCausalLM.from_pretrained(
|
|
llm_model,
|
|
torch_dtype=torch.float16,
|
|
load_in_8bit=True,
|
|
device_map={'': 0})
|
|
else:
|
|
self.llm_model = LlamaForCausalLM.from_pretrained(
|
|
llm_model, torch_dtype=torch.float16)
|
|
self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
|
self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
|
|
self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
|
|
self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
|
|
|
|
self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
|
|
|
|
for name, param in self.llm_model.named_parameters():
|
|
param.requires_grad = False
|
|
|
|
self.llm_proj = nn.Linear(self.Qformer.config.hidden_size,
|
|
self.llm_model.config.hidden_size)
|
|
|
|
self.max_txt_len = max_txt_len
|
|
self.max_output_txt_len = max_output_txt_len
|
|
self.sys_prompt = sys_prompt
|
|
self.prompt = prompt
|
|
self.is_caption_task = is_caption_task
|
|
|
|
self._lemmatizer = None
|
|
|
|
self.qformer_text_input = qformer_text_input
|
|
|
|
def forward(self, batch):
|
|
if self.mode == 'generation':
|
|
return self.generate(batch)
|
|
else:
|
|
raise RuntimeError(f'Invalid mode "{self.mode}".')
|
|
|
|
def concat_text_input_output(self, input_ids, input_atts, output_ids,
|
|
output_atts):
|
|
input_part_targets_len = []
|
|
llm_tokens = {'input_ids': [], 'attention_mask': []}
|
|
for i in range(input_ids.size(0)):
|
|
this_input_ones = input_atts[i].sum()
|
|
input_part_targets_len.append(this_input_ones)
|
|
llm_tokens['input_ids'].append(
|
|
torch.cat([
|
|
input_ids[i][:this_input_ones], output_ids[i][1:],
|
|
input_ids[i][this_input_ones:]
|
|
]))
|
|
llm_tokens['attention_mask'].append(
|
|
torch.cat([
|
|
input_atts[i][:this_input_ones], output_atts[i][1:],
|
|
input_atts[i][this_input_ones:]
|
|
]))
|
|
llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
|
|
llm_tokens['attention_mask'] = torch.stack(
|
|
llm_tokens['attention_mask'])
|
|
return llm_tokens, input_part_targets_len
|
|
|
|
def pack_inputs(self, batch):
|
|
images = [image.unsqueeze(0) for image in batch['inputs']]
|
|
data_samples = [data_sample for data_sample in batch['data_samples']]
|
|
images = torch.cat(images, dim=0).to(get_device())
|
|
inputs = {'image': images, 'data_samples': data_samples}
|
|
return inputs
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
batch,
|
|
use_nucleus_sampling=False,
|
|
num_beams=5,
|
|
max_length=256,
|
|
min_length=1,
|
|
top_p=0.9,
|
|
repetition_penalty=1.5,
|
|
length_penalty=1,
|
|
num_captions=1,
|
|
temperature=1,
|
|
):
|
|
inputs = self.pack_inputs(batch)
|
|
inputs = self.prompt_constructor(inputs)
|
|
image = inputs['image']
|
|
prompt = inputs['prompt']
|
|
data_samples = inputs['data_samples']
|
|
|
|
self.llm_tokenizer.padding_side = 'left'
|
|
|
|
bs = image.size(0)
|
|
|
|
if isinstance(prompt, str):
|
|
prompt = [prompt] * bs
|
|
else:
|
|
assert len(
|
|
prompt
|
|
) == bs, 'The number of prompts must be equal to the batch size.'
|
|
|
|
query_tokens = self.query_tokens.expand(bs, -1, -1)
|
|
if self.qformer_text_input:
|
|
text_Qformer = self.tokenizer(
|
|
prompt,
|
|
padding='longest',
|
|
truncation=True,
|
|
max_length=self.max_txt_len,
|
|
return_tensors='pt',
|
|
).to(image.device)
|
|
query_atts = torch.ones(query_tokens.size()[:-1],
|
|
dtype=torch.long).to(image.device)
|
|
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],
|
|
dim=1)
|
|
|
|
with self.maybe_autocast():
|
|
image_embeds = self.ln_vision(self.visual_encoder(image))
|
|
image_atts = torch.ones(image_embeds.size()[:-1],
|
|
dtype=torch.long).to(image.device)
|
|
|
|
if self.qformer_text_input:
|
|
query_output = self.Qformer.bert(
|
|
text_Qformer.input_ids,
|
|
attention_mask=Qformer_atts,
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=image_embeds,
|
|
encoder_attention_mask=image_atts,
|
|
return_dict=True,
|
|
)
|
|
else:
|
|
query_output = self.Qformer.bert(
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=image_embeds,
|
|
encoder_attention_mask=image_atts,
|
|
return_dict=True,
|
|
)
|
|
|
|
inputs_llm = self.llm_proj(
|
|
query_output.last_hidden_state[:, :query_tokens.size(1), :])
|
|
atts_llm = torch.ones(inputs_llm.size()[:-1],
|
|
dtype=torch.long).to(image.device)
|
|
|
|
prompt = ['###Human: ' + p + '###Assistant:' for p in prompt]
|
|
prompt = [self.sys_prompt + p for p in prompt]
|
|
llm_tokens = self.llm_tokenizer(prompt,
|
|
padding='longest',
|
|
return_tensors='pt').to(image.device)
|
|
|
|
with self.maybe_autocast():
|
|
inputs_embeds = self.llm_model.get_input_embeddings()(
|
|
llm_tokens.input_ids)
|
|
inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask],
|
|
dim=1)
|
|
|
|
outputs = self.llm_model.generate(
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
do_sample=use_nucleus_sampling,
|
|
top_p=top_p,
|
|
temperature=temperature,
|
|
num_beams=num_beams,
|
|
max_length=self.max_output_txt_len,
|
|
min_length=min_length,
|
|
repetition_penalty=repetition_penalty,
|
|
length_penalty=length_penalty,
|
|
num_return_sequences=num_captions,
|
|
)
|
|
|
|
for i, data_sample in enumerate(data_samples):
|
|
output_token = outputs[i]
|
|
output_text = self.post_processor(output_token, self.llm_tokenizer)
|
|
if self.is_caption_task:
|
|
data_sample.pred_caption = output_text
|
|
else:
|
|
data_sample.pred_answer = output_text
|
|
data_samples[i] = data_sample
|
|
return data_samples
|