[Feat] Support Qwen-VL-Chat on MMBench. (#312)

* [Feat] Support Qwen-VL base.

* [Feat] Support Qwen-VL-Chat on MMBench.

* [Fix] Add postprocessor and fix format.

* [Fix] Add type hint and remove redundant codes.

* [Fix] fix bugs in postprocessor.

* [Fix] Use given commit id.
This commit is contained in:
Yike Yuan 2023-09-06 18:42:19 +08:00 committed by GitHub
parent ddb8197212
commit b885ec84df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 752 additions and 0 deletions

View File

@ -0,0 +1,41 @@
from opencompass.multimodal.models.qwen import QwenVLMMBenchPromptConstructor, QwenVLBasePostProcessor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.torchvision/Resize',
size=(448, 448),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'options', 'category', 'l2-category', 'context',
'index', 'options_dict'
])
]
dataset = dict(type='opencompass.MMBenchDataset',
data_file='data/mmbench/mmbench_test_20230712.tsv',
pipeline=val_pipeline)
qwen_mmbench_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
qwen_model = dict(
type='qwen-vl-base',
pretrained_path='Qwen/Qwen-VL', # or Huggingface repo id
prompt_constructor=dict(type=QwenMMBenchPromptConstructor),
post_processor=dict(type=QwenVLBasePostProcessor)
)
# evaluation settings
qwen_mmbench_evaluator = [
dict(type='opencompass.DumpResults',
save_path='work_dirs/qwenvl-base-7b-mmbench.xlsx')
]

View File

