From f33e706dd102248d4963f7914901ddcd45c7257b Mon Sep 17 00:00:00 2001 From: jiangzhs Date: Fri, 25 Oct 2024 17:22:35 +0800 Subject: [PATCH] feat: delete some unused files --- ihp/config/data/cfg.hairuo.2b.efficiency.json | 24 - ihp/config/data/cfg.hairuo.2b.stable.json | 24 - ihp/config/model/cfg.hairuo.2b.json | 37 - ihp/model/__init__.py | 147 ---- ihp/model/loss.py | 22 - ihp/zoo/llama/__init__.py | 10 - ihp/zoo/llama/modeling_llama.py | 704 ------------------ ihp/zoo/qwen/__init__.py | 10 - ihp/zoo/qwen/modeling_qwen2.py | 698 ----------------- 9 files changed, 1676 deletions(-) delete mode 100644 ihp/config/data/cfg.hairuo.2b.efficiency.json delete mode 100644 ihp/config/data/cfg.hairuo.2b.stable.json delete mode 100644 ihp/config/model/cfg.hairuo.2b.json delete mode 100644 ihp/model/__init__.py delete mode 100644 ihp/model/loss.py delete mode 100644 ihp/zoo/llama/__init__.py delete mode 100644 ihp/zoo/llama/modeling_llama.py delete mode 100644 ihp/zoo/qwen/__init__.py delete mode 100644 ihp/zoo/qwen/modeling_qwen2.py diff --git a/ihp/config/data/cfg.hairuo.2b.efficiency.json b/ihp/config/data/cfg.hairuo.2b.efficiency.json deleted file mode 100644 index 9647d8c..0000000 --- a/ihp/config/data/cfg.hairuo.2b.efficiency.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "wudao": { - "group": "wudao", - "name": "wudao", - "epoch": 1, - "path": "wudao", - "strategy": { - "st_segment": "naive", - "st_tokenize": "legacy" - }, - "weight": 0.5 - }, - "zwjcylk": { - "group": "zwjcylk", - "name": "zwjcylk", - "epoch": 1, - "path": "zwjcylk", - "strategy": { - "st_segment": "naive", - "st_tokenize": "legacy" - }, - "weight": 0.5 - } -} \ No newline at end of file diff --git a/ihp/config/data/cfg.hairuo.2b.stable.json b/ihp/config/data/cfg.hairuo.2b.stable.json deleted file mode 100644 index 9647d8c..0000000 --- a/ihp/config/data/cfg.hairuo.2b.stable.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "wudao": { - "group": "wudao", - "name": "wudao", - "epoch": 1, - "path": "wudao", - "strategy": { - "st_segment": "naive", - "st_tokenize": "legacy" - }, - "weight": 0.5 - }, - "zwjcylk": { - "group": "zwjcylk", - "name": "zwjcylk", - "epoch": 1, - "path": "zwjcylk", - "strategy": { - "st_segment": "naive", - "st_tokenize": "legacy" - }, - "weight": 0.5 - } -} \ No newline at end of file diff --git a/ihp/config/model/cfg.hairuo.2b.json b/ihp/config/model/cfg.hairuo.2b.json deleted file mode 100644 index 6d59827..0000000 --- a/ihp/config/model/cfg.hairuo.2b.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "config_clz": "hairuo.HairuoConfig", - "config": { - "attention_dropout": 0.0, - "bos_token_id": 151643, - "eos_token_id": 151645, - "hidden_act": "silu", - "hidden_size": 2304, - "initializer_range": 0.02, - "intermediate_size": 5760, - "max_position_embeddings": 131072, - "model_type": "hairuo", - "num_attention_heads": 36, - "num_hidden_layers": 40, - "num_key_value_heads": 36, - "rms_norm_eps": 1e-05, - "rope_scaling": { - "factor": 8.0, - "high_freq_factor": 4.0, - "low_freq_factor": 1.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3" - }, - "rope_theta": 500000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "use_cache": true, - "vocab_size": 152064, - "_attn_implementation": "eager", - "_flash_attn_2_enabled": false, - "mup_scale_emb": 12, - "mup_scale_depth": 1.4, - "mup_scale_width": 9.0 - }, - "model_clz": "hairuo.HairuoForCausalLM", - "tokenizer_clz": "hairuo.HairuoTokenizer" -} diff --git a/ihp/model/__init__.py b/ihp/model/__init__.py deleted file mode 100644 index 9ff039d..0000000 --- a/ihp/model/__init__.py +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright @2024 AI. Inspur Inc. -# -# @author: sunxian -# @date: 2024/07/21 -# - -import json -import math -import tempfile -from contextlib import nullcontext - -import torch -from colossalai.lazy import LazyInitContext -from colossalai.moe.utils import skip_init -from colossalai.utils import get_current_device - -from ihp.optim import get_lr_scheduler -from ihp.optim import get_optimizer -from ihp.util import importer -from ihp.util.booster import get_booster -from ihp.util.io import save_json -from ihp.util.logger import logger -from ihp.util.metric import format_numel -from ihp.util.metric import get_model_numel - - -def create_and_load_model( - storage, coordinator, args, config, with_optimizer, with_scheduler, extra_config=None, extra_info="" -): - rank = coordinator.rank - logger.info(f"rank-{rank} -> init {extra_info} model") - - default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 - torch.set_default_dtype(default_dtype) - - if args.use_lazy_ctx: - init_ctx = LazyInitContext(default_device=get_current_device()) - elif args.use_skip_ctx: - init_ctx = skip_init() - else: - init_ctx = nullcontext() - - if coordinator.is_master(): - logger.info(f"init ctx: {init_ctx}") - - with init_ctx: - logger.info(f"rank-{rank} -> [start]init config({args.config_clz}) from {config}") - config_clz = importer.from_name_to_clz(args.config_clz) - if isinstance(config, dict): - with tempfile.NamedTemporaryFile() as tmp: - save_json(config, tmp.name) - model_config = config_clz.from_pretrained(tmp.name, trust_remote_code=True) - else: - model_config = config_clz.from_pretrained(config, trust_remote_code=True) - - if extra_config is None: - extra_config = {} - - if hasattr(model_config, "_flash_attn_2_enabled"): - extra_config["_flash_attn_2_enabled"] = args.use_flash_attn - if hasattr(model_config, "_attn_implementation"): - extra_config["_attn_implementation"] = "flash_attention_2" if args.use_flash_attn else "eager" - if hasattr(model_config, "use_cache"): - extra_config["use_cache"] = not args.use_flash_attn - if hasattr(model_config, "output_router_logits"): - extra_config["output_router_logits"] = True - if hasattr(model_config, "router_aux_loss_coef") and args.router_aux_loss_coef > 0.0: - extra_config["router_aux_loss_coef"] = args.router_aux_loss_coef - if hasattr(model_config, "initializer_range"): - extra_config["initializer_range"] = args.init_std - - if not args.use_mup: - args.mup_scale_emb = 1.0 - args.mup_scale_depth = math.sqrt(model_config.num_hidden_layers) - args.mup_scale_width = 1.0 - else: - if hasattr(model_config, "mup_scale_emb"): - args.mup_scale_emb = args.mup_scale_emb or model_config.mup_scale_emb - if hasattr(model_config, "mup_scale_depth"): - args.mup_scale_depth = args.mup_scale_depth or model_config.mup_scale_depth - if hasattr(model_config, "mup_scale_width"): - args.mup_scale_width = args.mup_scale_width or model_config.mup_scale_width - - extra_config["mup_scale_emb"] = args.mup_scale_emb - extra_config["mup_scale_depth"] = args.mup_scale_depth - extra_config["mup_scale_width"] = args.mup_scale_width - - if isinstance(extra_config, dict): - for k, v in extra_config.items(): - if hasattr(model_config, k): - model_config.__setattr__(k, v) - - logger.info(f"rank-{rank} -> [start]init model({args.model_clz}) with {model_config}") - model_clz = importer.from_name_to_clz(args.model_clz) - if hasattr(model_clz, "from_config"): - model = model_clz.from_config( - model_config, use_flash_attention_2=args.use_flash_attn, trust_remote_code=True - ) - else: - model = model_clz(model_config) - - if args.grad_checkpointing: - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": args.use_reentrant}) - - architecture = model.__class__.__name__ - model_numel, non_embed_numel, trainable_numel = get_model_numel(model) - if coordinator.is_master(): - args_dict = {key: value for key, value in args.__dict__.items() if key != "dataset"} - logger.info( - f"{extra_info} model specs architecture: {architecture}" - f", parameters full: {format_numel(model_numel)}" - f", non-embed: {format_numel(non_embed_numel)}" - f", trainable: {format_numel(trainable_numel)}" - f", args: {json.dumps(args_dict, ensure_ascii=False)}" - ) - - optimizer = get_optimizer(args, model) if with_optimizer else None - lr_scheduler = get_lr_scheduler(args, optimizer) if with_scheduler else None - booster = get_booster(args) - - model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler) - - st_step, extra_states = 1, {} - - if args.load: - states = storage.load( - booster, - model, - optimizer, - lr_scheduler, - args.load, - coordinator.rank, - args.reset_states, - not args.not_use_strict, - ) - st_step, extra_states = states - - # fmt: off - return ( - booster, model, optimizer, lr_scheduler, - architecture, model_config, model_numel, - st_step, extra_states - ) - # fmt: on diff --git a/ihp/model/loss.py b/ihp/model/loss.py deleted file mode 100644 index 8a7b93a..0000000 --- a/ihp/model/loss.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright @2024 AI. Inspur Inc. -# -# @author: sunxian -# @date: 2024/07/21 -# - -import torch -from hakkero.dataset import IGNORE_INDEX - - -def lm_cross_entropy(logits, labels, reduction="mean"): - # we do not do the stupid shifting as in huggingface since we shift it in dataset - logits = logits.view(-1, logits.shape[-1]) - labels = labels.view(-1) - return torch.nn.functional.cross_entropy(logits, labels, ignore_index=IGNORE_INDEX, reduction=reduction) - - -def lm_z_loss(logits): - return logits.max(-1).values().square().mean() diff --git a/ihp/zoo/llama/__init__.py b/ihp/zoo/llama/__init__.py deleted file mode 100644 index a8724ee..0000000 --- a/ihp/zoo/llama/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright @2024 AI. Inspur Inc. -# -# @author: jiangzhs -# @date: 2024/10/10 -# - -from ihp.zoo.llama.modeling_llama import LlamaForCausalLM diff --git a/ihp/zoo/llama/modeling_llama.py b/ihp/zoo/llama/modeling_llama.py deleted file mode 100644 index b69b02c..0000000 --- a/ihp/zoo/llama/modeling_llama.py +++ /dev/null @@ -1,704 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright @2024 AI. Inspur Inc. -# -# @author: jiangzhs -# @date: 2024/10/10 -# - -import math -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss -from transformers.cache_utils import Cache -from transformers.cache_utils import DynamicCache -from transformers.cache_utils import StaticCache -from transformers.generation import GenerationMixin -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb -from transformers.models.llama.modeling_llama import LlamaMLP -from transformers.models.llama.modeling_llama import LlamaPreTrainedModel -from transformers.models.llama.modeling_llama import LlamaRMSNorm -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding -from transformers.models.llama.modeling_llama import repeat_kv -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import add_start_docstrings_to_model_forward -from transformers.utils import is_flash_attn_greater_or_equal_2_10 -from transformers.utils import logging - -from ihp.zoo.modeling_flash_attention_utils import _flash_attention_forward - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - -ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) - - -class LlamaAttention(nn.Module): - - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - - # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) - self.rotary_emb = LlamaRotaryEmbedding(config=self.config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - cu_seqlens: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaFlashAttention2(LlamaAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - cu_seqlens: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - cu_seqlens=cu_seqlens, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -LLAMA_ATTENTION_CLASSES = { - "eager": LlamaAttention, - "flash_attention_2": LlamaFlashAttention2, -} - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - cu_seqlens: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - cu_seqlens=cu_seqlens, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class LlamaModel(LlamaPreTrainedModel): - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = LlamaRotaryEmbedding(config=config) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - cu_seqlens, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - cu_seqlens=cu_seqlens, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - ): - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - cu_seqlens: Optional[torch.Tensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - cu_seqlens=cu_seqlens, - ) - - hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/ihp/zoo/qwen/__init__.py b/ihp/zoo/qwen/__init__.py deleted file mode 100644 index a8078a5..0000000 --- a/ihp/zoo/qwen/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright @2024 AI. Inspur Inc. -# -# @author: jiangzhs -# @date: 2024/10/08 -# - -from ihp.zoo.qwen.modeling_qwen2 import Qwen2ForCausalLM diff --git a/ihp/zoo/qwen/modeling_qwen2.py b/ihp/zoo/qwen/modeling_qwen2.py deleted file mode 100644 index 0edb3bc..0000000 --- a/ihp/zoo/qwen/modeling_qwen2.py +++ /dev/null @@ -1,698 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright @2024 AI. Inspur Inc. -# -# @author: jiangzhs -# @date: 2024/10/08 -# - -import math -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss -from transformers import Cache -from transformers import DynamicCache -from transformers import StaticCache -from transformers.generation import GenerationMixin -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.qwen2 import Qwen2Config -from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb -from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP -from transformers.models.qwen2.modeling_qwen2 import Qwen2PreTrainedModel -from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm -from transformers.models.qwen2.modeling_qwen2 import Qwen2RotaryEmbedding -from transformers.models.qwen2.modeling_qwen2 import repeat_kv -from transformers.utils import is_flash_attn_2_available -from transformers.utils import is_flash_attn_greater_or_equal_2_10 -from transformers.utils import logging - - -if is_flash_attn_2_available(): - from ihp.zoo.modeling_flash_attention_utils import _flash_attention_forward - - -logger = logging.get_logger(__name__) - - -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta" -_CONFIG_FOR_DOC = "Qwen2Config" - - -class Qwen2Attention(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - cu_seqlens: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2FlashAttention2(Qwen2Attention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - cu_seqlens: Optional[torch.Tensor] = None, - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - kv_seq_len = key_states.shape[-2] + cache_position[0] - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - else: - sliding_window = None - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - cu_seqlens=cu_seqlens, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -QWEN2_ATTENTION_CLASSES = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, -} - - -class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - cu_seqlens: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - cu_seqlens=cu_seqlens, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class Qwen2Model(Qwen2PreTrainedModel): - def __init__(self, config: Qwen2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen2RotaryEmbedding(config=config) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - cu_seqlens, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - cu_seqlens=cu_seqlens, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - ): - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - cu_seqlens: Optional[torch.Tensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - cu_seqlens=cu_seqlens, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - )