mirror of
https://github.com/open-compass/opencompass.git
synced 2025-05-30 16:03:24 +08:00
[Feat] Support Qwen-VL-Chat on MMBench. (#312)
* [Feat] Support Qwen-VL base. * [Feat] Support Qwen-VL-Chat on MMBench. * [Fix] Add postprocessor and fix format. * [Fix] Add type hint and remove redundant codes. * [Fix] fix bugs in postprocessor. * [Fix] Use given commit id.
This commit is contained in:
parent
ddb8197212
commit
b885ec84df
41
configs/multimodal/qwen/qwenvl_base_7b_mmbench.py
Normal file
41
configs/multimodal/qwen/qwenvl_base_7b_mmbench.py
Normal file
@ -0,0 +1,41 @@
|
||||
from opencompass.multimodal.models.qwen import QwenVLMMBenchPromptConstructor, QwenVLBasePostProcessor
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(448, 448),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'options', 'category', 'l2-category', 'context',
|
||||
'index', 'options_dict'
|
||||
])
|
||||
]
|
||||
|
||||
dataset = dict(type='opencompass.MMBenchDataset',
|
||||
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
qwen_mmbench_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||
|
||||
# model settings
|
||||
qwen_model = dict(
|
||||
type='qwen-vl-base',
|
||||
pretrained_path='Qwen/Qwen-VL', # or Huggingface repo id
|
||||
prompt_constructor=dict(type=QwenMMBenchPromptConstructor),
|
||||
post_processor=dict(type=QwenVLBasePostProcessor)
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
qwen_mmbench_evaluator = [
|
||||
dict(type='opencompass.DumpResults',
|
||||
save_path='work_dirs/qwenvl-base-7b-mmbench.xlsx')
|
||||
]
|
40
configs/multimodal/qwen/qwenvl_chat_7b_mmbench.py
Normal file
40
configs/multimodal/qwen/qwenvl_chat_7b_mmbench.py
Normal file
@ -0,0 +1,40 @@
|
||||
from opencompass.multimodal.models.qwen import QwenVLMMBenchPromptConstructor
|
||||
|
||||
# dataloader settings
|
||||
val_pipeline = [
|
||||
dict(type='mmpretrain.torchvision/Resize',
|
||||
size=(448, 448),
|
||||
interpolation=3),
|
||||
dict(type='mmpretrain.torchvision/ToTensor'),
|
||||
dict(type='mmpretrain.torchvision/Normalize',
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711)),
|
||||
dict(type='mmpretrain.PackInputs',
|
||||
algorithm_keys=[
|
||||
'question', 'options', 'category', 'l2-category', 'context',
|
||||
'index', 'options_dict'
|
||||
])
|
||||
]
|
||||
|
||||
dataset = dict(type='opencompass.MMBenchDataset',
|
||||
data_file='data/mmbench/mmbench_test_20230712.tsv',
|
||||
pipeline=val_pipeline)
|
||||
|
||||
qwen_mmbench_dataloader = dict(batch_size=1,
|
||||
num_workers=4,
|
||||
dataset=dataset,
|
||||
collate_fn=dict(type='pseudo_collate'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False))
|
||||
|
||||
# model settings
|
||||
qwen_model = dict(
|
||||
type='qwen-vl-chat',
|
||||
pretrained_path='Qwen/Qwen-VL-Chat', # or Huggingface repo id
|
||||
prompt_constructor=dict(type=QwenVLMMBenchPromptConstructor)
|
||||
)
|
||||
|
||||
# evaluation settings
|
||||
qwen_mmbench_evaluator = [
|
||||
dict(type='opencompass.DumpResults',
|
||||
save_path='work_dirs/qwenvl-chat-7b-mmbench.xlsx')
|
||||
]
|
@ -23,4 +23,5 @@ from .openflamingo import * # noqa: F401, F403
|
||||
if osp.exists('opencompass/multimodal/models/otter/Otter'):
|
||||
from .otter import * # noqa: F401, F403
|
||||
|
||||
from .qwen import * # noqa: F401, F403
|
||||
from .visualglm import * # noqa: F401, F403
|
||||
|
8
opencompass/multimodal/models/qwen/__init__.py
Normal file
8
opencompass/multimodal/models/qwen/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from .post_processor import QwenVLBasePostProcessor
|
||||
from .prompt_constructor import QwenVLMMBenchPromptConstructor
|
||||
from .qwen import QwenVLBase, QwenVLChat
|
||||
|
||||
__all__ = [
|
||||
'QwenVLBase', 'QwenVLChat', 'QwenVLBasePostProcessor',
|
||||
'QwenVLMMBenchPromptConstructor'
|
||||
]
|
293
opencompass/multimodal/models/qwen/generation_utils.py
Normal file
293
opencompass/multimodal/models/qwen/generation_utils.py
Normal file
@ -0,0 +1,293 @@
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Generation support."""
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Types.
|
||||
HistoryType = List[Tuple[str, str]]
|
||||
TokensType = List[int]
|
||||
BatchTokensType = List[List[int]]
|
||||
|
||||
|
||||
def pad_batch(batch: BatchTokensType, pad_id: int,
|
||||
seq_length: int) -> BatchTokensType:
|
||||
for tokens in batch:
|
||||
context_length = len(tokens)
|
||||
if context_length < seq_length:
|
||||
tokens.extend([pad_id] * (seq_length - context_length))
|
||||
return batch
|
||||
|
||||
|
||||
def get_ltor_masks_and_position_ids(
|
||||
data: torch.Tensor,
|
||||
eod_token: int,
|
||||
reset_position_ids: bool,
|
||||
reset_attention_mask: bool,
|
||||
eod_mask_loss: bool,
|
||||
):
|
||||
"""Build masks and position id for left to right model."""
|
||||
|
||||
# Extract batch size and sequence length.
|
||||
micro_batch_size, seq_length = data.size()
|
||||
|
||||
# Attention mask (lower triangular).
|
||||
if reset_attention_mask:
|
||||
att_mask_batch = micro_batch_size
|
||||
else:
|
||||
att_mask_batch = 1
|
||||
attention_mask = torch.tril(
|
||||
torch.ones((att_mask_batch, seq_length, seq_length),
|
||||
device=data.device)).view(att_mask_batch, 1, seq_length,
|
||||
seq_length)
|
||||
|
||||
# Loss mask.
|
||||
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
|
||||
if eod_mask_loss:
|
||||
loss_mask[data == eod_token] = 0.0
|
||||
|
||||
# Position ids.
|
||||
position_ids = torch.arange(seq_length,
|
||||
dtype=torch.long,
|
||||
device=data.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(data)
|
||||
# We need to clone as the ids will be modified based on batch index.
|
||||
if reset_position_ids:
|
||||
position_ids = position_ids.clone()
|
||||
|
||||
if reset_position_ids or reset_attention_mask:
|
||||
# Loop through the batches:
|
||||
for b in range(micro_batch_size):
|
||||
|
||||
# Find indices where EOD token is.
|
||||
eod_index = position_ids[b, data[b] == eod_token]
|
||||
# Detach indices from positions if going to modify positions.
|
||||
if reset_position_ids:
|
||||
eod_index = eod_index.clone()
|
||||
|
||||
# Loop through EOD indices:
|
||||
prev_index = 0
|
||||
for j in range(eod_index.size()[0]):
|
||||
i = eod_index[j]
|
||||
# Mask attention loss.
|
||||
if reset_attention_mask:
|
||||
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
|
||||
# Reset positions.
|
||||
if reset_position_ids:
|
||||
position_ids[b, (i + 1):] -= i + 1 - prev_index
|
||||
prev_index = i + 1
|
||||
|
||||
# Convert attention mask to binary:
|
||||
attention_mask = attention_mask < 0.5
|
||||
|
||||
return attention_mask, loss_mask, position_ids
|
||||
|
||||
|
||||
def get_batch(context_tokens: torch.LongTensor, eod_id: int):
|
||||
"""Generate batch from context tokens."""
|
||||
# Move to GPU.
|
||||
tokens = context_tokens.contiguous().to(context_tokens.device)
|
||||
# Get the attention mask and position ids.
|
||||
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
eod_id,
|
||||
reset_position_ids=False,
|
||||
reset_attention_mask=False,
|
||||
eod_mask_loss=False,
|
||||
)
|
||||
return tokens, attention_mask, position_ids
|
||||
|
||||
|
||||
def get_stop_words_ids(chat_format: str, tokenizer: PreTrainedTokenizer):
|
||||
if chat_format == 'raw':
|
||||
stop_words_ids = [tokenizer.encode('Human:'), [tokenizer.eod_id]]
|
||||
elif chat_format == 'chatml':
|
||||
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown chat format {chat_format!r}')
|
||||
return stop_words_ids
|
||||
|
||||
|
||||
def make_context(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]] = None,
|
||||
system: str = '',
|
||||
max_window_size: int = 6144,
|
||||
chat_format: str = 'chatml',
|
||||
):
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
if chat_format == 'chatml':
|
||||
im_start, im_end = '<|im_start|>', '<|im_end|>'
|
||||
im_start_tokens = [tokenizer.im_start_id]
|
||||
im_end_tokens = [tokenizer.im_end_id]
|
||||
nl_tokens = tokenizer.encode('\n')
|
||||
|
||||
def _tokenize_str(role, content):
|
||||
return f'{role}\n{content}', tokenizer.encode(
|
||||
role, allowed_special=set(
|
||||
tokenizer.IMAGE_ST)) + nl_tokens + tokenizer.encode(
|
||||
content, allowed_special=set(tokenizer.IMAGE_ST))
|
||||
|
||||
system_text, system_tokens_part = _tokenize_str('system', system)
|
||||
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
|
||||
|
||||
raw_text = ''
|
||||
context_tokens = []
|
||||
|
||||
for turn_query, turn_response in reversed(history):
|
||||
query_text, query_tokens_part = _tokenize_str('user', turn_query)
|
||||
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
|
||||
if turn_response is not None:
|
||||
response_text, response_tokens_part = _tokenize_str(
|
||||
'assistant', turn_response)
|
||||
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens # noqa
|
||||
|
||||
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens # noqa
|
||||
prev_chat = (
|
||||
f'\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}' # noqa
|
||||
)
|
||||
else:
|
||||
next_context_tokens = nl_tokens + query_tokens + nl_tokens
|
||||
prev_chat = f'\n{im_start}{query_text}{im_end}\n'
|
||||
|
||||
current_context_size = (len(system_tokens) +
|
||||
len(next_context_tokens) +
|
||||
len(context_tokens))
|
||||
if current_context_size < max_window_size:
|
||||
context_tokens = next_context_tokens + context_tokens
|
||||
raw_text = prev_chat + raw_text
|
||||
else:
|
||||
break
|
||||
|
||||
context_tokens = system_tokens + context_tokens
|
||||
raw_text = f'{im_start}{system_text}{im_end}' + raw_text
|
||||
context_tokens += (nl_tokens + im_start_tokens +
|
||||
_tokenize_str('user', query)[1] + im_end_tokens +
|
||||
nl_tokens + im_start_tokens +
|
||||
tokenizer.encode('assistant') + nl_tokens)
|
||||
raw_text += f'\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n'
|
||||
|
||||
elif chat_format == 'raw':
|
||||
raw_text = query
|
||||
context_tokens = tokenizer.encode(raw_text)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown chat format {chat_format!r}')
|
||||
|
||||
return raw_text, context_tokens
|
||||
|
||||
|
||||
def _decode_default(
|
||||
tokens: List[int],
|
||||
*,
|
||||
stop_words: List[str],
|
||||
eod_words: List[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
raw_text_len: int,
|
||||
verbose: bool = False,
|
||||
return_end_reason: bool = False,
|
||||
errors: str = 'replace',
|
||||
):
|
||||
trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
|
||||
if verbose:
|
||||
print('\nRaw Generate: ', trim_decode_tokens)
|
||||
|
||||
end_reason = f'Gen length {len(tokens)}'
|
||||
for stop_word in stop_words:
|
||||
trim_decode_tokens = trim_decode_tokens.replace(stop_word, '').strip()
|
||||
for eod_word in eod_words:
|
||||
if eod_word in trim_decode_tokens:
|
||||
end_reason = f'Gen {eod_word!r}'
|
||||
trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
|
||||
trim_decode_tokens = trim_decode_tokens.strip()
|
||||
if verbose:
|
||||
print('\nEnd Reason:', end_reason)
|
||||
print('\nGenerate: ', trim_decode_tokens)
|
||||
|
||||
if return_end_reason:
|
||||
return trim_decode_tokens, end_reason
|
||||
else:
|
||||
return trim_decode_tokens
|
||||
|
||||
|
||||
def _decode_chatml(tokens: List[int],
|
||||
*,
|
||||
stop_words: List[str],
|
||||
eod_token_ids: List[int],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
raw_text_len: int,
|
||||
context_length: int,
|
||||
verbose: bool = False,
|
||||
return_end_reason: bool = False,
|
||||
errors: str = 'replace'):
|
||||
end_reason = f'Gen length {len(tokens)}'
|
||||
eod_token_idx = context_length
|
||||
for eod_token_idx in range(context_length, len(tokens)):
|
||||
if tokens[eod_token_idx] in eod_token_ids:
|
||||
end_reason = f'Gen {tokenizer.decode([tokens[eod_token_idx]])!r}'
|
||||
break
|
||||
|
||||
trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx],
|
||||
errors=errors)[raw_text_len:]
|
||||
if verbose:
|
||||
print('\nRaw Generate w/o EOD:',
|
||||
tokenizer.decode(tokens, errors=errors)[raw_text_len:])
|
||||
print('\nRaw Generate:', trim_decode_tokens)
|
||||
print('\nEnd Reason:', end_reason)
|
||||
for stop_word in stop_words:
|
||||
trim_decode_tokens = trim_decode_tokens.replace(stop_word, '').strip()
|
||||
trim_decode_tokens = trim_decode_tokens.strip()
|
||||
if verbose:
|
||||
print('\nGenerate:', trim_decode_tokens)
|
||||
|
||||
if return_end_reason:
|
||||
return trim_decode_tokens, end_reason
|
||||
else:
|
||||
return trim_decode_tokens
|
||||
|
||||
|
||||
def decode_tokens(
|
||||
tokens: Union[torch.LongTensor, TokensType],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
raw_text_len: int,
|
||||
context_length: int,
|
||||
chat_format: str,
|
||||
verbose: bool = False,
|
||||
return_end_reason: bool = False,
|
||||
errors: str = 'replace',
|
||||
) -> str:
|
||||
if torch.is_tensor(tokens):
|
||||
tokens = tokens.cpu().numpy().tolist()
|
||||
|
||||
if chat_format == 'chatml':
|
||||
return _decode_chatml(
|
||||
tokens,
|
||||
stop_words=[],
|
||||
eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
|
||||
tokenizer=tokenizer,
|
||||
raw_text_len=raw_text_len,
|
||||
context_length=context_length,
|
||||
verbose=verbose,
|
||||
return_end_reason=return_end_reason,
|
||||
errors=errors,
|
||||
)
|
||||
elif chat_format == 'raw':
|
||||
return _decode_default(
|
||||
tokens,
|
||||
stop_words=['<|endoftext|>'],
|
||||
eod_words=['<|endoftext|>'],
|
||||
tokenizer=tokenizer,
|
||||
raw_text_len=raw_text_len,
|
||||
verbose=verbose,
|
||||
return_end_reason=return_end_reason,
|
||||
errors=errors,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown chat format {chat_format!r}')
|
16
opencompass/multimodal/models/qwen/post_processor.py
Normal file
16
opencompass/multimodal/models/qwen/post_processor.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class QwenVLBasePostProcessor:
|
||||
"""Post processor for Qwen-VL-Base."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, pred: torch.tensor, tokenizer: Any,
|
||||
input_len: int) -> str:
|
||||
response = self.tokenizer.decode(pred)[input_len:]
|
||||
response = response.replace('<|endoftext|>', '').strip()
|
||||
return response
|
29
opencompass/multimodal/models/qwen/prompt_constructor.py
Normal file
29
opencompass/multimodal/models/qwen/prompt_constructor.py
Normal file
@ -0,0 +1,29 @@
|
||||
class QwenVLMMBenchPromptConstructor:
|
||||
"""MMBench prompt constructor for Qwen-VL.
|
||||
|
||||
The output is a dict following the input format of Qwen-VL tokenizer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, inputs: dict) -> str:
|
||||
data_samples = inputs['data_samples']
|
||||
assert len(data_samples) == 1
|
||||
data_sample = data_samples[0]
|
||||
question = data_sample.get('question')
|
||||
options = data_sample.get('options')
|
||||
context = data_sample.get('context')
|
||||
if context is not None:
|
||||
prompt = context + ' ' + question + ' ' + options
|
||||
else:
|
||||
prompt = question + ' ' + options
|
||||
format_input = [
|
||||
{
|
||||
'image': 'This_is_path_to_an_image.'
|
||||
}, # Just placeholder for Image Tokens
|
||||
{
|
||||
'text': prompt
|
||||
},
|
||||
]
|
||||
return format_input
|
324
opencompass/multimodal/models/qwen/qwen.py
Normal file
324
opencompass/multimodal/models/qwen/qwen.py
Normal file
@ -0,0 +1,324 @@
|
||||
import types
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.device import get_device
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
||||
from opencompass.registry import MM_MODELS
|
||||
|
||||
from .generation_utils import decode_tokens, make_context
|
||||
|
||||
|
||||
@MM_MODELS.register_module('qwen-vl-base')
|
||||
class QwenVLBase(nn.Module):
|
||||
"""Inference code of Qwen-VL.
|
||||
|
||||
We load the Qwen model via Huggingface.
|
||||
Args:
|
||||
pretrained_path (str): Path to Qwen checkpoint or repo id.
|
||||
prompt_constructor (dict): The config of prompt constructor.
|
||||
post_processor (dict): The config of post processor.
|
||||
is_caption_task (bool): Whether the task is caption task.
|
||||
Defaults to False.
|
||||
commit_id (str): Use given version of Qwen-VL.
|
||||
Warning: the latest version may have some conflicts.
|
||||
Recommend to use the given default version.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_path: str,
|
||||
prompt_constructor: dict = None,
|
||||
post_processor: dict = None,
|
||||
is_caption_task: bool = False,
|
||||
commit_id: str = '548275c8b99de56dec203c0e793be18e030f2f4c'
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path,
|
||||
trust_remote_code=True,
|
||||
revision=commit_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
pretrained_path,
|
||||
device_map=get_device(),
|
||||
trust_remote_code=True,
|
||||
revision=commit_id)
|
||||
self.model.generation_config = GenerationConfig.from_pretrained(
|
||||
pretrained_path, trust_remote_code=True, revision=commit_id)
|
||||
if prompt_constructor is not None:
|
||||
self.prompt_constructor = mmengine.registry.build_from_cfg(
|
||||
prompt_constructor, MM_MODELS)
|
||||
if post_processor is not None:
|
||||
self.post_processor = mmengine.registry.build_from_cfg(
|
||||
post_processor, MM_MODELS)
|
||||
self.is_caption_task = is_caption_task
|
||||
self.model.transformer.forward = types.MethodType(
|
||||
forward_hack, self.model.transformer)
|
||||
|
||||
def _build_embeds(self, images, input_ids):
|
||||
# encode image
|
||||
images = self.model.transformer.visual(images)
|
||||
# compute image position
|
||||
bos_pos = torch.where(input_ids == self.model.transformer.config.
|
||||
visual['image_start_id'])
|
||||
eos_pos = torch.where(
|
||||
input_ids ==
|
||||
self.model.transformer.config.visual['image_start_id'] + 1)
|
||||
assert (bos_pos[0] == eos_pos[0]).all()
|
||||
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
|
||||
# embed words
|
||||
inputs_embeds = self.model.transformer.wte(input_ids)
|
||||
# embed image tokens
|
||||
for idx, (i, a, b) in enumerate(img_pos):
|
||||
inputs_embeds[i][a + 1:b] = images[idx]
|
||||
return inputs_embeds
|
||||
|
||||
def generate(self, batch):
|
||||
images = batch.pop('inputs')
|
||||
images = torch.stack(images, dim=0)
|
||||
format_input = self.prompt_constructor(batch)
|
||||
query = self.tokenizer.from_list_format(format_input)
|
||||
|
||||
inputs = self.tokenizer(query, return_tensors='pt')
|
||||
inputs = inputs.to(get_device())
|
||||
input_ids, token_type_ids, attention_mask = inputs[
|
||||
'input_ids'], inputs['token_type_ids'], inputs['attention_mask']
|
||||
inputs_embeds = self._build_embeds(images, input_ids)
|
||||
pred = self.model.generate(input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids)
|
||||
response = self.post_processor(pred.cpu()[0])
|
||||
|
||||
data_sample = batch['data_samples'][0]
|
||||
if self.is_caption_task:
|
||||
data_sample.pred_caption = response
|
||||
else:
|
||||
data_sample.pred_answer = response
|
||||
return data_sample
|
||||
|
||||
def forward(self, batch):
|
||||
return self.generate(batch)
|
||||
|
||||
|
||||
@MM_MODELS.register_module('qwen-vl-chat')
|
||||
class QwenVLChat(QwenVLBase):
|
||||
"""Inference code of Qwen-VL-Chat.
|
||||
|
||||
We load the Qwen model via Huggingface.
|
||||
Args:
|
||||
pretrained_path (str): Path to Qwen checkpoint or repo id.
|
||||
prompt_constructor (dict): The config of prompt constructor.
|
||||
post_processor (dict): The config of post processor.
|
||||
is_caption_task (bool): Whether the task is caption task.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained_path: str,
|
||||
prompt_constructor: dict = None,
|
||||
post_processor: dict = None,
|
||||
is_caption_task: bool = False) -> None:
|
||||
super().__init__(pretrained_path, prompt_constructor, post_processor,
|
||||
is_caption_task)
|
||||
|
||||
def generate(self, batch):
|
||||
images = batch.pop('inputs')
|
||||
images = torch.stack(images, dim=0)
|
||||
format_input = self.prompt_constructor(batch)
|
||||
query = self.tokenizer.from_list_format(format_input)
|
||||
|
||||
raw_text, context_tokens = make_context(
|
||||
self.tokenizer,
|
||||
query,
|
||||
system='You are a helpful assistant.',
|
||||
chat_format=self.model.generation_config.chat_format,
|
||||
)
|
||||
|
||||
input_ids = torch.tensor([context_tokens]).to(get_device())
|
||||
|
||||
inputs_embeds = self._build_embeds(images, input_ids)
|
||||
pred = self.model.generate(input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
response = decode_tokens(
|
||||
pred[0],
|
||||
self.tokenizer,
|
||||
raw_text_len=len(raw_text),
|
||||
context_length=len(context_tokens),
|
||||
chat_format=self.model.generation_config.chat_format,
|
||||
verbose=False,
|
||||
errors='replace')
|
||||
|
||||
data_sample = batch['data_samples'][0]
|
||||
if self.is_caption_task:
|
||||
data_sample.pred_caption = response
|
||||
else:
|
||||
data_sample.pred_answer = response
|
||||
return data_sample
|
||||
|
||||
|
||||
def forward_hack(self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None):
|
||||
if past_key_values is None and input_ids is not None and torch.any(
|
||||
input_ids == self.config.visual['image_start_id']):
|
||||
bos_pos = torch.where(
|
||||
input_ids == self.config.visual['image_start_id'])
|
||||
eos_pos = torch.where(
|
||||
input_ids == self.config.visual['image_start_id'] + 1)
|
||||
assert (bos_pos[0] == eos_pos[0]).all()
|
||||
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
|
||||
images = []
|
||||
for i, a, b in img_pos:
|
||||
image = input_ids[i][a + 1:b - 1].tolist()
|
||||
image = image[:image.index(self.config.visual['image_start_id'] +
|
||||
2)]
|
||||
images.append(bytes(image).decode('utf-8'))
|
||||
|
||||
images = self.visual.encode(images)
|
||||
assert images.shape[0] == len(images)
|
||||
else:
|
||||
images = None
|
||||
|
||||
output_attentions = (output_attentions if output_attentions is not None
|
||||
else self.config.output_attentions)
|
||||
output_hidden_states = (output_hidden_states if output_hidden_states
|
||||
is not None else self.config.output_hidden_states)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (return_dict
|
||||
if return_dict is not None else self.config.use_return_dict)
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
'You cannot specify both input_ids and inputs_embeds at the same time' # noqa
|
||||
)
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
'You have to specify either input_ids or inputs_embeds')
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_length,
|
||||
input_shape[-1] + past_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
encoder_attention_mask = None
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
|
||||
if batch_size <= 0:
|
||||
raise ValueError('batch_size has to be defined and > 0')
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_length)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
if images is not None:
|
||||
for idx, (i, a, b) in enumerate(img_pos):
|
||||
hidden_states[i][a + 1:b] = images[idx]
|
||||
output_shape = input_shape + (hidden_states.size(-1), )
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states, )
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[2 if output_attentions else 1], )
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[1], )
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states, )
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states]
|
||||
if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
Loading…
Reference in New Issue
Block a user