From b885ec84df0523a2bdcd74defd7c96ee104e616d Mon Sep 17 00:00:00 2001 From: Yike Yuan <32432002+yyk-wew@users.noreply.github.com> Date: Wed, 6 Sep 2023 18:42:19 +0800 Subject: [PATCH] [Feat] Support Qwen-VL-Chat on MMBench. (#312) * [Feat] Support Qwen-VL base. * [Feat] Support Qwen-VL-Chat on MMBench. * [Fix] Add postprocessor and fix format. * [Fix] Add type hint and remove redundant codes. * [Fix] fix bugs in postprocessor. * [Fix] Use given commit id. --- .../multimodal/qwen/qwenvl_base_7b_mmbench.py | 41 +++ .../multimodal/qwen/qwenvl_chat_7b_mmbench.py | 40 +++ opencompass/multimodal/models/__init__.py | 1 + .../multimodal/models/qwen/__init__.py | 8 + .../models/qwen/generation_utils.py | 293 ++++++++++++++++ .../multimodal/models/qwen/post_processor.py | 16 + .../models/qwen/prompt_constructor.py | 29 ++ opencompass/multimodal/models/qwen/qwen.py | 324 ++++++++++++++++++ 8 files changed, 752 insertions(+) create mode 100644 configs/multimodal/qwen/qwenvl_base_7b_mmbench.py create mode 100644 configs/multimodal/qwen/qwenvl_chat_7b_mmbench.py create mode 100644 opencompass/multimodal/models/qwen/__init__.py create mode 100644 opencompass/multimodal/models/qwen/generation_utils.py create mode 100644 opencompass/multimodal/models/qwen/post_processor.py create mode 100644 opencompass/multimodal/models/qwen/prompt_constructor.py create mode 100644 opencompass/multimodal/models/qwen/qwen.py diff --git a/configs/multimodal/qwen/qwenvl_base_7b_mmbench.py b/configs/multimodal/qwen/qwenvl_base_7b_mmbench.py new file mode 100644 index 00000000..23cfb8e6 --- /dev/null +++ b/configs/multimodal/qwen/qwenvl_base_7b_mmbench.py @@ -0,0 +1,41 @@ +from opencompass.multimodal.models.qwen import QwenVLMMBenchPromptConstructor, QwenVLBasePostProcessor + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.torchvision/Resize', + size=(448, 448), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict(type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'options', 'category', 'l2-category', 'context', + 'index', 'options_dict' + ]) +] + +dataset = dict(type='opencompass.MMBenchDataset', + data_file='data/mmbench/mmbench_test_20230712.tsv', + pipeline=val_pipeline) + +qwen_mmbench_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +qwen_model = dict( + type='qwen-vl-base', + pretrained_path='Qwen/Qwen-VL', # or Huggingface repo id + prompt_constructor=dict(type=QwenMMBenchPromptConstructor), + post_processor=dict(type=QwenVLBasePostProcessor) +) + +# evaluation settings +qwen_mmbench_evaluator = [ + dict(type='opencompass.DumpResults', + save_path='work_dirs/qwenvl-base-7b-mmbench.xlsx') +] diff --git a/configs/multimodal/qwen/qwenvl_chat_7b_mmbench.py b/configs/multimodal/qwen/qwenvl_chat_7b_mmbench.py new file mode 100644 index 00000000..de665e4c --- /dev/null +++ b/configs/multimodal/qwen/qwenvl_chat_7b_mmbench.py @@ -0,0 +1,40 @@ +from opencompass.multimodal.models.qwen import QwenVLMMBenchPromptConstructor + +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.torchvision/Resize', + size=(448, 448), + interpolation=3), + dict(type='mmpretrain.torchvision/ToTensor'), + dict(type='mmpretrain.torchvision/Normalize', + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711)), + dict(type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'options', 'category', 'l2-category', 'context', + 'index', 'options_dict' + ]) +] + +dataset = dict(type='opencompass.MMBenchDataset', + data_file='data/mmbench/mmbench_test_20230712.tsv', + pipeline=val_pipeline) + +qwen_mmbench_dataloader = dict(batch_size=1, + num_workers=4, + dataset=dataset, + collate_fn=dict(type='pseudo_collate'), + sampler=dict(type='DefaultSampler', shuffle=False)) + +# model settings +qwen_model = dict( + type='qwen-vl-chat', + pretrained_path='Qwen/Qwen-VL-Chat', # or Huggingface repo id + prompt_constructor=dict(type=QwenVLMMBenchPromptConstructor) +) + +# evaluation settings +qwen_mmbench_evaluator = [ + dict(type='opencompass.DumpResults', + save_path='work_dirs/qwenvl-chat-7b-mmbench.xlsx') +] diff --git a/opencompass/multimodal/models/__init__.py b/opencompass/multimodal/models/__init__.py index dfe7caca..6bb6603c 100644 --- a/opencompass/multimodal/models/__init__.py +++ b/opencompass/multimodal/models/__init__.py @@ -23,4 +23,5 @@ from .openflamingo import * # noqa: F401, F403 if osp.exists('opencompass/multimodal/models/otter/Otter'): from .otter import * # noqa: F401, F403 +from .qwen import * # noqa: F401, F403 from .visualglm import * # noqa: F401, F403 diff --git a/opencompass/multimodal/models/qwen/__init__.py b/opencompass/multimodal/models/qwen/__init__.py new file mode 100644 index 00000000..94f33b6b --- /dev/null +++ b/opencompass/multimodal/models/qwen/__init__.py @@ -0,0 +1,8 @@ +from .post_processor import QwenVLBasePostProcessor +from .prompt_constructor import QwenVLMMBenchPromptConstructor +from .qwen import QwenVLBase, QwenVLChat + +__all__ = [ + 'QwenVLBase', 'QwenVLChat', 'QwenVLBasePostProcessor', + 'QwenVLMMBenchPromptConstructor' +] diff --git a/opencompass/multimodal/models/qwen/generation_utils.py b/opencompass/multimodal/models/qwen/generation_utils.py new file mode 100644 index 00000000..9bfb83a0 --- /dev/null +++ b/opencompass/multimodal/models/qwen/generation_utils.py @@ -0,0 +1,293 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Generation support.""" + +from typing import List, Tuple, Union + +import torch +from transformers import PreTrainedTokenizer + +# Types. +HistoryType = List[Tuple[str, str]] +TokensType = List[int] +BatchTokensType = List[List[int]] + + +def pad_batch(batch: BatchTokensType, pad_id: int, + seq_length: int) -> BatchTokensType: + for tokens in batch: + context_length = len(tokens) + if context_length < seq_length: + tokens.extend([pad_id] * (seq_length - context_length)) + return batch + + +def get_ltor_masks_and_position_ids( + data: torch.Tensor, + eod_token: int, + reset_position_ids: bool, + reset_attention_mask: bool, + eod_mask_loss: bool, +): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + micro_batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if reset_attention_mask: + att_mask_batch = micro_batch_size + else: + att_mask_batch = 1 + attention_mask = torch.tril( + torch.ones((att_mask_batch, seq_length, seq_length), + device=data.device)).view(att_mask_batch, 1, seq_length, + seq_length) + + # Loss mask. + loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) + if eod_mask_loss: + loss_mask[data == eod_token] = 0.0 + + # Position ids. + position_ids = torch.arange(seq_length, + dtype=torch.long, + device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data) + # We need to clone as the ids will be modified based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(micro_batch_size): + + # Find indices where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] + # Detach indices from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indices: + prev_index = 0 + for j in range(eod_index.size()[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i + 1):] -= i + 1 - prev_index + prev_index = i + 1 + + # Convert attention mask to binary: + attention_mask = attention_mask < 0.5 + + return attention_mask, loss_mask, position_ids + + +def get_batch(context_tokens: torch.LongTensor, eod_id: int): + """Generate batch from context tokens.""" + # Move to GPU. + tokens = context_tokens.contiguous().to(context_tokens.device) + # Get the attention mask and position ids. + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + tokens, + eod_id, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + ) + return tokens, attention_mask, position_ids + + +def get_stop_words_ids(chat_format: str, tokenizer: PreTrainedTokenizer): + if chat_format == 'raw': + stop_words_ids = [tokenizer.encode('Human:'), [tokenizer.eod_id]] + elif chat_format == 'chatml': + stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] + else: + raise NotImplementedError(f'Unknown chat format {chat_format!r}') + return stop_words_ids + + +def make_context( + tokenizer: PreTrainedTokenizer, + query: str, + history: List[Tuple[str, str]] = None, + system: str = '', + max_window_size: int = 6144, + chat_format: str = 'chatml', +): + if history is None: + history = [] + + if chat_format == 'chatml': + im_start, im_end = '<|im_start|>', '<|im_end|>' + im_start_tokens = [tokenizer.im_start_id] + im_end_tokens = [tokenizer.im_end_id] + nl_tokens = tokenizer.encode('\n') + + def _tokenize_str(role, content): + return f'{role}\n{content}', tokenizer.encode( + role, allowed_special=set( + tokenizer.IMAGE_ST)) + nl_tokens + tokenizer.encode( + content, allowed_special=set(tokenizer.IMAGE_ST)) + + system_text, system_tokens_part = _tokenize_str('system', system) + system_tokens = im_start_tokens + system_tokens_part + im_end_tokens + + raw_text = '' + context_tokens = [] + + for turn_query, turn_response in reversed(history): + query_text, query_tokens_part = _tokenize_str('user', turn_query) + query_tokens = im_start_tokens + query_tokens_part + im_end_tokens + if turn_response is not None: + response_text, response_tokens_part = _tokenize_str( + 'assistant', turn_response) + response_tokens = im_start_tokens + response_tokens_part + im_end_tokens # noqa + + next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens # noqa + prev_chat = ( + f'\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}' # noqa + ) + else: + next_context_tokens = nl_tokens + query_tokens + nl_tokens + prev_chat = f'\n{im_start}{query_text}{im_end}\n' + + current_context_size = (len(system_tokens) + + len(next_context_tokens) + + len(context_tokens)) + if current_context_size < max_window_size: + context_tokens = next_context_tokens + context_tokens + raw_text = prev_chat + raw_text + else: + break + + context_tokens = system_tokens + context_tokens + raw_text = f'{im_start}{system_text}{im_end}' + raw_text + context_tokens += (nl_tokens + im_start_tokens + + _tokenize_str('user', query)[1] + im_end_tokens + + nl_tokens + im_start_tokens + + tokenizer.encode('assistant') + nl_tokens) + raw_text += f'\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n' + + elif chat_format == 'raw': + raw_text = query + context_tokens = tokenizer.encode(raw_text) + else: + raise NotImplementedError(f'Unknown chat format {chat_format!r}') + + return raw_text, context_tokens + + +def _decode_default( + tokens: List[int], + *, + stop_words: List[str], + eod_words: List[str], + tokenizer: PreTrainedTokenizer, + raw_text_len: int, + verbose: bool = False, + return_end_reason: bool = False, + errors: str = 'replace', +): + trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:] + if verbose: + print('\nRaw Generate: ', trim_decode_tokens) + + end_reason = f'Gen length {len(tokens)}' + for stop_word in stop_words: + trim_decode_tokens = trim_decode_tokens.replace(stop_word, '').strip() + for eod_word in eod_words: + if eod_word in trim_decode_tokens: + end_reason = f'Gen {eod_word!r}' + trim_decode_tokens = trim_decode_tokens.split(eod_word)[0] + trim_decode_tokens = trim_decode_tokens.strip() + if verbose: + print('\nEnd Reason:', end_reason) + print('\nGenerate: ', trim_decode_tokens) + + if return_end_reason: + return trim_decode_tokens, end_reason + else: + return trim_decode_tokens + + +def _decode_chatml(tokens: List[int], + *, + stop_words: List[str], + eod_token_ids: List[int], + tokenizer: PreTrainedTokenizer, + raw_text_len: int, + context_length: int, + verbose: bool = False, + return_end_reason: bool = False, + errors: str = 'replace'): + end_reason = f'Gen length {len(tokens)}' + eod_token_idx = context_length + for eod_token_idx in range(context_length, len(tokens)): + if tokens[eod_token_idx] in eod_token_ids: + end_reason = f'Gen {tokenizer.decode([tokens[eod_token_idx]])!r}' + break + + trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], + errors=errors)[raw_text_len:] + if verbose: + print('\nRaw Generate w/o EOD:', + tokenizer.decode(tokens, errors=errors)[raw_text_len:]) + print('\nRaw Generate:', trim_decode_tokens) + print('\nEnd Reason:', end_reason) + for stop_word in stop_words: + trim_decode_tokens = trim_decode_tokens.replace(stop_word, '').strip() + trim_decode_tokens = trim_decode_tokens.strip() + if verbose: + print('\nGenerate:', trim_decode_tokens) + + if return_end_reason: + return trim_decode_tokens, end_reason + else: + return trim_decode_tokens + + +def decode_tokens( + tokens: Union[torch.LongTensor, TokensType], + tokenizer: PreTrainedTokenizer, + raw_text_len: int, + context_length: int, + chat_format: str, + verbose: bool = False, + return_end_reason: bool = False, + errors: str = 'replace', +) -> str: + if torch.is_tensor(tokens): + tokens = tokens.cpu().numpy().tolist() + + if chat_format == 'chatml': + return _decode_chatml( + tokens, + stop_words=[], + eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id], + tokenizer=tokenizer, + raw_text_len=raw_text_len, + context_length=context_length, + verbose=verbose, + return_end_reason=return_end_reason, + errors=errors, + ) + elif chat_format == 'raw': + return _decode_default( + tokens, + stop_words=['<|endoftext|>'], + eod_words=['<|endoftext|>'], + tokenizer=tokenizer, + raw_text_len=raw_text_len, + verbose=verbose, + return_end_reason=return_end_reason, + errors=errors, + ) + else: + raise NotImplementedError(f'Unknown chat format {chat_format!r}') diff --git a/opencompass/multimodal/models/qwen/post_processor.py b/opencompass/multimodal/models/qwen/post_processor.py new file mode 100644 index 00000000..4382622f --- /dev/null +++ b/opencompass/multimodal/models/qwen/post_processor.py @@ -0,0 +1,16 @@ +from typing import Any + +import torch + + +class QwenVLBasePostProcessor: + """Post processor for Qwen-VL-Base.""" + + def __init__(self) -> None: + pass + + def __call__(self, pred: torch.tensor, tokenizer: Any, + input_len: int) -> str: + response = self.tokenizer.decode(pred)[input_len:] + response = response.replace('<|endoftext|>', '').strip() + return response diff --git a/opencompass/multimodal/models/qwen/prompt_constructor.py b/opencompass/multimodal/models/qwen/prompt_constructor.py new file mode 100644 index 00000000..476e1958 --- /dev/null +++ b/opencompass/multimodal/models/qwen/prompt_constructor.py @@ -0,0 +1,29 @@ +class QwenVLMMBenchPromptConstructor: + """MMBench prompt constructor for Qwen-VL. + + The output is a dict following the input format of Qwen-VL tokenizer. + """ + + def __init__(self) -> None: + pass + + def __call__(self, inputs: dict) -> str: + data_samples = inputs['data_samples'] + assert len(data_samples) == 1 + data_sample = data_samples[0] + question = data_sample.get('question') + options = data_sample.get('options') + context = data_sample.get('context') + if context is not None: + prompt = context + ' ' + question + ' ' + options + else: + prompt = question + ' ' + options + format_input = [ + { + 'image': 'This_is_path_to_an_image.' + }, # Just placeholder for Image Tokens + { + 'text': prompt + }, + ] + return format_input diff --git a/opencompass/multimodal/models/qwen/qwen.py b/opencompass/multimodal/models/qwen/qwen.py new file mode 100644 index 00000000..9682b5c9 --- /dev/null +++ b/opencompass/multimodal/models/qwen/qwen.py @@ -0,0 +1,324 @@ +import types +from typing import Optional, Tuple + +import mmengine +import torch +import torch.nn as nn +from mmengine.device import get_device +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig +from transformers.modeling_outputs import BaseModelOutputWithPast + +from opencompass.registry import MM_MODELS + +from .generation_utils import decode_tokens, make_context + + +@MM_MODELS.register_module('qwen-vl-base') +class QwenVLBase(nn.Module): + """Inference code of Qwen-VL. + + We load the Qwen model via Huggingface. + Args: + pretrained_path (str): Path to Qwen checkpoint or repo id. + prompt_constructor (dict): The config of prompt constructor. + post_processor (dict): The config of post processor. + is_caption_task (bool): Whether the task is caption task. + Defaults to False. + commit_id (str): Use given version of Qwen-VL. + Warning: the latest version may have some conflicts. + Recommend to use the given default version. + """ + + def __init__( + self, + pretrained_path: str, + prompt_constructor: dict = None, + post_processor: dict = None, + is_caption_task: bool = False, + commit_id: str = '548275c8b99de56dec203c0e793be18e030f2f4c' + ) -> None: + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path, + trust_remote_code=True, + revision=commit_id) + self.model = AutoModelForCausalLM.from_pretrained( + pretrained_path, + device_map=get_device(), + trust_remote_code=True, + revision=commit_id) + self.model.generation_config = GenerationConfig.from_pretrained( + pretrained_path, trust_remote_code=True, revision=commit_id) + if prompt_constructor is not None: + self.prompt_constructor = mmengine.registry.build_from_cfg( + prompt_constructor, MM_MODELS) + if post_processor is not None: + self.post_processor = mmengine.registry.build_from_cfg( + post_processor, MM_MODELS) + self.is_caption_task = is_caption_task + self.model.transformer.forward = types.MethodType( + forward_hack, self.model.transformer) + + def _build_embeds(self, images, input_ids): + # encode image + images = self.model.transformer.visual(images) + # compute image position + bos_pos = torch.where(input_ids == self.model.transformer.config. + visual['image_start_id']) + eos_pos = torch.where( + input_ids == + self.model.transformer.config.visual['image_start_id'] + 1) + assert (bos_pos[0] == eos_pos[0]).all() + img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) + # embed words + inputs_embeds = self.model.transformer.wte(input_ids) + # embed image tokens + for idx, (i, a, b) in enumerate(img_pos): + inputs_embeds[i][a + 1:b] = images[idx] + return inputs_embeds + + def generate(self, batch): + images = batch.pop('inputs') + images = torch.stack(images, dim=0) + format_input = self.prompt_constructor(batch) + query = self.tokenizer.from_list_format(format_input) + + inputs = self.tokenizer(query, return_tensors='pt') + inputs = inputs.to(get_device()) + input_ids, token_type_ids, attention_mask = inputs[ + 'input_ids'], inputs['token_type_ids'], inputs['attention_mask'] + inputs_embeds = self._build_embeds(images, input_ids) + pred = self.model.generate(input_ids=input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + token_type_ids=token_type_ids) + response = self.post_processor(pred.cpu()[0]) + + data_sample = batch['data_samples'][0] + if self.is_caption_task: + data_sample.pred_caption = response + else: + data_sample.pred_answer = response + return data_sample + + def forward(self, batch): + return self.generate(batch) + + +@MM_MODELS.register_module('qwen-vl-chat') +class QwenVLChat(QwenVLBase): + """Inference code of Qwen-VL-Chat. + + We load the Qwen model via Huggingface. + Args: + pretrained_path (str): Path to Qwen checkpoint or repo id. + prompt_constructor (dict): The config of prompt constructor. + post_processor (dict): The config of post processor. + is_caption_task (bool): Whether the task is caption task. + Defaults to False. + """ + + def __init__(self, + pretrained_path: str, + prompt_constructor: dict = None, + post_processor: dict = None, + is_caption_task: bool = False) -> None: + super().__init__(pretrained_path, prompt_constructor, post_processor, + is_caption_task) + + def generate(self, batch): + images = batch.pop('inputs') + images = torch.stack(images, dim=0) + format_input = self.prompt_constructor(batch) + query = self.tokenizer.from_list_format(format_input) + + raw_text, context_tokens = make_context( + self.tokenizer, + query, + system='You are a helpful assistant.', + chat_format=self.model.generation_config.chat_format, + ) + + input_ids = torch.tensor([context_tokens]).to(get_device()) + + inputs_embeds = self._build_embeds(images, input_ids) + pred = self.model.generate(input_ids=input_ids, + inputs_embeds=inputs_embeds) + + response = decode_tokens( + pred[0], + self.tokenizer, + raw_text_len=len(raw_text), + context_length=len(context_tokens), + chat_format=self.model.generation_config.chat_format, + verbose=False, + errors='replace') + + data_sample = batch['data_samples'][0] + if self.is_caption_task: + data_sample.pred_caption = response + else: + data_sample.pred_answer = response + return data_sample + + +def forward_hack(self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None): + if past_key_values is None and input_ids is not None and torch.any( + input_ids == self.config.visual['image_start_id']): + bos_pos = torch.where( + input_ids == self.config.visual['image_start_id']) + eos_pos = torch.where( + input_ids == self.config.visual['image_start_id'] + 1) + assert (bos_pos[0] == eos_pos[0]).all() + img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) + images = [] + for i, a, b in img_pos: + image = input_ids[i][a + 1:b - 1].tolist() + image = image[:image.index(self.config.visual['image_start_id'] + + 2)] + images.append(bytes(image).decode('utf-8')) + + images = self.visual.encode(images) + assert images.shape[0] == len(images) + else: + images = None + + output_attentions = (output_attentions if output_attentions is not None + else self.config.output_attentions) + output_hidden_states = (output_hidden_states if output_hidden_states + is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict + if return_dict is not None else self.config.use_return_dict) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' # noqa + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + encoder_attention_mask = None + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + if batch_size <= 0: + raise ValueError('batch_size has to be defined and > 0') + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_length) + + hidden_states = inputs_embeds + + hidden_states = self.drop(hidden_states) + if images is not None: + for idx, (i, a, b) in enumerate(img_pos): + hidden_states[i][a + 1:b] = images[idx] + output_shape = input_shape + (hidden_states.size(-1), ) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[2 if output_attentions else 1], ) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[1], ) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states, ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states] + if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + )