feat(hairuo): add hairuo file
This commit is contained in:
commit
f51b7719b8
67
configuration_hairuo.py
Normal file
67
configuration_hairuo.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoConfig(PretrainedConfig):
|
||||||
|
model_type = "hairuo"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
_auto_class = "AutoConfig"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=14336,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096 * 32,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=8.4e5,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
mup_scale_emb=1,
|
||||||
|
mup_scale_depth=32,
|
||||||
|
mup_scale_width=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
self.mup_scale_emb = mup_scale_emb
|
||||||
|
self.mup_scale_depth = mup_scale_depth
|
||||||
|
self.mup_scale_width = mup_scale_width
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
151387
merges.txt
Normal file
151387
merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
177
modeling_flash_attention_utils.py
Normal file
177
modeling_flash_attention_utils.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_func
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import index_first_axis
|
||||||
|
from flash_attn.bert_padding import pad_input
|
||||||
|
from flash_attn.bert_padding import unpad_input
|
||||||
|
|
||||||
|
|
||||||
|
def _get_unpad_data(
|
||||||
|
attention_mask: torch.Tensor, cu_seqlens: torch.Tensor = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||||
|
if cu_seqlens is not None:
|
||||||
|
max_seqlen_in_batch = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]).item()
|
||||||
|
indices = torch.arange(0, cu_seqlens[-1].item(), device=cu_seqlens.device)
|
||||||
|
else:
|
||||||
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
|
||||||
|
return indices, cu_seqlens, max_seqlen_in_batch
|
||||||
|
|
||||||
|
|
||||||
|
def _unpad_input(
|
||||||
|
query_layer: torch.Tensor,
|
||||||
|
key_layer: torch.Tensor,
|
||||||
|
value_layer: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
query_length: int,
|
||||||
|
cu_seqlens,
|
||||||
|
):
|
||||||
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask, cu_seqlens)
|
||||||
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||||
|
|
||||||
|
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
|
||||||
|
value_layer = index_first_axis(
|
||||||
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
|
||||||
|
if query_length == kv_seq_len:
|
||||||
|
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
|
||||||
|
cu_seqlens_q = cu_seqlens_k
|
||||||
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||||
|
indices_q = indices_k
|
||||||
|
elif query_length == 1:
|
||||||
|
max_seqlen_in_batch_q = 1
|
||||||
|
# There is a memcpy here, that is very bad.
|
||||||
|
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device)
|
||||||
|
indices_q = cu_seqlens_q[:-1]
|
||||||
|
query_layer = query_layer.squeeze(1)
|
||||||
|
else:
|
||||||
|
# The -q_len: slice assumes left padding.
|
||||||
|
attention_mask = attention_mask[:, -query_length:]
|
||||||
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||||
|
|
||||||
|
return (
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
indices_q,
|
||||||
|
(cu_seqlens_q, cu_seqlens_k),
|
||||||
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
||||||
|
query = query.view(-1, query.size(-2), query.size(-1))
|
||||||
|
key = key.view(-1, key.size(-2), key.size(-1))
|
||||||
|
value = value.view(-1, value.size(-2), value.size(-1))
|
||||||
|
position_ids = position_ids.flatten()
|
||||||
|
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
||||||
|
|
||||||
|
cu_seq_lens = torch.cat(
|
||||||
|
(
|
||||||
|
indices_q[position_ids == 0],
|
||||||
|
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
max_length = position_ids.max() + 1
|
||||||
|
|
||||||
|
return query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)
|
||||||
|
|
||||||
|
|
||||||
|
def _flash_attention_forward(
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
query_length: int,
|
||||||
|
is_causal: bool,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
use_top_left_mask: bool = False,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
deterministic: bool = None,
|
||||||
|
cu_seqlens=None,
|
||||||
|
):
|
||||||
|
if not use_top_left_mask:
|
||||||
|
causal = is_causal
|
||||||
|
else:
|
||||||
|
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
|
||||||
|
# For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
|
||||||
|
causal = is_causal and query_length != 1
|
||||||
|
|
||||||
|
if deterministic is None:
|
||||||
|
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||||
|
flash_kwargs = {"deterministic": deterministic}
|
||||||
|
|
||||||
|
if softcap is not None:
|
||||||
|
flash_kwargs["softcap"] = softcap
|
||||||
|
|
||||||
|
if attention_mask is not None or cu_seqlens is not None:
|
||||||
|
batch_size = query_states.shape[0]
|
||||||
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _unpad_input(
|
||||||
|
query_states, key_states, value_states, attention_mask, query_length, cu_seqlens
|
||||||
|
)
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
**flash_kwargs,
|
||||||
|
)
|
||||||
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
|
elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all() and query_length != 1:
|
||||||
|
batch_size = query_states.size(0)
|
||||||
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
||||||
|
query_states, key_states, value_states, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
**flash_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
1008
modeling_hairuo.py
Normal file
1008
modeling_hairuo.py
Normal file
File diff suppressed because it is too large
Load Diff
241
tokenization_hairuo.py
Normal file
241
tokenization_hairuo.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import unicodedata
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
from transformers.tokenization_utils import AddedToken
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {
|
||||||
|
"vocab_file": "vocab.json",
|
||||||
|
"merges_file": "merges.txt",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MAX_MODEL_INPUT_SIZES = {"hairuo/hairuo-tokenizer": 32768}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def bytes_to_unicode():
|
||||||
|
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||||
|
cs = bs[:]
|
||||||
|
n = 0
|
||||||
|
for b in range(2**8):
|
||||||
|
if b not in bs:
|
||||||
|
bs.append(b)
|
||||||
|
cs.append(2**8 + n)
|
||||||
|
n += 1
|
||||||
|
cs = [chr(n) for n in cs]
|
||||||
|
return dict(zip(bs, cs))
|
||||||
|
|
||||||
|
|
||||||
|
def get_pairs(word):
|
||||||
|
pairs = set()
|
||||||
|
prev_char = word[0]
|
||||||
|
for char in word[1:]:
|
||||||
|
pairs.add((prev_char, char))
|
||||||
|
prev_char = char
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoTokenizer(PreTrainedTokenizer):
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
model_input_names = ["input_ids", "attention_mask"]
|
||||||
|
_auto_class = "AutoTokenizer"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
merges_file,
|
||||||
|
errors="replace",
|
||||||
|
unk_token="<|end_of_text|>",
|
||||||
|
bos_token=None,
|
||||||
|
eos_token="<|end_of_text|>",
|
||||||
|
pad_token="<|end_of_text|>",
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
split_special_tokens=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
bos_token = (
|
||||||
|
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||||
|
if isinstance(bos_token, str)
|
||||||
|
else bos_token
|
||||||
|
)
|
||||||
|
eos_token = (
|
||||||
|
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||||
|
if isinstance(eos_token, str)
|
||||||
|
else eos_token
|
||||||
|
)
|
||||||
|
unk_token = (
|
||||||
|
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||||
|
if isinstance(unk_token, str)
|
||||||
|
else unk_token
|
||||||
|
)
|
||||||
|
pad_token = (
|
||||||
|
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||||
|
if isinstance(pad_token, str)
|
||||||
|
else pad_token
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||||
|
self.encoder = json.load(vocab_handle)
|
||||||
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||||
|
self.errors = errors
|
||||||
|
self.byte_encoder = bytes_to_unicode()
|
||||||
|
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||||
|
bpe_merges = []
|
||||||
|
with open(merges_file, encoding="utf-8") as merges_handle:
|
||||||
|
for i, line in enumerate(merges_handle):
|
||||||
|
line = line.strip()
|
||||||
|
if (i == 0 and line.startswith("#version:")) or not line:
|
||||||
|
continue
|
||||||
|
bpe_merges.append(tuple(line.split()))
|
||||||
|
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
self.pat = re.compile(
|
||||||
|
r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||||
|
)
|
||||||
|
|
||||||
|
if kwargs.get("add_prefix_space", False):
|
||||||
|
logger.warning_once(
|
||||||
|
f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
errors=errors,
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
unk_token=unk_token,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
split_special_tokens=split_special_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return len(self.encoder)
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
return dict(self.encoder, **self.added_tokens_encoder)
|
||||||
|
|
||||||
|
def bpe(self, token):
|
||||||
|
if token in self.cache:
|
||||||
|
return self.cache[token]
|
||||||
|
word = tuple(token)
|
||||||
|
pairs = get_pairs(word)
|
||||||
|
|
||||||
|
if not pairs:
|
||||||
|
return token
|
||||||
|
|
||||||
|
while True:
|
||||||
|
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||||
|
if bigram not in self.bpe_ranks:
|
||||||
|
break
|
||||||
|
first, second = bigram
|
||||||
|
new_word = []
|
||||||
|
i = 0
|
||||||
|
while i < len(word):
|
||||||
|
try:
|
||||||
|
j = word.index(first, i)
|
||||||
|
except ValueError:
|
||||||
|
new_word.extend(word[i:])
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
new_word.extend(word[i:j])
|
||||||
|
i = j
|
||||||
|
|
||||||
|
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||||
|
new_word.append(first + second)
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
new_word.append(word[i])
|
||||||
|
i += 1
|
||||||
|
new_word = tuple(new_word)
|
||||||
|
word = new_word
|
||||||
|
if len(word) == 1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pairs = get_pairs(word)
|
||||||
|
word = " ".join(word)
|
||||||
|
self.cache[token] = word
|
||||||
|
return word
|
||||||
|
|
||||||
|
def _tokenize(self, text, **kwargs):
|
||||||
|
bpe_tokens = []
|
||||||
|
for token in self.pat.findall(text):
|
||||||
|
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||||
|
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
||||||
|
return bpe_tokens
|
||||||
|
|
||||||
|
def _convert_token_to_id(self, token):
|
||||||
|
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index):
|
||||||
|
return self.decoder.get(index)
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens):
|
||||||
|
text = "".join(tokens)
|
||||||
|
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
token_ids,
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
clean_up_tokenization_spaces: Optional[bool] = False,
|
||||||
|
spaces_between_special_tokens: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
return super().decode(
|
||||||
|
token_ids,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
|
if not os.path.isdir(save_directory):
|
||||||
|
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||||
|
return
|
||||||
|
|
||||||
|
vocab_file = os.path.join(
|
||||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||||
|
)
|
||||||
|
merge_file = os.path.join(
|
||||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||||
|
writer.write("#version: 0.2\n")
|
||||||
|
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||||
|
if index != token_index:
|
||||||
|
logger.warning(
|
||||||
|
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
||||||
|
" Please check that the tokenizer is not corrupted!"
|
||||||
|
)
|
||||||
|
index = token_index
|
||||||
|
writer.write(" ".join(bpe_tokens) + "\n")
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
return vocab_file, merge_file
|
||||||
|
|
||||||
|
def prepare_for_tokenization(self, text, **kwargs):
|
||||||
|
text = unicodedata.normalize("NFC", text)
|
||||||
|
return text, kwargs
|
14
vllm_hairuo.py
Normal file
14
vllm_hairuo.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from vllm import ModelRegistry
|
||||||
|
from vllm_modeling_hairuo import HairuoForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
ModelRegistry.register_model("HairuoForCausalLM", HairuoForCausalLM)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import runpy
|
||||||
|
|
||||||
|
runpy.run_module("vllm.entrypoints.openai.api_server", run_name="__main__")
|
449
vllm_modeling_hairuo.py
Normal file
449
vllm_modeling_hairuo.py
Normal file
@ -0,0 +1,449 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Iterable
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from vllm.attention import Attention
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import get_pp_group
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import MergedColumnParallelLinear
|
||||||
|
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||||
|
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
|
||||||
|
from vllm.model_executor.models import SupportsLoRA
|
||||||
|
from vllm.model_executor.models import SupportsPP
|
||||||
|
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||||
|
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||||
|
from vllm.model_executor.models.utils import make_empty_intermediate_tensors_factory
|
||||||
|
from vllm.model_executor.models.utils import make_layers
|
||||||
|
from vllm.model_executor.models.utils import maybe_prefix
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoMLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
intermediate_size, hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.down_proj"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.")
|
||||||
|
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
total_num_heads=self.total_num_heads,
|
||||||
|
total_num_kv_heads=self.total_num_kv_heads,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
input_size=self.total_num_heads * self.head_dim,
|
||||||
|
output_size=hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
|
o_dtype = q.dtype
|
||||||
|
q, k = self.rotary_emb(positions, q.float(), k.float())
|
||||||
|
q, k = q.to(o_dtype), k.to(o_dtype)
|
||||||
|
|
||||||
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoDecoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.rope_theta = getattr(config, "rope_theta", 10000.0)
|
||||||
|
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
if self.rope_scaling is not None and getattr(config, "original_max_position_embeddings", None):
|
||||||
|
self.rope_scaling["original_max_position_embeddings"] = config.original_max_position_embeddings
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||||
|
|
||||||
|
self.self_attn = HairuoAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads),
|
||||||
|
rope_theta=self.rope_theta,
|
||||||
|
rope_scaling=self.rope_scaling,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp = HairuoMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
self.mup_scale_hidden_states = config.mup_scale_depth / math.sqrt(config.num_hidden_layers)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states * self.mup_scale_hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states * self.mup_scale_hidden_states
|
||||||
|
|
||||||
|
return hidden_states, None
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoModel(nn.Module):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
prefix=f"{prefix}.embed_tokens",
|
||||||
|
)
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: HairuoDecoderLayer(
|
||||||
|
config=config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.layers"
|
||||||
|
),
|
||||||
|
prefix=f"{prefix}.layers",
|
||||||
|
)
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
embedding = self.embed_tokens(input_ids)
|
||||||
|
return embedding * self.config.mup_scale_emb
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i - self.start_layer],
|
||||||
|
attn_metadata,
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({"hidden_states": hidden_states, "residual": residual})
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_up_proj",
|
||||||
|
"down_proj",
|
||||||
|
"embed_tokens",
|
||||||
|
"lm_head",
|
||||||
|
]
|
||||||
|
embedding_modules = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings",
|
||||||
|
}
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.model = HairuoModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|
||||||
|
unpadded_vocab_size = config.vocab_size
|
||||||
|
if lora_config:
|
||||||
|
unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=(
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
# We need bigger padding if using lora for kernel
|
||||||
|
# compatibility
|
||||||
|
if not lora_config
|
||||||
|
else lora_config.lora_vocab_padding_size
|
||||||
|
),
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
hidden_states = hidden_states / self.config.mup_scale_width
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
|
||||||
|
loader.load_weights(weights)
|
1
vocab.json
Normal file
1
vocab.json
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user