@ -0,0 +1,40 @@
from opencompass.multimodal.models.qwen import QwenVLMMBenchPromptConstructor
# dataloader settings
val_pipeline = [
dict(type='mmpretrain.torchvision/Resize',
size=(448, 448),
interpolation=3),
dict(type='mmpretrain.torchvision/ToTensor'),
dict(type='mmpretrain.torchvision/Normalize',
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
dict(type='mmpretrain.PackInputs',
algorithm_keys=[
'question', 'options', 'category', 'l2-category', 'context',
'index', 'options_dict'
])
]
dataset = dict(type='opencompass.MMBenchDataset',
data_file='data/mmbench/mmbench_test_20230712.tsv',
pipeline=val_pipeline)
qwen_mmbench_dataloader = dict(batch_size=1,
num_workers=4,
dataset=dataset,
collate_fn=dict(type='pseudo_collate'),
sampler=dict(type='DefaultSampler', shuffle=False))
# model settings
qwen_model = dict(
type='qwen-vl-chat',
pretrained_path='Qwen/Qwen-VL-Chat', # or Huggingface repo id
prompt_constructor=dict(type=QwenVLMMBenchPromptConstructor)
)
# evaluation settings
qwen_mmbench_evaluator = [
dict(type='opencompass.DumpResults',
save_path='work_dirs/qwenvl-chat-7b-mmbench.xlsx')
]

View File

@ -23,4 +23,5 @@ from .openflamingo import * # noqa: F401, F403
if osp.exists('opencompass/multimodal/models/otter/Otter'): if osp.exists('opencompass/multimodal/models/otter/Otter'):
from .otter import * # noqa: F401, F403 from .otter import * # noqa: F401, F403
from .qwen import * # noqa: F401, F403
from .visualglm import * # noqa: F401, F403 from .visualglm import * # noqa: F401, F403

View File

@ -0,0 +1,8 @@
from .post_processor import QwenVLBasePostProcessor
from .prompt_constructor import QwenVLMMBenchPromptConstructor
from .qwen import QwenVLBase, QwenVLChat
__all__ = [
'QwenVLBase', 'QwenVLChat', 'QwenVLBasePostProcessor',
'QwenVLMMBenchPromptConstructor'
]

View File

@ -0,0 +1,293 @@
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Generation support."""
from typing import List, Tuple, Union
import torch
from transformers import PreTrainedTokenizer
# Types.
HistoryType = List[Tuple[str, str]]
TokensType = List[int]
BatchTokensType = List[List[int]]
def pad_batch(batch: BatchTokensType, pad_id: int,
seq_length: int) -> BatchTokensType:
for tokens in batch:
context_length = len(tokens)
if context_length < seq_length:
tokens.extend([pad_id] * (seq_length - context_length))
return batch
def get_ltor_masks_and_position_ids(
data: torch.Tensor,
eod_token: int,
reset_position_ids: bool,
reset_attention_mask: bool,
eod_mask_loss: bool,
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length),
device=data.device)).view(att_mask_batch, 1, seq_length,
seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length,
dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modified based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indices where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indices from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indices:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1):] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, loss_mask, position_ids
def get_batch(context_tokens: torch.LongTensor, eod_id: int):
"""Generate batch from context tokens."""
# Move to GPU.
tokens = context_tokens.contiguous().to(context_tokens.device)
# Get the attention mask and position ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
eod_id,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False,
)
return tokens, attention_mask, position_ids
def get_stop_words_ids(chat_format: str, tokenizer: PreTrainedTokenizer):
if chat_format == 'raw':
stop_words_ids = [tokenizer.encode('Human:'), [tokenizer.eod_id]]
elif chat_format == 'chatml':
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
else:
raise NotImplementedError(f'Unknown chat format {chat_format!r}')
return stop_words_ids
def make_context(
tokenizer: PreTrainedTokenizer,
query: str,
history: List[Tuple[str, str]] = None,
system: str = '',
max_window_size: int = 6144,
chat_format: str = 'chatml',
):
if history is None:
history = []
if chat_format == 'chatml':
im_start, im_end = '<|im_start|>', '<|im_end|>'
im_start_tokens = [tokenizer.im_start_id]
im_end_tokens = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode('\n')
def _tokenize_str(role, content):
return f'{role}\n{content}', tokenizer.encode(
role, allowed_special=set(
tokenizer.IMAGE_ST)) + nl_tokens + tokenizer.encode(
content, allowed_special=set(tokenizer.IMAGE_ST))
system_text, system_tokens_part = _tokenize_str('system', system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
raw_text = ''
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str('user', turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
if turn_response is not None:
response_text, response_tokens_part = _tokenize_str(
'assistant', turn_response)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens # noqa
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens # noqa
prev_chat = (
f'\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}' # noqa
)
else:
next_context_tokens = nl_tokens + query_tokens + nl_tokens
prev_chat = f'\n{im_start}{query_text}{im_end}\n'
current_context_size = (len(system_tokens) +
len(next_context_tokens) +
len(context_tokens))
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
else:
break
context_tokens = system_tokens + context_tokens
raw_text = f'{im_start}{system_text}{im_end}' + raw_text
context_tokens += (nl_tokens + im_start_tokens +
_tokenize_str('user', query)[1] + im_end_tokens +
nl_tokens + im_start_tokens +
tokenizer.encode('assistant') + nl_tokens)
raw_text += f'\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n'
elif chat_format == 'raw':
raw_text = query
context_tokens = tokenizer.encode(raw_text)
else:
raise NotImplementedError(f'Unknown chat format {chat_format!r}')
return raw_text, context_tokens
def _decode_default(
tokens: List[int],
*,
stop_words: List[str],
eod_words: List[str],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str = 'replace',
):
trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
if verbose:
print('\nRaw Generate: ', trim_decode_tokens)
end_reason = f'Gen length {len(tokens)}'
for stop_word in stop_words:
trim_decode_tokens = trim_decode_tokens.replace(stop_word, '').strip()
for eod_word in eod_words:
if eod_word in trim_decode_tokens:
end_reason = f'Gen {eod_word!r}'
trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
trim_decode_tokens = trim_decode_tokens.strip()
if verbose:
print('\nEnd Reason:', end_reason)
print('\nGenerate: ', trim_decode_tokens)
if return_end_reason:
return trim_decode_tokens, end_reason
else:
return trim_decode_tokens
def _decode_chatml(tokens: List[int],
*,
stop_words: List[str],
eod_token_ids: List[int],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str = 'replace'):
end_reason = f'Gen length {len(tokens)}'
eod_token_idx = context_length
for eod_token_idx in range(context_length, len(tokens)):
if tokens[eod_token_idx] in eod_token_ids:
end_reason = f'Gen {tokenizer.decode([tokens[eod_token_idx]])!r}'
break
trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx],
errors=errors)[raw_text_len:]
if verbose:
print('\nRaw Generate w/o EOD:',
tokenizer.decode(tokens, errors=errors)[raw_text_len:])
print('\nRaw Generate:', trim_decode_tokens)
print('\nEnd Reason:', end_reason)
for stop_word in stop_words:
trim_decode_tokens = trim_decode_tokens.replace(stop_word, '').strip()
trim_decode_tokens = trim_decode_tokens.strip()
if verbose:
print('\nGenerate:', trim_decode_tokens)
if return_end_reason:
return trim_decode_tokens, end_reason
else:
return trim_decode_tokens
def decode_tokens(
tokens: Union[torch.LongTensor, TokensType],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
chat_format: str,
verbose: bool = False,
return_end_reason: bool = False,
errors: str = 'replace',
) -> str:
if torch.is_tensor(tokens):
tokens = tokens.cpu().numpy().tolist()
if chat_format == 'chatml':
return _decode_chatml(
tokens,
stop_words=[],
eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
tokenizer=tokenizer,
raw_text_len=raw_text_len,
context_length=context_length,
verbose=verbose,
return_end_reason=return_end_reason,
errors=errors,
)
elif chat_format == 'raw':
return _decode_default(
tokens,
stop_words=['<|endoftext|>'],
eod_words=['<|endoftext|>'],
tokenizer=tokenizer,
raw_text_len=raw_text_len,
verbose=verbose,
return_end_reason=return_end_reason,
errors=errors,
)
else:
raise NotImplementedError(f'Unknown chat format {chat_format!r}')

View File

@ -0,0 +1,16 @@
from typing import Any
import torch
class QwenVLBasePostProcessor:
"""Post processor for Qwen-VL-Base."""
def __init__(self) -> None:
pass
def __call__(self, pred: torch.tensor, tokenizer: Any,
input_len: int) -> str:
response = self.tokenizer.decode(pred)[input_len:]
response = response.replace('<|endoftext|>', '').strip()
return response

View File

@ -0,0 +1,29 @@
class QwenVLMMBenchPromptConstructor:
"""MMBench prompt constructor for Qwen-VL.
The output is a dict following the input format of Qwen-VL tokenizer.
"""
def __init__(self) -> None:
pass
def __call__(self, inputs: dict) -> str:
data_samples = inputs['data_samples']
assert len(data_samples) == 1
data_sample = data_samples[0]
question = data_sample.get('question')
options = data_sample.get('options')
context = data_sample.get('context')
if context is not None:
prompt = context + ' ' + question + ' ' + options
else:
prompt = question + ' ' + options
format_input = [
{
'image': 'This_is_path_to_an_image.'
}, # Just placeholder for Image Tokens
{
'text': prompt
},
]
return format_input

View File

@ -0,0 +1,324 @@
import types
from typing import Optional, Tuple
import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from transformers.modeling_outputs import BaseModelOutputWithPast
from opencompass.registry import MM_MODELS
from .generation_utils import decode_tokens, make_context
@MM_MODELS.register_module('qwen-vl-base')
class QwenVLBase(nn.Module):
"""Inference code of Qwen-VL.
We load the Qwen model via Huggingface.
Args:
pretrained_path (str): Path to Qwen checkpoint or repo id.
prompt_constructor (dict): The config of prompt constructor.
post_processor (dict): The config of post processor.
is_caption_task (bool): Whether the task is caption task.
Defaults to False.
commit_id (str): Use given version of Qwen-VL.
Warning: the latest version may have some conflicts.
Recommend to use the given default version.
"""
def __init__(
self,
pretrained_path: str,
prompt_constructor: dict = None,
post_processor: dict = None,
is_caption_task: bool = False,
commit_id: str = '548275c8b99de56dec203c0e793be18e030f2f4c'
) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path,
trust_remote_code=True,
revision=commit_id)
self.model = AutoModelForCausalLM.from_pretrained(
pretrained_path,
device_map=get_device(),
trust_remote_code=True,
revision=commit_id)
self.model.generation_config = GenerationConfig.from_pretrained(
pretrained_path, trust_remote_code=True, revision=commit_id)
if prompt_constructor is not None:
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
self.is_caption_task = is_caption_task
self.model.transformer.forward = types.MethodType(
forward_hack, self.model.transformer)
def _build_embeds(self, images, input_ids):
# encode image
images = self.model.transformer.visual(images)
# compute image position
bos_pos = torch.where(input_ids == self.model.transformer.config.
visual['image_start_id'])
eos_pos = torch.where(
input_ids ==
self.model.transformer.config.visual['image_start_id'] + 1)
assert (bos_pos[0] == eos_pos[0]).all()
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
# embed words
inputs_embeds = self.model.transformer.wte(input_ids)
# embed image tokens
for idx, (i, a, b) in enumerate(img_pos):
inputs_embeds[i][a + 1:b] = images[idx]
return inputs_embeds
def generate(self, batch):
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
format_input = self.prompt_constructor(batch)
query = self.tokenizer.from_list_format(format_input)
inputs = self.tokenizer(query, return_tensors='pt')
inputs = inputs.to(get_device())
input_ids, token_type_ids, attention_mask = inputs[
'input_ids'], inputs['token_type_ids'], inputs['attention_mask']
inputs_embeds = self._build_embeds(images, input_ids)
pred = self.model.generate(input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
response = self.post_processor(pred.cpu()[0])
data_sample = batch['data_samples'][0]
if self.is_caption_task:
data_sample.pred_caption = response
else:
data_sample.pred_answer = response
return data_sample
def forward(self, batch):
return self.generate(batch)
@MM_MODELS.register_module('qwen-vl-chat')
class QwenVLChat(QwenVLBase):
"""Inference code of Qwen-VL-Chat.
We load the Qwen model via Huggingface.
Args:
pretrained_path (str): Path to Qwen checkpoint or repo id.
prompt_constructor (dict): The config of prompt constructor.
post_processor (dict): The config of post processor.
is_caption_task (bool): Whether the task is caption task.
Defaults to False.
"""
def __init__(self,
pretrained_path: str,
prompt_constructor: dict = None,
post_processor: dict = None,
is_caption_task: bool = False) -> None:
super().__init__(pretrained_path, prompt_constructor, post_processor,
is_caption_task)
def generate(self, batch):
images = batch.pop('inputs')
images = torch.stack(images, dim=0)
format_input = self.prompt_constructor(batch)
query = self.tokenizer.from_list_format(format_input)
raw_text, context_tokens = make_context(
self.tokenizer,
query,
system='You are a helpful assistant.',
chat_format=self.model.generation_config.chat_format,
)
input_ids = torch.tensor([context_tokens]).to(get_device())
inputs_embeds = self._build_embeds(images, input_ids)
pred = self.model.generate(input_ids=input_ids,
inputs_embeds=inputs_embeds)
response = decode_tokens(
pred[0],
self.tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
chat_format=self.model.generation_config.chat_format,
verbose=False,
errors='replace')
data_sample = batch['data_samples'][0]
if self.is_caption_task:
data_sample.pred_caption = response
else:
data_sample.pred_answer = response
return data_sample
def forward_hack(self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None):
if past_key_values is None and input_ids is not None and torch.any(
input_ids == self.config.visual['image_start_id']):
bos_pos = torch.where(
input_ids == self.config.visual['image_start_id'])
eos_pos = torch.where(
input_ids == self.config.visual['image_start_id'] + 1)
assert (bos_pos[0] == eos_pos[0]).all()
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
images = []
for i, a, b in img_pos:
image = input_ids[i][a + 1:b - 1].tolist()
image = image[:image.index(self.config.visual['image_start_id'] +
2)]
images.append(bytes(image).decode('utf-8'))
images = self.visual.encode(images)
assert images.shape[0] == len(images)
else:
images = None
output_attentions = (output_attentions if output_attentions is not None
else self.config.output_attentions)
output_hidden_states = (output_hidden_states if output_hidden_states
is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict
if return_dict is not None else self.config.use_return_dict)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
'You cannot specify both input_ids and inputs_embeds at the same time' # noqa
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError(
'You have to specify either input_ids or inputs_embeds')
device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
encoder_attention_mask = None
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
if batch_size <= 0:
raise ValueError('batch_size has to be defined and > 0')
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_length)
hidden_states = inputs_embeds
hidden_states = self.drop(hidden_states)
if images is not None:
for idx, (i, a, b) in enumerate(img_pos):
hidden_states[i][a + 1:b] = images[idx]
output_shape = input_shape + (hidden_states.size(-1), )
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states, )
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[2 if output_attentions else 1], )
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[1], )
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states, )
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)