mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00

* [Feat] Support Qwen-VL base. * [Feat] Support Qwen-VL-Chat on MMBench. * [Fix] Add postprocessor and fix format. * [Fix] Add type hint and remove redundant codes. * [Fix] fix bugs in postprocessor. * [Fix] Use given commit id.
325 lines
12 KiB
Python
325 lines
12 KiB
Python
import types
|
|
from typing import Optional, Tuple
|
|
|
|
import mmengine
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.device import get_device
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from transformers.generation import GenerationConfig
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
|
|
from opencompass.registry import MM_MODELS
|
|
|
|
from .generation_utils import decode_tokens, make_context
|
|
|
|
|
|
@MM_MODELS.register_module('qwen-vl-base')
|
|
class QwenVLBase(nn.Module):
|
|
"""Inference code of Qwen-VL.
|
|
|
|
We load the Qwen model via Huggingface.
|
|
Args:
|
|
pretrained_path (str): Path to Qwen checkpoint or repo id.
|
|
prompt_constructor (dict): The config of prompt constructor.
|
|
post_processor (dict): The config of post processor.
|
|
is_caption_task (bool): Whether the task is caption task.
|
|
Defaults to False.
|
|
commit_id (str): Use given version of Qwen-VL.
|
|
Warning: the latest version may have some conflicts.
|
|
Recommend to use the given default version.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
pretrained_path: str,
|
|
prompt_constructor: dict = None,
|
|
post_processor: dict = None,
|
|
is_caption_task: bool = False,
|
|
commit_id: str = '548275c8b99de56dec203c0e793be18e030f2f4c'
|
|
) -> None:
|
|
super().__init__()
|
|
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path,
|
|
trust_remote_code=True,
|
|
revision=commit_id)
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
pretrained_path,
|
|
device_map=get_device(),
|
|
trust_remote_code=True,
|
|
revision=commit_id)
|
|
self.model.generation_config = GenerationConfig.from_pretrained(
|
|
pretrained_path, trust_remote_code=True, revision=commit_id)
|
|
if prompt_constructor is not None:
|
|
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
|
prompt_constructor, MM_MODELS)
|
|
if post_processor is not None:
|
|
self.post_processor = mmengine.registry.build_from_cfg(
|
|
post_processor, MM_MODELS)
|
|
self.is_caption_task = is_caption_task
|
|
self.model.transformer.forward = types.MethodType(
|
|
forward_hack, self.model.transformer)
|
|
|
|
def _build_embeds(self, images, input_ids):
|
|
# encode image
|
|
images = self.model.transformer.visual(images)
|
|
# compute image position
|
|
bos_pos = torch.where(input_ids == self.model.transformer.config.
|
|
visual['image_start_id'])
|
|
eos_pos = torch.where(
|
|
input_ids ==
|
|
self.model.transformer.config.visual['image_start_id'] + 1)
|
|
assert (bos_pos[0] == eos_pos[0]).all()
|
|
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
|
|
# embed words
|
|
inputs_embeds = self.model.transformer.wte(input_ids)
|
|
# embed image tokens
|
|
for idx, (i, a, b) in enumerate(img_pos):
|
|
inputs_embeds[i][a + 1:b] = images[idx]
|
|
return inputs_embeds
|
|
|
|
def generate(self, batch):
|
|
images = batch.pop('inputs')
|
|
images = torch.stack(images, dim=0)
|
|
format_input = self.prompt_constructor(batch)
|
|
query = self.tokenizer.from_list_format(format_input)
|
|
|
|
inputs = self.tokenizer(query, return_tensors='pt')
|
|
inputs = inputs.to(get_device())
|
|
input_ids, token_type_ids, attention_mask = inputs[
|
|
'input_ids'], inputs['token_type_ids'], inputs['attention_mask']
|
|
inputs_embeds = self._build_embeds(images, input_ids)
|
|
pred = self.model.generate(input_ids=input_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids)
|
|
response = self.post_processor(pred.cpu()[0])
|
|
|
|
data_sample = batch['data_samples'][0]
|
|
if self.is_caption_task:
|
|
data_sample.pred_caption = response
|
|
else:
|
|
data_sample.pred_answer = response
|
|
return data_sample
|
|
|
|
def forward(self, batch):
|
|
return self.generate(batch)
|
|
|
|
|
|
@MM_MODELS.register_module('qwen-vl-chat')
|
|
class QwenVLChat(QwenVLBase):
|
|
"""Inference code of Qwen-VL-Chat.
|
|
|
|
We load the Qwen model via Huggingface.
|
|
Args:
|
|
pretrained_path (str): Path to Qwen checkpoint or repo id.
|
|
prompt_constructor (dict): The config of prompt constructor.
|
|
post_processor (dict): The config of post processor.
|
|
is_caption_task (bool): Whether the task is caption task.
|
|
Defaults to False.
|
|
"""
|
|
|
|
def __init__(self,
|
|
pretrained_path: str,
|
|
prompt_constructor: dict = None,
|
|
post_processor: dict = None,
|
|
is_caption_task: bool = False) -> None:
|
|
super().__init__(pretrained_path, prompt_constructor, post_processor,
|
|
is_caption_task)
|
|
|
|
def generate(self, batch):
|
|
images = batch.pop('inputs')
|
|
images = torch.stack(images, dim=0)
|
|
format_input = self.prompt_constructor(batch)
|
|
query = self.tokenizer.from_list_format(format_input)
|
|
|
|
raw_text, context_tokens = make_context(
|
|
self.tokenizer,
|
|
query,
|
|
system='You are a helpful assistant.',
|
|
chat_format=self.model.generation_config.chat_format,
|
|
)
|
|
|
|
input_ids = torch.tensor([context_tokens]).to(get_device())
|
|
|
|
inputs_embeds = self._build_embeds(images, input_ids)
|
|
pred = self.model.generate(input_ids=input_ids,
|
|
inputs_embeds=inputs_embeds)
|
|
|
|
response = decode_tokens(
|
|
pred[0],
|
|
self.tokenizer,
|
|
raw_text_len=len(raw_text),
|
|
context_length=len(context_tokens),
|
|
chat_format=self.model.generation_config.chat_format,
|
|
verbose=False,
|
|
errors='replace')
|
|
|
|
data_sample = batch['data_samples'][0]
|
|
if self.is_caption_task:
|
|
data_sample.pred_caption = response
|
|
else:
|
|
data_sample.pred_answer = response
|
|
return data_sample
|
|
|
|
|
|
def forward_hack(self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None):
|
|
if past_key_values is None and input_ids is not None and torch.any(
|
|
input_ids == self.config.visual['image_start_id']):
|
|
bos_pos = torch.where(
|
|
input_ids == self.config.visual['image_start_id'])
|
|
eos_pos = torch.where(
|
|
input_ids == self.config.visual['image_start_id'] + 1)
|
|
assert (bos_pos[0] == eos_pos[0]).all()
|
|
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
|
|
images = []
|
|
for i, a, b in img_pos:
|
|
image = input_ids[i][a + 1:b - 1].tolist()
|
|
image = image[:image.index(self.config.visual['image_start_id'] +
|
|
2)]
|
|
images.append(bytes(image).decode('utf-8'))
|
|
|
|
images = self.visual.encode(images)
|
|
assert images.shape[0] == len(images)
|
|
else:
|
|
images = None
|
|
|
|
output_attentions = (output_attentions if output_attentions is not None
|
|
else self.config.output_attentions)
|
|
output_hidden_states = (output_hidden_states if output_hidden_states
|
|
is not None else self.config.output_hidden_states)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
return_dict = (return_dict
|
|
if return_dict is not None else self.config.use_return_dict)
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError(
|
|
'You cannot specify both input_ids and inputs_embeds at the same time' # noqa
|
|
)
|
|
elif input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
batch_size = input_ids.shape[0]
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
batch_size = inputs_embeds.shape[0]
|
|
else:
|
|
raise ValueError(
|
|
'You have to specify either input_ids or inputs_embeds')
|
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa
|
|
|
|
if token_type_ids is not None:
|
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
|
if position_ids is not None:
|
|
position_ids = position_ids.view(-1, input_shape[-1])
|
|
|
|
if past_key_values is None:
|
|
past_length = 0
|
|
past_key_values = tuple([None] * len(self.h))
|
|
else:
|
|
past_length = past_key_values[0][0].size(-2)
|
|
|
|
if position_ids is None:
|
|
position_ids = torch.arange(
|
|
past_length,
|
|
input_shape[-1] + past_length,
|
|
dtype=torch.long,
|
|
device=device,
|
|
)
|
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
|
|
|
encoder_attention_mask = None
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.wte(input_ids)
|
|
|
|
if batch_size <= 0:
|
|
raise ValueError('batch_size has to be defined and > 0')
|
|
attention_mask = self._prepare_decoder_attention_mask(
|
|
attention_mask, input_shape, inputs_embeds, past_length)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
hidden_states = self.drop(hidden_states)
|
|
if images is not None:
|
|
for idx, (i, a, b) in enumerate(img_pos):
|
|
hidden_states[i][a + 1:b] = images[idx]
|
|
output_shape = input_shape + (hidden_states.size(-1), )
|
|
|
|
presents = () if use_cache else None
|
|
all_self_attentions = () if output_attentions else None
|
|
all_hidden_states = () if output_hidden_states else None
|
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states, )
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
def create_custom_forward(module):
|
|
|
|
def custom_forward(*inputs):
|
|
# None for past_key_value
|
|
return module(*inputs, use_cache, output_attentions)
|
|
|
|
return custom_forward
|
|
|
|
outputs = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(block),
|
|
hidden_states,
|
|
None,
|
|
attention_mask,
|
|
head_mask[i],
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
)
|
|
else:
|
|
outputs = block(
|
|
hidden_states,
|
|
layer_past=layer_past,
|
|
attention_mask=attention_mask,
|
|
head_mask=head_mask[i],
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
if use_cache is True:
|
|
presents = presents + (outputs[2 if output_attentions else 1], )
|
|
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (outputs[1], )
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
hidden_states = hidden_states.view(output_shape)
|
|
# Add last hidden state
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states, )
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, presents, all_hidden_states]
|
|
if v is not None)
|
|
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=presents,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
)
|