feat(hairuo): add hairuo file

This commit is contained in:
jiangzhs 2025-01-14 17:47:08 +08:00
commit f51b7719b8
8 changed files with 153344 additions and 0 deletions

67
configuration_hairuo.py Normal file
View File

@ -0,0 +1,67 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from transformers.configuration_utils import PretrainedConfig
class HairuoConfig(PretrainedConfig):
model_type = "hairuo"
keys_to_ignore_at_inference = ["past_key_values"]
_auto_class = "AutoConfig"
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=8.4e5,
rope_scaling=None,
attention_dropout=0.0,
mup_scale_emb=1,
mup_scale_depth=32,
mup_scale_width=1,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout
self.mup_scale_emb = mup_scale_emb
self.mup_scale_depth = mup_scale_depth
self.mup_scale_width = mup_scale_width
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

151387
merges.txt Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,177 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from typing import Optional
from typing import Tuple
import torch
import torch.nn.functional as F
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
from flash_attn import flash_attn_func
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis
from flash_attn.bert_padding import pad_input
from flash_attn.bert_padding import unpad_input
def _get_unpad_data(
attention_mask: torch.Tensor, cu_seqlens: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor, int]:
if cu_seqlens is not None:
max_seqlen_in_batch = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]).item()
indices = torch.arange(0, cu_seqlens[-1].item(), device=cu_seqlens.device)
else:
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch
def _unpad_input(
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
cu_seqlens,
):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask, cu_seqlens)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
# There is a memcpy here, that is very bad.
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device)
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
def prepare_fa2_from_position_ids(query, key, value, position_ids):
query = query.view(-1, query.size(-2), query.size(-1))
key = key.view(-1, key.size(-2), key.size(-1))
value = value.view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)
max_length = position_ids.max() + 1
return query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)
def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
cu_seqlens=None,
):
if not use_top_left_mask:
causal = is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
# For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
causal = is_causal and query_length != 1
if deterministic is None:
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
flash_kwargs = {"deterministic": deterministic}
if softcap is not None:
flash_kwargs["softcap"] = softcap
if attention_mask is not None or cu_seqlens is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _unpad_input(
query_states, key_states, value_states, attention_mask, query_length, cu_seqlens
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all() and query_length != 1:
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
)
return attn_output

1008
modeling_hairuo.py Normal file

File diff suppressed because it is too large Load Diff

241
tokenization_hairuo.py Normal file
View File

@ -0,0 +1,241 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json
import os
import unicodedata
from functools import lru_cache
from typing import Optional
from typing import Tuple
import regex as re
from transformers.tokenization_utils import AddedToken
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
MAX_MODEL_INPUT_SIZES = {"hairuo/hairuo-tokenizer": 32768}
@lru_cache()
def bytes_to_unicode():
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class HairuoTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
_auto_class = "AutoTokenizer"
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|end_of_text|>",
bos_token=None,
eos_token="<|end_of_text|>",
pad_token="<|end_of_text|>",
clean_up_tokenization_spaces=False,
split_special_tokens=False,
**kwargs,
):
bos_token = (
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(bos_token, str)
else bos_token
)
eos_token = (
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(eos_token, str)
else eos_token
)
unk_token = (
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(unk_token, str)
else unk_token
)
pad_token = (
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(pad_token, str)
else pad_token
)
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_merges = []
with open(merges_file, encoding="utf-8") as merges_handle:
for i, line in enumerate(merges_handle):
line = line.strip()
if (i == 0 and line.startswith("#version:")) or not line:
continue
bpe_merges.append(tuple(line.split()))
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.pat = re.compile(
r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
)
if kwargs.get("add_prefix_space", False):
logger.warning_once(
f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
)
super().__init__(
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
unk_token=unk_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
split_special_tokens=split_special_tokens,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def _tokenize(self, text, **kwargs):
bpe_tokens = []
for token in self.pat.findall(text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def _convert_token_to_id(self, token):
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
return self.decoder.get(index)
def convert_tokens_to_string(self, tokens):
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: Optional[bool] = False,
spaces_between_special_tokens: bool = False,
**kwargs,
) -> str:
return super().decode(
token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
spaces_between_special_tokens=spaces_between_special_tokens,
**kwargs,
)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
def prepare_for_tokenization(self, text, **kwargs):
text = unicodedata.normalize("NFC", text)
return text, kwargs

14
vllm_hairuo.py Normal file
View File

@ -0,0 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from vllm import ModelRegistry
from vllm_modeling_hairuo import HairuoForCausalLM
ModelRegistry.register_model("HairuoForCausalLM", HairuoForCausalLM)
if __name__ == "__main__":
import runpy
runpy.run_module("vllm.entrypoints.openai.api_server", run_name="__main__")

449
vllm_modeling_hairuo.py Normal file
View File

@ -0,0 +1,449 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import math
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
from vllm.model_executor.models import SupportsLoRA
from vllm.model_executor.models import SupportsPP
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.model_executor.models.utils import make_empty_intermediate_tensors_factory
from vllm.model_executor.models.utils import make_layers
from vllm.model_executor.models.utils import maybe_prefix
from vllm.sequence import IntermediateTensors
class HairuoMLP(nn.Module):
def __init__(
self,
hidden_size,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj"
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.down_proj"
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class HairuoAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
o_dtype = q.dtype
q, k = self.rotary_emb(positions, q.float(), k.float())
q, k = q.to(o_dtype), k.to(o_dtype)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class HairuoDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.cache_config = cache_config
self.quant_config = quant_config
self.hidden_size = config.hidden_size
self.rope_theta = getattr(config, "rope_theta", 10000.0)
self.rope_scaling = getattr(config, "rope_scaling", None)
if self.rope_scaling is not None and getattr(config, "original_max_position_embeddings", None):
self.rope_scaling["original_max_position_embeddings"] = config.original_max_position_embeddings
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = HairuoAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads),
rope_theta=self.rope_theta,
rope_scaling=self.rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = HairuoMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mup_scale_hidden_states = config.mup_scale_depth / math.sqrt(config.num_hidden_layers)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata
)
hidden_states = residual + hidden_states * self.mup_scale_hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states * self.mup_scale_hidden_states
return hidden_states, None
class HairuoModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
prefix=f"{prefix}.embed_tokens",
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: HairuoDecoderLayer(
config=config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.layers"
),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids)
return embedding * self.config.mup_scale_emb
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states, "residual": residual})
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
class HairuoForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.model = HairuoModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states / self.config.mup_scale_width
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
loader.load_weights(weights)

1
vocab.json Normal file

File diff suppressed because one or more lines are too long