#!/usr/bin/env python # -*- coding: utf-8 -*- import math from typing import Any from typing import Dict from typing import Iterable from typing import List from typing import Optional from typing import Tuple from typing import Union import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention from vllm.attention import AttentionMetadata from vllm.config import CacheConfig from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import MergedColumnParallelLinear from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name from vllm.model_executor.models import SupportsLoRA from vllm.model_executor.models import SupportsPP from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.models.utils import is_pp_missing_parameter from vllm.model_executor.models.utils import make_empty_intermediate_tensors_factory from vllm.model_executor.models.utils import make_layers from vllm.model_executor.models.utils import maybe_prefix from vllm.sequence import IntermediateTensors class HairuoMLP(nn.Module): def __init__( self, hidden_size, intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj" ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.down_proj" ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x class HairuoAttention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: assert self.total_num_kv_heads % tp_size == 0 else: assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) o_dtype = q.dtype q, k = self.rotary_emb(positions, q.float(), k.float()) q, k = q.to(o_dtype), k.to(o_dtype) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output class HairuoDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config = config self.cache_config = cache_config self.quant_config = quant_config self.hidden_size = config.hidden_size self.rope_theta = getattr(config, "rope_theta", 10000.0) self.rope_scaling = getattr(config, "rope_scaling", None) if self.rope_scaling is not None and getattr(config, "original_max_position_embeddings", None): self.rope_scaling["original_max_position_embeddings"] = config.original_max_position_embeddings max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = HairuoAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads), rope_theta=self.rope_theta, rope_scaling=self.rope_scaling, max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) self.mlp = HairuoMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mup_scale_hidden_states = config.mup_scale_depth / math.sqrt(config.num_hidden_layers) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata ) hidden_states = residual + hidden_states * self.mup_scale_hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states * self.mup_scale_hidden_states return hidden_states, None class HairuoModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: HairuoDecoderLayer( config=config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.layers" ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) return embedding * self.config.mup_scale_emb def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None else: hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i - self.start_layer], attn_metadata, residual, ) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states, "residual": residual}) hidden_states = self.norm(hidden_states) return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) class HairuoForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } # LoRA specific attributes supported_lora_modules = [ "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", "lm_head", ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.model = HairuoModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=( DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: hidden_states = hidden_states / self.config.mup_scale_width logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) loader.load_weights(weights)