feat: first submit
This commit is contained in:
parent
cb5abad7f8
commit
c644085c99
8
ihp/__init__.py
Normal file
8
ihp/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: sunxian <sunxian@inspur.com>
|
||||
# @date: 2024/07/18
|
||||
#
|
24
ihp/config/data/cfg.hairuo.2b.efficiency.json
Normal file
24
ihp/config/data/cfg.hairuo.2b.efficiency.json
Normal file
@ -0,0 +1,24 @@
|
||||
{
|
||||
"wudao": {
|
||||
"group": "wudao",
|
||||
"name": "wudao",
|
||||
"epoch": 1,
|
||||
"path": "wudao",
|
||||
"strategy": {
|
||||
"st_segment": "naive",
|
||||
"st_tokenize": "legacy"
|
||||
},
|
||||
"weight": 0.5
|
||||
},
|
||||
"zwjcylk": {
|
||||
"group": "zwjcylk",
|
||||
"name": "zwjcylk",
|
||||
"epoch": 1,
|
||||
"path": "zwjcylk",
|
||||
"strategy": {
|
||||
"st_segment": "naive",
|
||||
"st_tokenize": "legacy"
|
||||
},
|
||||
"weight": 0.5
|
||||
}
|
||||
}
|
24
ihp/config/data/cfg.hairuo.2b.stable.json
Normal file
24
ihp/config/data/cfg.hairuo.2b.stable.json
Normal file
@ -0,0 +1,24 @@
|
||||
{
|
||||
"wudao": {
|
||||
"group": "wudao",
|
||||
"name": "wudao",
|
||||
"epoch": 1,
|
||||
"path": "wudao",
|
||||
"strategy": {
|
||||
"st_segment": "naive",
|
||||
"st_tokenize": "legacy"
|
||||
},
|
||||
"weight": 0.5
|
||||
},
|
||||
"zwjcylk": {
|
||||
"group": "zwjcylk",
|
||||
"name": "zwjcylk",
|
||||
"epoch": 1,
|
||||
"path": "zwjcylk",
|
||||
"strategy": {
|
||||
"st_segment": "naive",
|
||||
"st_tokenize": "legacy"
|
||||
},
|
||||
"weight": 0.5
|
||||
}
|
||||
}
|
37
ihp/config/model/cfg.hairuo.2b.json
Normal file
37
ihp/config/model/cfg.hairuo.2b.json
Normal file
@ -0,0 +1,37 @@
|
||||
{
|
||||
"config_clz": "hairuo.HairuoConfig",
|
||||
"config": {
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2304,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 5760,
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "hairuo",
|
||||
"num_attention_heads": 36,
|
||||
"num_hidden_layers": 40,
|
||||
"num_key_value_heads": 36,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": {
|
||||
"factor": 8.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"original_max_position_embeddings": 8192,
|
||||
"rope_type": "llama3"
|
||||
},
|
||||
"rope_theta": 500000.0,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "bfloat16",
|
||||
"use_cache": true,
|
||||
"vocab_size": 152064,
|
||||
"_attn_implementation": "eager",
|
||||
"_flash_attn_2_enabled": false,
|
||||
"mup_scale_emb": 12,
|
||||
"mup_scale_depth": 1.4,
|
||||
"mup_scale_width": 9.0
|
||||
},
|
||||
"model_clz": "hairuo.HairuoForCausalLM",
|
||||
"tokenizer_clz": "hairuo.HairuoTokenizer"
|
||||
}
|
147
ihp/model/__init__.py
Normal file
147
ihp/model/__init__.py
Normal file
@ -0,0 +1,147 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: sunxian <sunxian@inspur.com>
|
||||
# @date: 2024/07/21
|
||||
#
|
||||
|
||||
import json
|
||||
import math
|
||||
import tempfile
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe.utils import skip_init
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from ihp.optim import get_lr_scheduler
|
||||
from ihp.optim import get_optimizer
|
||||
from ihp.util import importer
|
||||
from ihp.util.booster import get_booster
|
||||
from ihp.util.io import save_json
|
||||
from ihp.util.logger import logger
|
||||
from ihp.util.metric import format_numel
|
||||
from ihp.util.metric import get_model_numel
|
||||
|
||||
|
||||
def create_and_load_model(
|
||||
storage, coordinator, args, config, with_optimizer, with_scheduler, extra_config=None, extra_info=""
|
||||
):
|
||||
rank = coordinator.rank
|
||||
logger.info(f"rank-{rank} -> init {extra_info} model")
|
||||
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
if args.use_lazy_ctx:
|
||||
init_ctx = LazyInitContext(default_device=get_current_device())
|
||||
elif args.use_skip_ctx:
|
||||
init_ctx = skip_init()
|
||||
else:
|
||||
init_ctx = nullcontext()
|
||||
|
||||
if coordinator.is_master():
|
||||
logger.info(f"init ctx: {init_ctx}")
|
||||
|
||||
with init_ctx:
|
||||
logger.info(f"rank-{rank} -> [start]init config({args.config_clz}) from {config}")
|
||||
config_clz = importer.from_name_to_clz(args.config_clz)
|
||||
if isinstance(config, dict):
|
||||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
save_json(config, tmp.name)
|
||||
model_config = config_clz.from_pretrained(tmp.name, trust_remote_code=True)
|
||||
else:
|
||||
model_config = config_clz.from_pretrained(config, trust_remote_code=True)
|
||||
|
||||
if extra_config is None:
|
||||
extra_config = {}
|
||||
|
||||
if hasattr(model_config, "_flash_attn_2_enabled"):
|
||||
extra_config["_flash_attn_2_enabled"] = args.use_flash_attn
|
||||
if hasattr(model_config, "_attn_implementation"):
|
||||
extra_config["_attn_implementation"] = "flash_attention_2" if args.use_flash_attn else "eager"
|
||||
if hasattr(model_config, "use_cache"):
|
||||
extra_config["use_cache"] = not args.use_flash_attn
|
||||
if hasattr(model_config, "output_router_logits"):
|
||||
extra_config["output_router_logits"] = True
|
||||
if hasattr(model_config, "router_aux_loss_coef") and args.router_aux_loss_coef > 0.0:
|
||||
extra_config["router_aux_loss_coef"] = args.router_aux_loss_coef
|
||||
if hasattr(model_config, "initializer_range"):
|
||||
extra_config["initializer_range"] = args.init_std
|
||||
|
||||
if not args.use_mup:
|
||||
args.mup_scale_emb = 1.0
|
||||
args.mup_scale_depth = math.sqrt(model_config.num_hidden_layers)
|
||||
args.mup_scale_width = 1.0
|
||||
else:
|
||||
if hasattr(model_config, "mup_scale_emb"):
|
||||
args.mup_scale_emb = args.mup_scale_emb or model_config.mup_scale_emb
|
||||
if hasattr(model_config, "mup_scale_depth"):
|
||||
args.mup_scale_depth = args.mup_scale_depth or model_config.mup_scale_depth
|
||||
if hasattr(model_config, "mup_scale_width"):
|
||||
args.mup_scale_width = args.mup_scale_width or model_config.mup_scale_width
|
||||
|
||||
extra_config["mup_scale_emb"] = args.mup_scale_emb
|
||||
extra_config["mup_scale_depth"] = args.mup_scale_depth
|
||||
extra_config["mup_scale_width"] = args.mup_scale_width
|
||||
|
||||
if isinstance(extra_config, dict):
|
||||
for k, v in extra_config.items():
|
||||
if hasattr(model_config, k):
|
||||
model_config.__setattr__(k, v)
|
||||
|
||||
logger.info(f"rank-{rank} -> [start]init model({args.model_clz}) with {model_config}")
|
||||
model_clz = importer.from_name_to_clz(args.model_clz)
|
||||
if hasattr(model_clz, "from_config"):
|
||||
model = model_clz.from_config(
|
||||
model_config, use_flash_attention_2=args.use_flash_attn, trust_remote_code=True
|
||||
)
|
||||
else:
|
||||
model = model_clz(model_config)
|
||||
|
||||
if args.grad_checkpointing:
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": args.use_reentrant})
|
||||
|
||||
architecture = model.__class__.__name__
|
||||
model_numel, non_embed_numel, trainable_numel = get_model_numel(model)
|
||||
if coordinator.is_master():
|
||||
args_dict = {key: value for key, value in args.__dict__.items() if key != "dataset"}
|
||||
logger.info(
|
||||
f"{extra_info} model specs architecture: {architecture}"
|
||||
f", parameters full: {format_numel(model_numel)}"
|
||||
f", non-embed: {format_numel(non_embed_numel)}"
|
||||
f", trainable: {format_numel(trainable_numel)}"
|
||||
f", args: {json.dumps(args_dict, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
optimizer = get_optimizer(args, model) if with_optimizer else None
|
||||
lr_scheduler = get_lr_scheduler(args, optimizer) if with_scheduler else None
|
||||
booster = get_booster(args)
|
||||
|
||||
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
|
||||
|
||||
st_step, extra_states = 1, {}
|
||||
|
||||
if args.load:
|
||||
states = storage.load(
|
||||
booster,
|
||||
model,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
args.load,
|
||||
coordinator.rank,
|
||||
args.reset_states,
|
||||
not args.not_use_strict,
|
||||
)
|
||||
st_step, extra_states = states
|
||||
|
||||
# fmt: off
|
||||
return (
|
||||
booster, model, optimizer, lr_scheduler,
|
||||
architecture, model_config, model_numel,
|
||||
st_step, extra_states
|
||||
)
|
||||
# fmt: on
|
22
ihp/model/loss.py
Normal file
22
ihp/model/loss.py
Normal file
@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: sunxian <sunxian@inspur.com>
|
||||
# @date: 2024/07/21
|
||||
#
|
||||
|
||||
import torch
|
||||
from hakkero.dataset import IGNORE_INDEX
|
||||
|
||||
|
||||
def lm_cross_entropy(logits, labels, reduction="mean"):
|
||||
# we do not do the stupid shifting as in huggingface since we shift it in dataset
|
||||
logits = logits.view(-1, logits.shape[-1])
|
||||
labels = labels.view(-1)
|
||||
return torch.nn.functional.cross_entropy(logits, labels, ignore_index=IGNORE_INDEX, reduction=reduction)
|
||||
|
||||
|
||||
def lm_z_loss(logits):
|
||||
return logits.max(-1).values().square().mean()
|
8
ihp/zoo/__init__.py
Normal file
8
ihp/zoo/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: sunxian <sunxian@inspur.com>
|
||||
# @date: 2024/07/18
|
||||
#
|
15
ihp/zoo/hairuo/__init__.py
Normal file
15
ihp/zoo/hairuo/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: sunxian <sunxian@inspur.com>
|
||||
# @date: 2024/07/18
|
||||
#
|
||||
|
||||
from ihp.zoo.hairuo.configuration_hairuo import HairuoConfig
|
||||
from ihp.zoo.hairuo.modeling_hairuo import HairuoForCausalLM
|
||||
from ihp.zoo.hairuo.modeling_hairuo import HairuoForSequenceClassification
|
||||
from ihp.zoo.hairuo.modeling_hairuo import HairuoForTokenClassification
|
||||
from ihp.zoo.hairuo.modeling_hairuo import HairuoModel
|
||||
from ihp.zoo.hairuo.tokenization_hairuo import HairuoTokenizer
|
73
ihp/zoo/hairuo/configuration_hairuo.py
Normal file
73
ihp/zoo/hairuo/configuration_hairuo.py
Normal file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: sunxian <sunxian@inspur.com>
|
||||
# @date: 2024/07/18
|
||||
#
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class HairuoConfig(PretrainedConfig):
|
||||
model_type = "hairuo"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
_auto_class = "AutoConfig"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=14336,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096 * 32,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=8.4e5,
|
||||
rope_scaling=None,
|
||||
attention_dropout=0.0,
|
||||
mup_scale_emb=1,
|
||||
mup_scale_depth=32,
|
||||
mup_scale_width=1,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
self.mup_scale_emb = mup_scale_emb
|
||||
self.mup_scale_depth = mup_scale_depth
|
||||
self.mup_scale_width = mup_scale_width
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
183
ihp/zoo/hairuo/modeling_flash_attention_utils.py
Normal file
183
ihp/zoo/hairuo/modeling_flash_attention_utils.py
Normal file
@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: sunxian <sunxian@inspur.com>
|
||||
# @date: 2024/07/22
|
||||
#
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis
|
||||
from flash_attn.bert_padding import pad_input
|
||||
from flash_attn.bert_padding import unpad_input
|
||||
|
||||
|
||||
def _get_unpad_data(
|
||||
attention_mask: torch.Tensor, cu_seqlens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
if cu_seqlens is not None:
|
||||
max_seqlen_in_batch = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]).item()
|
||||
indices = torch.arange(0, cu_seqlens[-1].item(), device=cu_seqlens.device)
|
||||
else:
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
|
||||
return indices, cu_seqlens, max_seqlen_in_batch
|
||||
|
||||
|
||||
def _unpad_input(
|
||||
query_layer: torch.Tensor,
|
||||
key_layer: torch.Tensor,
|
||||
value_layer: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
cu_seqlens,
|
||||
):
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask, cu_seqlens)
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
elif query_length == 1:
|
||||
max_seqlen_in_batch_q = 1
|
||||
# There is a memcpy here, that is very bad.
|
||||
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device)
|
||||
indices_q = cu_seqlens_q[:-1]
|
||||
query_layer = query_layer.squeeze(1)
|
||||
else:
|
||||
# The -q_len: slice assumes left padding.
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
indices_q,
|
||||
(cu_seqlens_q, cu_seqlens_k),
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
|
||||
def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
||||
query = query.view(-1, query.size(-2), query.size(-1))
|
||||
key = key.view(-1, key.size(-2), key.size(-1))
|
||||
value = value.view(-1, value.size(-2), value.size(-1))
|
||||
position_ids = position_ids.flatten()
|
||||
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
||||
|
||||
cu_seq_lens = torch.cat(
|
||||
(
|
||||
indices_q[position_ids == 0],
|
||||
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
||||
)
|
||||
)
|
||||
|
||||
max_length = position_ids.max() + 1
|
||||
|
||||
return query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)
|
||||
|
||||
|
||||
def _flash_attention_forward(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
is_causal: bool,
|
||||
dropout: float = 0.0,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
use_top_left_mask: bool = False,
|
||||
softcap: Optional[float] = None,
|
||||
deterministic: bool = None,
|
||||
cu_seqlens=None,
|
||||
):
|
||||
if not use_top_left_mask:
|
||||
causal = is_causal
|
||||
else:
|
||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
|
||||
# For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
|
||||
causal = is_causal and query_length != 1
|
||||
|
||||
if deterministic is None:
|
||||
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
flash_kwargs = {"deterministic": deterministic}
|
||||
|
||||
if softcap is not None:
|
||||
flash_kwargs["softcap"] = softcap
|
||||
|
||||
if attention_mask is not None or cu_seqlens is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _unpad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length, cu_seqlens
|
||||
)
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
**flash_kwargs,
|
||||
)
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all() and query_length != 1:
|
||||
batch_size = query_states.size(0)
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
||||
query_states, key_states, value_states, position_ids
|
||||
)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
**flash_kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
||||
)
|
||||
|
||||
return attn_output
|
1039
ihp/zoo/hairuo/modeling_hairuo.py
Normal file
1039
ihp/zoo/hairuo/modeling_hairuo.py
Normal file
File diff suppressed because it is too large
Load Diff
247
ihp/zoo/hairuo/tokenization_hairuo.py
Normal file
247
ihp/zoo/hairuo/tokenization_hairuo.py
Normal file
@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: sunxian <sunxian@inspur.com>
|
||||
# @date: 2024/08/08
|
||||
#
|
||||
|
||||
import json
|
||||
import os
|
||||
import unicodedata
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import regex as re
|
||||
from transformers.tokenization_utils import AddedToken
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"vocab_file": "vocab.json",
|
||||
"merges_file": "merges.txt",
|
||||
}
|
||||
|
||||
|
||||
MAX_MODEL_INPUT_SIZES = {"hairuo/hairuo-tokenizer": 32768}
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class HairuoTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
_auto_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
errors="replace",
|
||||
unk_token="<|end_of_text|>",
|
||||
bos_token=None,
|
||||
eos_token="<|end_of_text|>",
|
||||
pad_token="<|end_of_text|>",
|
||||
clean_up_tokenization_spaces=False,
|
||||
split_special_tokens=False,
|
||||
**kwargs,
|
||||
):
|
||||
bos_token = (
|
||||
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||
if isinstance(bos_token, str)
|
||||
else bos_token
|
||||
)
|
||||
eos_token = (
|
||||
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||
if isinstance(eos_token, str)
|
||||
else eos_token
|
||||
)
|
||||
unk_token = (
|
||||
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||
if isinstance(unk_token, str)
|
||||
else unk_token
|
||||
)
|
||||
pad_token = (
|
||||
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||
if isinstance(pad_token, str)
|
||||
else pad_token
|
||||
)
|
||||
|
||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||
self.encoder = json.load(vocab_handle)
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
bpe_merges = []
|
||||
with open(merges_file, encoding="utf-8") as merges_handle:
|
||||
for i, line in enumerate(merges_handle):
|
||||
line = line.strip()
|
||||
if (i == 0 and line.startswith("#version:")) or not line:
|
||||
continue
|
||||
bpe_merges.append(tuple(line.split()))
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
|
||||
self.pat = re.compile(
|
||||
r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||
)
|
||||
|
||||
if kwargs.get("add_prefix_space", False):
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
errors=errors,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
unk_token=unk_token,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
split_special_tokens=split_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return len(self.encoder)
|
||||
|
||||
def get_vocab(self):
|
||||
return dict(self.encoder, **self.added_tokens_encoder)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
except ValueError:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
else:
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
bpe_tokens = []
|
||||
for token in self.pat.findall(text):
|
||||
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
||||
return bpe_tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
return self.decoder.get(index)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
text = "".join(tokens)
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
||||
return text
|
||||
|
||||
def decode(
|
||||
self,
|
||||
token_ids,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: Optional[bool] = False,
|
||||
spaces_between_special_tokens: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
return super().decode(
|
||||
token_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
|
||||
vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
merge_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
||||
)
|
||||
|
||||
with open(vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write("#version: 0.2\n")
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning(
|
||||
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!"
|
||||
)
|
||||
index = token_index
|
||||
writer.write(" ".join(bpe_tokens) + "\n")
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file
|
||||
|
||||
def prepare_for_tokenization(self, text, **kwargs):
|
||||
text = unicodedata.normalize("NFC", text)
|
||||
return text, kwargs
|
609
ihp/zoo/hairuo/vllm_hairuo.py
Normal file
609
ihp/zoo/hairuo/vllm_hairuo.py
Normal file
@ -0,0 +1,609 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: sunxian <sunxian@inspur.com>
|
||||
# @date: 2024/07/18
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import MSELoss
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
get_compressed_tensors_cache_scale)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
from .configuration_hairuo import HairuoConfig
|
||||
|
||||
|
||||
class HairuoMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
class HairuoAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: HairuoConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
is_neox_style = True
|
||||
if quant_config is not None and quant_config.get_name() == "gguf":
|
||||
is_neox_style = False
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache()
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
# orig_dtype = q.dtype
|
||||
# q, k = q.float(), k.float()
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
# q, k = q.to(orig_dtype), k.to(orig_dtype)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class HairuoDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HairuoConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.mup_scale_hidden_states = config.mup_scale_depth / math.sqrt(config.num_hidden_layers)
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False)
|
||||
self.self_attn = HairuoAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
config.num_attention_heads),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = HairuoMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
hidden_states = residual + hidden_states * self.mup_scale_hidden_states
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states= self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states * self.mup_scale_hidden_states
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class HairuoModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: HairuoConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||
and get_pp_group().is_last_rank):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: HairuoDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
# return self.embed_tokens(input_ids) * self.config.mup_scale_emb
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||
quantization_param_path, tp_rank, tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type):
|
||||
if not isinstance(self.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.layers[layer_idx].self_attn
|
||||
|
||||
if is_hip():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn, "kv_scale"):
|
||||
layer_self_attn.attn._kv_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
|
||||
class HairuoForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
|
||||
"lm_head"
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings"
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
mistral_mapping = {
|
||||
"layers": "model.layers",
|
||||
"attention": "self_attn",
|
||||
"wq": "q_proj",
|
||||
"wk": "k_proj",
|
||||
"wv": "v_proj",
|
||||
"wo": "o_proj",
|
||||
"attention_norm": "input_layernorm",
|
||||
"feed_forward": "mlp",
|
||||
"w1": "gate_proj",
|
||||
"w2": "down_proj",
|
||||
"w3": "up_proj",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"tok_embeddings": "model.embed_tokens",
|
||||
"output": "lm_head",
|
||||
"norm": "model.norm"
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HairuoConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = HairuoModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config = lora_config,
|
||||
prefix="model")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=(
|
||||
DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else
|
||||
lora_config.lora_vocab_padding_size),
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(
|
||||
self.model.embed_tokens)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logit_scale)
|
||||
self.sampler = Sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return model_output
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
hidden_states = hidden_states / self.config.mup_scale_width
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
loader.load_weights(
|
||||
self.maybe_remap_mistral(name, loaded_weight)
|
||||
for name, loaded_weight in weights)
|
||||
|
||||
|
||||
def maybe_remap_mistral(
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> Tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w: torch.Tensor, n_heads: int):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
attn_out = self.config.hidden_size
|
||||
|
||||
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||
|
||||
mapping = self.mistral_mapping
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if "wk" in modules:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
elif "wq" in modules:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
|
||||
for item in modules:
|
||||
if item in mapping and mapping[item] not in name:
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
669
ihp/zoo/hairuo/vllm_hairuo_v2.py
Normal file
669
ihp/zoo/hairuo/vllm_hairuo_v2.py
Normal file
@ -0,0 +1,669 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: sunxian <sunxian@inspur.com>
|
||||
# @date: 2024/07/18
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import MSELoss
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
get_compressed_tensors_cache_scale)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
from .configuration_hairuo import HairuoConfig
|
||||
|
||||
|
||||
class HairuoMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
class HairuoAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: HairuoConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
is_neox_style = True
|
||||
if quant_config is not None and quant_config.get_name() == "gguf":
|
||||
is_neox_style = False
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache()
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
# orig_dtype = q.dtype
|
||||
# q, k = q.float(), k.float()
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
# q, k = q.to(orig_dtype), k.to(orig_dtype)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class HairuoDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HairuoConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.mup_scale_hidden_states = config.mup_scale_depth / math.sqrt(config.num_hidden_layers)
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False)
|
||||
self.self_attn = HairuoAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
config.num_attention_heads),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = HairuoMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
hidden_states = residual + hidden_states * self.mup_scale_hidden_states
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states= self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states * self.mup_scale_hidden_states
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class HairuoModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: HairuoConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||
and get_pp_group().is_last_rank):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: HairuoDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids) * self.config.mup_scale_emb
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
"""
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||
quantization_param_path, tp_rank, tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type):
|
||||
if not isinstance(self.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.layers[layer_idx].self_attn
|
||||
|
||||
if is_hip():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn, "kv_scale"):
|
||||
layer_self_attn.attn._kv_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
"""
|
||||
|
||||
class HairuoForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
|
||||
"lm_head"
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings"
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
mistral_mapping = {
|
||||
"layers": "model.layers",
|
||||
"attention": "self_attn",
|
||||
"wq": "q_proj",
|
||||
"wk": "k_proj",
|
||||
"wv": "v_proj",
|
||||
"wo": "o_proj",
|
||||
"attention_norm": "input_layernorm",
|
||||
"feed_forward": "mlp",
|
||||
"w1": "gate_proj",
|
||||
"w2": "down_proj",
|
||||
"w3": "up_proj",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"tok_embeddings": "model.embed_tokens",
|
||||
"output": "lm_head",
|
||||
"norm": "model.norm"
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HairuoConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = HairuoModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config = lora_config,
|
||||
prefix="model")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=(
|
||||
DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else
|
||||
lora_config.lora_vocab_padding_size),
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(
|
||||
self.model.embed_tokens)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logit_scale)
|
||||
self.sampler = Sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return model_output
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
hidden_states = hidden_states / self.config.mup_scale_width
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
loader.load_weights(
|
||||
self.maybe_remap_mistral(name, loaded_weight)
|
||||
for name, loaded_weight in weights)
|
||||
"""
|
||||
|
||||
def maybe_remap_mistral(
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> Tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w: torch.Tensor, n_heads: int):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
attn_out = self.config.hidden_size
|
||||
|
||||
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||
|
||||
mapping = self.mistral_mapping
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if "wk" in modules:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
elif "wq" in modules:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
|
||||
for item in modules:
|
||||
if item in mapping and mapping[item] not in name:
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
10
ihp/zoo/llama/__init__.py
Normal file
10
ihp/zoo/llama/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: jiangzhs <jiangzhs@inspur.com>
|
||||
# @date: 2024/10/10
|
||||
#
|
||||
|
||||
from ihp.zoo.llama.modeling_llama import LlamaForCausalLM
|
704
ihp/zoo/llama/modeling_llama.py
Normal file
704
ihp/zoo/llama/modeling_llama.py
Normal file
@ -0,0 +1,704 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: jiangzhs <jiangzhs@inspur.com>
|
||||
# @date: 2024/10/10
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.cache_utils import StaticCache
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
||||
from transformers.models.llama.modeling_llama import repeat_kv
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.utils import add_start_docstrings_to_model_forward
|
||||
from transformers.utils import is_flash_attn_greater_or_equal_2_10
|
||||
from transformers.utils import logging
|
||||
|
||||
from ihp.zoo.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "LlamaConfig"
|
||||
|
||||
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||
|
||||
# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LlamaFlashAttention2(LlamaAttention):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if isinstance(past_key_value, StaticCache):
|
||||
raise ValueError(
|
||||
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
||||
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||
)
|
||||
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
LLAMA_ATTENTION_CLASSES = {
|
||||
"eager": LlamaAttention,
|
||||
"flash_attention_2": LlamaFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
def __init__(self, config: LlamaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = LlamaMLP(config)
|
||||
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
cu_seqlens=cu_seqlens,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class LlamaModel(LlamaPreTrainedModel):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
else:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
logger.warning_once(
|
||||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||
)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
cu_seqlens,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = next_cache.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
):
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = LlamaModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
197
ihp/zoo/modeling_flash_attention_utils.py
Normal file
197
ihp/zoo/modeling_flash_attention_utils.py
Normal file
@ -0,0 +1,197 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: jiangzhs <jiangzhs@inspur.com>
|
||||
# @date: 2024/10/08
|
||||
#
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
from transformers.utils import is_flash_attn_greater_or_equal
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis # noqa
|
||||
from flash_attn.bert_padding import pad_input
|
||||
from flash_attn.bert_padding import unpad_input
|
||||
|
||||
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
||||
|
||||
|
||||
def _get_unpad_data(
|
||||
attention_mask: torch.Tensor, cu_seqlens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
if cu_seqlens is not None:
|
||||
max_seqlen_in_batch = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]).item()
|
||||
indices = torch.arange(0, cu_seqlens[-1].item(), device=cu_seqlens.device)
|
||||
else:
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return (indices, cu_seqlens, max_seqlen_in_batch)
|
||||
|
||||
|
||||
def _upad_input(
|
||||
query_layer: torch.Tensor,
|
||||
key_layer: torch.Tensor,
|
||||
value_layer: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
cu_seqlens,
|
||||
):
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask, cu_seqlens)
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
elif query_length == 1:
|
||||
max_seqlen_in_batch_q = 1
|
||||
cu_seqlens_q = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||
) # There is a memcpy here, that is very bad.
|
||||
indices_q = cu_seqlens_q[:-1]
|
||||
query_layer = query_layer.squeeze(1)
|
||||
else:
|
||||
# The -q_len: slice assumes left padding.
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
indices_q,
|
||||
(cu_seqlens_q, cu_seqlens_k),
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
|
||||
def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
||||
query = query.view(-1, query.size(-2), query.size(-1))
|
||||
key = key.view(-1, key.size(-2), key.size(-1))
|
||||
value = value.view(-1, value.size(-2), value.size(-1))
|
||||
position_ids = position_ids.flatten()
|
||||
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
||||
|
||||
cu_seq_lens = torch.cat(
|
||||
(indices_q[position_ids == 0], torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32))
|
||||
)
|
||||
|
||||
max_length = position_ids.max() + 1
|
||||
|
||||
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
|
||||
|
||||
|
||||
def _flash_attention_forward(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
is_causal: bool,
|
||||
dropout: float = 0.0,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
use_top_left_mask: bool = False,
|
||||
softcap: Optional[float] = None,
|
||||
deterministic: bool = None,
|
||||
cu_seqlens=None,
|
||||
):
|
||||
if not use_top_left_mask:
|
||||
causal = is_causal
|
||||
else:
|
||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
|
||||
causal = is_causal and query_length != 1
|
||||
|
||||
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
|
||||
)
|
||||
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
||||
|
||||
if is_flash_attn_greater_or_equal("2.4.1"):
|
||||
if deterministic is None:
|
||||
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
flash_kwargs["deterministic"] = deterministic
|
||||
|
||||
if softcap is not None:
|
||||
flash_kwargs["softcap"] = softcap
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None or cu_seqlens is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length, cu_seqlens
|
||||
)
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
**flash_kwargs,
|
||||
)
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
|
||||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||
# Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always)
|
||||
elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
|
||||
batch_size = query_states.size(0)
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
||||
query_states, key_states, value_states, position_ids
|
||||
)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
**flash_kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
||||
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
||||
)
|
||||
|
||||
return attn_output
|
10
ihp/zoo/qwen/__init__.py
Normal file
10
ihp/zoo/qwen/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: jiangzhs <jiangzhs@inspur.com>
|
||||
# @date: 2024/10/08
|
||||
#
|
||||
|
||||
from ihp.zoo.qwen.modeling_qwen2 import Qwen2ForCausalLM
|
698
ihp/zoo/qwen/modeling_qwen2.py
Normal file
698
ihp/zoo/qwen/modeling_qwen2.py
Normal file
@ -0,0 +1,698 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright @2024 AI. Inspur Inc.
|
||||
#
|
||||
# @author: jiangzhs <jiangzhs@inspur.com>
|
||||
# @date: 2024/10/08
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import Cache
|
||||
from transformers import DynamicCache
|
||||
from transformers import StaticCache
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.qwen2 import Qwen2Config
|
||||
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2PreTrainedModel
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2RotaryEmbedding
|
||||
from transformers.models.qwen2.modeling_qwen2 import repeat_kv
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
from transformers.utils import is_flash_attn_greater_or_equal_2_10
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ihp.zoo.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
|
||||
_CONFIG_FOR_DOC = "Qwen2Config"
|
||||
|
||||
|
||||
class Qwen2Attention(nn.Module):
|
||||
def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Qwen2FlashAttention2(Qwen2Attention):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
||||
kv_seq_len = key_states.shape[-2] + cache_position[0]
|
||||
if (
|
||||
getattr(self.config, "sliding_window", None) is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
and cache_has_contents
|
||||
):
|
||||
slicing_tokens = 1 - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[self.layer_idx][0]
|
||||
past_value = past_key_value[self.layer_idx][1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, slicing_tokens:]
|
||||
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
||||
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
# Reashape to the expected shape for Flash Attention
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if (
|
||||
self.config.use_sliding_window
|
||||
and getattr(self.config, "sliding_window", None) is not None
|
||||
and self.layer_idx >= self.config.max_window_layers
|
||||
):
|
||||
sliding_window = self.config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=sliding_window,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
QWEN2_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2Attention,
|
||||
"flash_attention_2": Qwen2FlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2DecoderLayer(nn.Module):
|
||||
def __init__(self, config: Qwen2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
if config.sliding_window and config._attn_implementation != "flash_attention_2":
|
||||
logger.warning_once(
|
||||
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
||||
"unexpected results may be encountered."
|
||||
)
|
||||
self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||
|
||||
self.mlp = Qwen2MLP(config)
|
||||
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Qwen2Model(Qwen2PreTrainedModel):
|
||||
def __init__(self, config: Qwen2Config):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
else:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
logger.warning_once(
|
||||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
cu_seqlens,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = next_cache.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
and not output_attentions
|
||||
):
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
):
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = Qwen2Model(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
55
test_hairuo.py
Normal file
55
test_hairuo.py
Normal file
@ -0,0 +1,55 @@
|
||||
prompt = "你好,我的名字是"
|
||||
model_path = './model_ckpt/hairuo'
|
||||
# run_type = 'vllm'
|
||||
run_type = 'transformers'
|
||||
|
||||
if run_type == 'transformers':
|
||||
from ihp.zoo.hairuo import HairuoTokenizer
|
||||
from ihp.zoo.hairuo import HairuoForCausalLM
|
||||
|
||||
model = HairuoForCausalLM.from_pretrained(model_path)
|
||||
tokenizer = HairuoTokenizer.from_pretrained(model_path)
|
||||
|
||||
model.requires_grad_(False)
|
||||
model.eval()
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
generate_ids = model.generate(inputs.input_ids, attention_mask = inputs.attention_mask, max_length=200, temperature=0.8, do_sample=True, eos_token_id=151644, pad_token_id=151644)
|
||||
generated_text = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
print(generated_text)
|
||||
|
||||
|
||||
if run_type == 'vllm':
|
||||
|
||||
# 载入 LLM 和 SamplingParams
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import ModelRegistry
|
||||
from ihp.zoo.hairuo.vllm_hairuo import HairuoForCausalLM
|
||||
ModelRegistry.register_model("HairuoForCausalLM", HairuoForCausalLM)
|
||||
# 推理数据以List[str]格式组织
|
||||
prompts = [
|
||||
"你好,我的名字是",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"AI的未来是什么?",
|
||||
]
|
||||
# 设置采样参数
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=1)
|
||||
# 加载模型
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=1,
|
||||
# dtype='float32',
|
||||
gpu_memory_utilization=0.95,
|
||||
max_model_len=100,
|
||||
enforce_eager=True,
|
||||
)
|
||||
# 执行推理
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# 输出推理结果
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
Loading…
Reference in New Issue
Block a user