feat: first submit
This commit is contained in:
parent
cb5abad7f8
commit
c644085c99
8
ihp/__init__.py
Normal file
8
ihp/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: sunxian <sunxian@inspur.com>
|
||||||
|
# @date: 2024/07/18
|
||||||
|
#
|
24
ihp/config/data/cfg.hairuo.2b.efficiency.json
Normal file
24
ihp/config/data/cfg.hairuo.2b.efficiency.json
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
{
|
||||||
|
"wudao": {
|
||||||
|
"group": "wudao",
|
||||||
|
"name": "wudao",
|
||||||
|
"epoch": 1,
|
||||||
|
"path": "wudao",
|
||||||
|
"strategy": {
|
||||||
|
"st_segment": "naive",
|
||||||
|
"st_tokenize": "legacy"
|
||||||
|
},
|
||||||
|
"weight": 0.5
|
||||||
|
},
|
||||||
|
"zwjcylk": {
|
||||||
|
"group": "zwjcylk",
|
||||||
|
"name": "zwjcylk",
|
||||||
|
"epoch": 1,
|
||||||
|
"path": "zwjcylk",
|
||||||
|
"strategy": {
|
||||||
|
"st_segment": "naive",
|
||||||
|
"st_tokenize": "legacy"
|
||||||
|
},
|
||||||
|
"weight": 0.5
|
||||||
|
}
|
||||||
|
}
|
24
ihp/config/data/cfg.hairuo.2b.stable.json
Normal file
24
ihp/config/data/cfg.hairuo.2b.stable.json
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
{
|
||||||
|
"wudao": {
|
||||||
|
"group": "wudao",
|
||||||
|
"name": "wudao",
|
||||||
|
"epoch": 1,
|
||||||
|
"path": "wudao",
|
||||||
|
"strategy": {
|
||||||
|
"st_segment": "naive",
|
||||||
|
"st_tokenize": "legacy"
|
||||||
|
},
|
||||||
|
"weight": 0.5
|
||||||
|
},
|
||||||
|
"zwjcylk": {
|
||||||
|
"group": "zwjcylk",
|
||||||
|
"name": "zwjcylk",
|
||||||
|
"epoch": 1,
|
||||||
|
"path": "zwjcylk",
|
||||||
|
"strategy": {
|
||||||
|
"st_segment": "naive",
|
||||||
|
"st_tokenize": "legacy"
|
||||||
|
},
|
||||||
|
"weight": 0.5
|
||||||
|
}
|
||||||
|
}
|
37
ihp/config/model/cfg.hairuo.2b.json
Normal file
37
ihp/config/model/cfg.hairuo.2b.json
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
{
|
||||||
|
"config_clz": "hairuo.HairuoConfig",
|
||||||
|
"config": {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 2304,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 5760,
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"model_type": "hairuo",
|
||||||
|
"num_attention_heads": 36,
|
||||||
|
"num_hidden_layers": 40,
|
||||||
|
"num_key_value_heads": 36,
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_scaling": {
|
||||||
|
"factor": 8.0,
|
||||||
|
"high_freq_factor": 4.0,
|
||||||
|
"low_freq_factor": 1.0,
|
||||||
|
"original_max_position_embeddings": 8192,
|
||||||
|
"rope_type": "llama3"
|
||||||
|
},
|
||||||
|
"rope_theta": 500000.0,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 152064,
|
||||||
|
"_attn_implementation": "eager",
|
||||||
|
"_flash_attn_2_enabled": false,
|
||||||
|
"mup_scale_emb": 12,
|
||||||
|
"mup_scale_depth": 1.4,
|
||||||
|
"mup_scale_width": 9.0
|
||||||
|
},
|
||||||
|
"model_clz": "hairuo.HairuoForCausalLM",
|
||||||
|
"tokenizer_clz": "hairuo.HairuoTokenizer"
|
||||||
|
}
|
147
ihp/model/__init__.py
Normal file
147
ihp/model/__init__.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: sunxian <sunxian@inspur.com>
|
||||||
|
# @date: 2024/07/21
|
||||||
|
#
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import tempfile
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.moe.utils import skip_init
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
from ihp.optim import get_lr_scheduler
|
||||||
|
from ihp.optim import get_optimizer
|
||||||
|
from ihp.util import importer
|
||||||
|
from ihp.util.booster import get_booster
|
||||||
|
from ihp.util.io import save_json
|
||||||
|
from ihp.util.logger import logger
|
||||||
|
from ihp.util.metric import format_numel
|
||||||
|
from ihp.util.metric import get_model_numel
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_load_model(
|
||||||
|
storage, coordinator, args, config, with_optimizer, with_scheduler, extra_config=None, extra_info=""
|
||||||
|
):
|
||||||
|
rank = coordinator.rank
|
||||||
|
logger.info(f"rank-{rank} -> init {extra_info} model")
|
||||||
|
|
||||||
|
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||||
|
torch.set_default_dtype(default_dtype)
|
||||||
|
|
||||||
|
if args.use_lazy_ctx:
|
||||||
|
init_ctx = LazyInitContext(default_device=get_current_device())
|
||||||
|
elif args.use_skip_ctx:
|
||||||
|
init_ctx = skip_init()
|
||||||
|
else:
|
||||||
|
init_ctx = nullcontext()
|
||||||
|
|
||||||
|
if coordinator.is_master():
|
||||||
|
logger.info(f"init ctx: {init_ctx}")
|
||||||
|
|
||||||
|
with init_ctx:
|
||||||
|
logger.info(f"rank-{rank} -> [start]init config({args.config_clz}) from {config}")
|
||||||
|
config_clz = importer.from_name_to_clz(args.config_clz)
|
||||||
|
if isinstance(config, dict):
|
||||||
|
with tempfile.NamedTemporaryFile() as tmp:
|
||||||
|
save_json(config, tmp.name)
|
||||||
|
model_config = config_clz.from_pretrained(tmp.name, trust_remote_code=True)
|
||||||
|
else:
|
||||||
|
model_config = config_clz.from_pretrained(config, trust_remote_code=True)
|
||||||
|
|
||||||
|
if extra_config is None:
|
||||||
|
extra_config = {}
|
||||||
|
|
||||||
|
if hasattr(model_config, "_flash_attn_2_enabled"):
|
||||||
|
extra_config["_flash_attn_2_enabled"] = args.use_flash_attn
|
||||||
|
if hasattr(model_config, "_attn_implementation"):
|
||||||
|
extra_config["_attn_implementation"] = "flash_attention_2" if args.use_flash_attn else "eager"
|
||||||
|
if hasattr(model_config, "use_cache"):
|
||||||
|
extra_config["use_cache"] = not args.use_flash_attn
|
||||||
|
if hasattr(model_config, "output_router_logits"):
|
||||||
|
extra_config["output_router_logits"] = True
|
||||||
|
if hasattr(model_config, "router_aux_loss_coef") and args.router_aux_loss_coef > 0.0:
|
||||||
|
extra_config["router_aux_loss_coef"] = args.router_aux_loss_coef
|
||||||
|
if hasattr(model_config, "initializer_range"):
|
||||||
|
extra_config["initializer_range"] = args.init_std
|
||||||
|
|
||||||
|
if not args.use_mup:
|
||||||
|
args.mup_scale_emb = 1.0
|
||||||
|
args.mup_scale_depth = math.sqrt(model_config.num_hidden_layers)
|
||||||
|
args.mup_scale_width = 1.0
|
||||||
|
else:
|
||||||
|
if hasattr(model_config, "mup_scale_emb"):
|
||||||
|
args.mup_scale_emb = args.mup_scale_emb or model_config.mup_scale_emb
|
||||||
|
if hasattr(model_config, "mup_scale_depth"):
|
||||||
|
args.mup_scale_depth = args.mup_scale_depth or model_config.mup_scale_depth
|
||||||
|
if hasattr(model_config, "mup_scale_width"):
|
||||||
|
args.mup_scale_width = args.mup_scale_width or model_config.mup_scale_width
|
||||||
|
|
||||||
|
extra_config["mup_scale_emb"] = args.mup_scale_emb
|
||||||
|
extra_config["mup_scale_depth"] = args.mup_scale_depth
|
||||||
|
extra_config["mup_scale_width"] = args.mup_scale_width
|
||||||
|
|
||||||
|
if isinstance(extra_config, dict):
|
||||||
|
for k, v in extra_config.items():
|
||||||
|
if hasattr(model_config, k):
|
||||||
|
model_config.__setattr__(k, v)
|
||||||
|
|
||||||
|
logger.info(f"rank-{rank} -> [start]init model({args.model_clz}) with {model_config}")
|
||||||
|
model_clz = importer.from_name_to_clz(args.model_clz)
|
||||||
|
if hasattr(model_clz, "from_config"):
|
||||||
|
model = model_clz.from_config(
|
||||||
|
model_config, use_flash_attention_2=args.use_flash_attn, trust_remote_code=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = model_clz(model_config)
|
||||||
|
|
||||||
|
if args.grad_checkpointing:
|
||||||
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": args.use_reentrant})
|
||||||
|
|
||||||
|
architecture = model.__class__.__name__
|
||||||
|
model_numel, non_embed_numel, trainable_numel = get_model_numel(model)
|
||||||
|
if coordinator.is_master():
|
||||||
|
args_dict = {key: value for key, value in args.__dict__.items() if key != "dataset"}
|
||||||
|
logger.info(
|
||||||
|
f"{extra_info} model specs architecture: {architecture}"
|
||||||
|
f", parameters full: {format_numel(model_numel)}"
|
||||||
|
f", non-embed: {format_numel(non_embed_numel)}"
|
||||||
|
f", trainable: {format_numel(trainable_numel)}"
|
||||||
|
f", args: {json.dumps(args_dict, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = get_optimizer(args, model) if with_optimizer else None
|
||||||
|
lr_scheduler = get_lr_scheduler(args, optimizer) if with_scheduler else None
|
||||||
|
booster = get_booster(args)
|
||||||
|
|
||||||
|
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
|
||||||
|
|
||||||
|
st_step, extra_states = 1, {}
|
||||||
|
|
||||||
|
if args.load:
|
||||||
|
states = storage.load(
|
||||||
|
booster,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
args.load,
|
||||||
|
coordinator.rank,
|
||||||
|
args.reset_states,
|
||||||
|
not args.not_use_strict,
|
||||||
|
)
|
||||||
|
st_step, extra_states = states
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
return (
|
||||||
|
booster, model, optimizer, lr_scheduler,
|
||||||
|
architecture, model_config, model_numel,
|
||||||
|
st_step, extra_states
|
||||||
|
)
|
||||||
|
# fmt: on
|
22
ihp/model/loss.py
Normal file
22
ihp/model/loss.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: sunxian <sunxian@inspur.com>
|
||||||
|
# @date: 2024/07/21
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from hakkero.dataset import IGNORE_INDEX
|
||||||
|
|
||||||
|
|
||||||
|
def lm_cross_entropy(logits, labels, reduction="mean"):
|
||||||
|
# we do not do the stupid shifting as in huggingface since we shift it in dataset
|
||||||
|
logits = logits.view(-1, logits.shape[-1])
|
||||||
|
labels = labels.view(-1)
|
||||||
|
return torch.nn.functional.cross_entropy(logits, labels, ignore_index=IGNORE_INDEX, reduction=reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def lm_z_loss(logits):
|
||||||
|
return logits.max(-1).values().square().mean()
|
8
ihp/zoo/__init__.py
Normal file
8
ihp/zoo/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: sunxian <sunxian@inspur.com>
|
||||||
|
# @date: 2024/07/18
|
||||||
|
#
|
15
ihp/zoo/hairuo/__init__.py
Normal file
15
ihp/zoo/hairuo/__init__.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: sunxian <sunxian@inspur.com>
|
||||||
|
# @date: 2024/07/18
|
||||||
|
#
|
||||||
|
|
||||||
|
from ihp.zoo.hairuo.configuration_hairuo import HairuoConfig
|
||||||
|
from ihp.zoo.hairuo.modeling_hairuo import HairuoForCausalLM
|
||||||
|
from ihp.zoo.hairuo.modeling_hairuo import HairuoForSequenceClassification
|
||||||
|
from ihp.zoo.hairuo.modeling_hairuo import HairuoForTokenClassification
|
||||||
|
from ihp.zoo.hairuo.modeling_hairuo import HairuoModel
|
||||||
|
from ihp.zoo.hairuo.tokenization_hairuo import HairuoTokenizer
|
73
ihp/zoo/hairuo/configuration_hairuo.py
Normal file
73
ihp/zoo/hairuo/configuration_hairuo.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: sunxian <sunxian@inspur.com>
|
||||||
|
# @date: 2024/07/18
|
||||||
|
#
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoConfig(PretrainedConfig):
|
||||||
|
model_type = "hairuo"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
_auto_class = "AutoConfig"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=14336,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096 * 32,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=8.4e5,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
mup_scale_emb=1,
|
||||||
|
mup_scale_depth=32,
|
||||||
|
mup_scale_width=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
self.mup_scale_emb = mup_scale_emb
|
||||||
|
self.mup_scale_depth = mup_scale_depth
|
||||||
|
self.mup_scale_width = mup_scale_width
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
183
ihp/zoo/hairuo/modeling_flash_attention_utils.py
Normal file
183
ihp/zoo/hairuo/modeling_flash_attention_utils.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: sunxian <sunxian@inspur.com>
|
||||||
|
# @date: 2024/07/22
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_func
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import index_first_axis
|
||||||
|
from flash_attn.bert_padding import pad_input
|
||||||
|
from flash_attn.bert_padding import unpad_input
|
||||||
|
|
||||||
|
|
||||||
|
def _get_unpad_data(
|
||||||
|
attention_mask: torch.Tensor, cu_seqlens: torch.Tensor = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||||
|
if cu_seqlens is not None:
|
||||||
|
max_seqlen_in_batch = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]).item()
|
||||||
|
indices = torch.arange(0, cu_seqlens[-1].item(), device=cu_seqlens.device)
|
||||||
|
else:
|
||||||
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
|
||||||
|
return indices, cu_seqlens, max_seqlen_in_batch
|
||||||
|
|
||||||
|
|
||||||
|
def _unpad_input(
|
||||||
|
query_layer: torch.Tensor,
|
||||||
|
key_layer: torch.Tensor,
|
||||||
|
value_layer: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
query_length: int,
|
||||||
|
cu_seqlens,
|
||||||
|
):
|
||||||
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask, cu_seqlens)
|
||||||
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||||
|
|
||||||
|
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
|
||||||
|
value_layer = index_first_axis(
|
||||||
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
|
||||||
|
if query_length == kv_seq_len:
|
||||||
|
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
|
||||||
|
cu_seqlens_q = cu_seqlens_k
|
||||||
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||||
|
indices_q = indices_k
|
||||||
|
elif query_length == 1:
|
||||||
|
max_seqlen_in_batch_q = 1
|
||||||
|
# There is a memcpy here, that is very bad.
|
||||||
|
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device)
|
||||||
|
indices_q = cu_seqlens_q[:-1]
|
||||||
|
query_layer = query_layer.squeeze(1)
|
||||||
|
else:
|
||||||
|
# The -q_len: slice assumes left padding.
|
||||||
|
attention_mask = attention_mask[:, -query_length:]
|
||||||
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||||
|
|
||||||
|
return (
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
indices_q,
|
||||||
|
(cu_seqlens_q, cu_seqlens_k),
|
||||||
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
||||||
|
query = query.view(-1, query.size(-2), query.size(-1))
|
||||||
|
key = key.view(-1, key.size(-2), key.size(-1))
|
||||||
|
value = value.view(-1, value.size(-2), value.size(-1))
|
||||||
|
position_ids = position_ids.flatten()
|
||||||
|
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
||||||
|
|
||||||
|
cu_seq_lens = torch.cat(
|
||||||
|
(
|
||||||
|
indices_q[position_ids == 0],
|
||||||
|
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
max_length = position_ids.max() + 1
|
||||||
|
|
||||||
|
return query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)
|
||||||
|
|
||||||
|
|
||||||
|
def _flash_attention_forward(
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
query_length: int,
|
||||||
|
is_causal: bool,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
use_top_left_mask: bool = False,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
deterministic: bool = None,
|
||||||
|
cu_seqlens=None,
|
||||||
|
):
|
||||||
|
if not use_top_left_mask:
|
||||||
|
causal = is_causal
|
||||||
|
else:
|
||||||
|
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
|
||||||
|
# For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
|
||||||
|
causal = is_causal and query_length != 1
|
||||||
|
|
||||||
|
if deterministic is None:
|
||||||
|
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||||
|
flash_kwargs = {"deterministic": deterministic}
|
||||||
|
|
||||||
|
if softcap is not None:
|
||||||
|
flash_kwargs["softcap"] = softcap
|
||||||
|
|
||||||
|
if attention_mask is not None or cu_seqlens is not None:
|
||||||
|
batch_size = query_states.shape[0]
|
||||||
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _unpad_input(
|
||||||
|
query_states, key_states, value_states, attention_mask, query_length, cu_seqlens
|
||||||
|
)
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
**flash_kwargs,
|
||||||
|
)
|
||||||
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
|
elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all() and query_length != 1:
|
||||||
|
batch_size = query_states.size(0)
|
||||||
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
||||||
|
query_states, key_states, value_states, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
**flash_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
1039
ihp/zoo/hairuo/modeling_hairuo.py
Normal file
1039
ihp/zoo/hairuo/modeling_hairuo.py
Normal file
File diff suppressed because it is too large
Load Diff
247
ihp/zoo/hairuo/tokenization_hairuo.py
Normal file
247
ihp/zoo/hairuo/tokenization_hairuo.py
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: sunxian <sunxian@inspur.com>
|
||||||
|
# @date: 2024/08/08
|
||||||
|
#
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import unicodedata
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
from transformers.tokenization_utils import AddedToken
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {
|
||||||
|
"vocab_file": "vocab.json",
|
||||||
|
"merges_file": "merges.txt",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MAX_MODEL_INPUT_SIZES = {"hairuo/hairuo-tokenizer": 32768}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def bytes_to_unicode():
|
||||||
|
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||||
|
cs = bs[:]
|
||||||
|
n = 0
|
||||||
|
for b in range(2**8):
|
||||||
|
if b not in bs:
|
||||||
|
bs.append(b)
|
||||||
|
cs.append(2**8 + n)
|
||||||
|
n += 1
|
||||||
|
cs = [chr(n) for n in cs]
|
||||||
|
return dict(zip(bs, cs))
|
||||||
|
|
||||||
|
|
||||||
|
def get_pairs(word):
|
||||||
|
pairs = set()
|
||||||
|
prev_char = word[0]
|
||||||
|
for char in word[1:]:
|
||||||
|
pairs.add((prev_char, char))
|
||||||
|
prev_char = char
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoTokenizer(PreTrainedTokenizer):
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
model_input_names = ["input_ids", "attention_mask"]
|
||||||
|
_auto_class = "AutoTokenizer"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
merges_file,
|
||||||
|
errors="replace",
|
||||||
|
unk_token="<|end_of_text|>",
|
||||||
|
bos_token=None,
|
||||||
|
eos_token="<|end_of_text|>",
|
||||||
|
pad_token="<|end_of_text|>",
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
split_special_tokens=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
bos_token = (
|
||||||
|
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||||
|
if isinstance(bos_token, str)
|
||||||
|
else bos_token
|
||||||
|
)
|
||||||
|
eos_token = (
|
||||||
|
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||||
|
if isinstance(eos_token, str)
|
||||||
|
else eos_token
|
||||||
|
)
|
||||||
|
unk_token = (
|
||||||
|
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||||
|
if isinstance(unk_token, str)
|
||||||
|
else unk_token
|
||||||
|
)
|
||||||
|
pad_token = (
|
||||||
|
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||||
|
if isinstance(pad_token, str)
|
||||||
|
else pad_token
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||||
|
self.encoder = json.load(vocab_handle)
|
||||||
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||||
|
self.errors = errors
|
||||||
|
self.byte_encoder = bytes_to_unicode()
|
||||||
|
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||||
|
bpe_merges = []
|
||||||
|
with open(merges_file, encoding="utf-8") as merges_handle:
|
||||||
|
for i, line in enumerate(merges_handle):
|
||||||
|
line = line.strip()
|
||||||
|
if (i == 0 and line.startswith("#version:")) or not line:
|
||||||
|
continue
|
||||||
|
bpe_merges.append(tuple(line.split()))
|
||||||
|
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
self.pat = re.compile(
|
||||||
|
r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||||
|
)
|
||||||
|
|
||||||
|
if kwargs.get("add_prefix_space", False):
|
||||||
|
logger.warning_once(
|
||||||
|
f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
errors=errors,
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
unk_token=unk_token,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
split_special_tokens=split_special_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return len(self.encoder)
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
return dict(self.encoder, **self.added_tokens_encoder)
|
||||||
|
|
||||||
|
def bpe(self, token):
|
||||||
|
if token in self.cache:
|
||||||
|
return self.cache[token]
|
||||||
|
word = tuple(token)
|
||||||
|
pairs = get_pairs(word)
|
||||||
|
|
||||||
|
if not pairs:
|
||||||
|
return token
|
||||||
|
|
||||||
|
while True:
|
||||||
|
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||||
|
if bigram not in self.bpe_ranks:
|
||||||
|
break
|
||||||
|
first, second = bigram
|
||||||
|
new_word = []
|
||||||
|
i = 0
|
||||||
|
while i < len(word):
|
||||||
|
try:
|
||||||
|
j = word.index(first, i)
|
||||||
|
except ValueError:
|
||||||
|
new_word.extend(word[i:])
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
new_word.extend(word[i:j])
|
||||||
|
i = j
|
||||||
|
|
||||||
|
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||||
|
new_word.append(first + second)
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
new_word.append(word[i])
|
||||||
|
i += 1
|
||||||
|
new_word = tuple(new_word)
|
||||||
|
word = new_word
|
||||||
|
if len(word) == 1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pairs = get_pairs(word)
|
||||||
|
word = " ".join(word)
|
||||||
|
self.cache[token] = word
|
||||||
|
return word
|
||||||
|
|
||||||
|
def _tokenize(self, text, **kwargs):
|
||||||
|
bpe_tokens = []
|
||||||
|
for token in self.pat.findall(text):
|
||||||
|
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||||
|
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
||||||
|
return bpe_tokens
|
||||||
|
|
||||||
|
def _convert_token_to_id(self, token):
|
||||||
|
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index):
|
||||||
|
return self.decoder.get(index)
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens):
|
||||||
|
text = "".join(tokens)
|
||||||
|
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
token_ids,
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
clean_up_tokenization_spaces: Optional[bool] = False,
|
||||||
|
spaces_between_special_tokens: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
return super().decode(
|
||||||
|
token_ids,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
|
if not os.path.isdir(save_directory):
|
||||||
|
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||||
|
return
|
||||||
|
|
||||||
|
vocab_file = os.path.join(
|
||||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||||
|
)
|
||||||
|
merge_file = os.path.join(
|
||||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||||
|
writer.write("#version: 0.2\n")
|
||||||
|
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||||
|
if index != token_index:
|
||||||
|
logger.warning(
|
||||||
|
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
||||||
|
" Please check that the tokenizer is not corrupted!"
|
||||||
|
)
|
||||||
|
index = token_index
|
||||||
|
writer.write(" ".join(bpe_tokens) + "\n")
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
return vocab_file, merge_file
|
||||||
|
|
||||||
|
def prepare_for_tokenization(self, text, **kwargs):
|
||||||
|
text = unicodedata.normalize("NFC", text)
|
||||||
|
return text, kwargs
|
609
ihp/zoo/hairuo/vllm_hairuo.py
Normal file
609
ihp/zoo/hairuo/vllm_hairuo.py
Normal file
@ -0,0 +1,609 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: sunxian <sunxian@inspur.com>
|
||||||
|
# @date: 2024/07/18
|
||||||
|
#
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Iterable
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import BCEWithLogitsLoss
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from torch.nn import MSELoss
|
||||||
|
|
||||||
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
|
get_compressed_tensors_cache_scale)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||||
|
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory, make_layers)
|
||||||
|
|
||||||
|
from .configuration_hairuo import HairuoConfig
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoMLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
bias: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
input_size=hidden_size,
|
||||||
|
output_sizes=[intermediate_size] * 2,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
input_size=intermediate_size,
|
||||||
|
output_size=hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class HairuoAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: HairuoConfig,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
bias: bool = False,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||||
|
self.head_dim = getattr(config, "head_dim",
|
||||||
|
self.hidden_size // self.total_num_heads)
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
total_num_heads=self.total_num_heads,
|
||||||
|
total_num_kv_heads=self.total_num_kv_heads,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
input_size=self.total_num_heads * self.head_dim,
|
||||||
|
output_size=hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
is_neox_style = True
|
||||||
|
if quant_config is not None and quant_config.get_name() == "gguf":
|
||||||
|
is_neox_style = False
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
|
)
|
||||||
|
self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache()
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
# orig_dtype = q.dtype
|
||||||
|
# q, k = q.float(), k.float()
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
# q, k = q.to(orig_dtype), k.to(orig_dtype)
|
||||||
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: HairuoConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.mup_scale_hidden_states = config.mup_scale_depth / math.sqrt(config.num_hidden_layers)
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
if rope_scaling is not None and getattr(
|
||||||
|
config, "original_max_position_embeddings", None):
|
||||||
|
rope_scaling["original_max_position_embeddings"] = (
|
||||||
|
config.original_max_position_embeddings)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
|
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||||
|
config, "bias", False)
|
||||||
|
self.self_attn = HairuoAttention(
|
||||||
|
config=config,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||||
|
config.num_attention_heads),
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
quant_config=quant_config,
|
||||||
|
bias=attention_bias,
|
||||||
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
)
|
||||||
|
self.mlp = HairuoMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
bias=getattr(config, "mlp_bias", False),
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Self Attention
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.self_attn(positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
hidden_states = residual + hidden_states * self.mup_scale_hidden_states
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states= self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states * self.mup_scale_hidden_states
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: HairuoConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||||
|
and get_pp_group().is_last_rank):
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: HairuoDecoderLayer(config=config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix),
|
||||||
|
prefix=f"{prefix}.layers",
|
||||||
|
)
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
# return self.embed_tokens(input_ids) * self.config.mup_scale_emb
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(positions, hidden_states,
|
||||||
|
kv_caches[i - self.start_layer],
|
||||||
|
attn_metadata, residual)
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual
|
||||||
|
})
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if ("rotary_emb.cos_cached" in name
|
||||||
|
or "rotary_emb.sin_cached" in name):
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
if scale_name := get_compressed_tensors_cache_scale(name):
|
||||||
|
# Loading kv cache scales for compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = loaded_weight[0]
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
continue
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
# If this function is called, it should always initialize KV cache scale
|
||||||
|
# factors (or else raise an exception). Thus, handled exceptions should
|
||||||
|
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||||
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||||
|
quantization_param_path, tp_rank, tp_size,
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
self.config.__class__.model_type):
|
||||||
|
if not isinstance(self.layers[layer_idx], nn.Identity):
|
||||||
|
layer_self_attn = self.layers[layer_idx].self_attn
|
||||||
|
|
||||||
|
if is_hip():
|
||||||
|
# The scaling factor convention we are assuming is
|
||||||
|
# quantized_value * scaling_factor ~= true_value
|
||||||
|
# which is consistent with the practice of setting
|
||||||
|
# scaling_factor = tensor_amax / FPtype_max
|
||||||
|
scaling_factor *= 2
|
||||||
|
if hasattr(layer_self_attn, "kv_scale"):
|
||||||
|
layer_self_attn.attn._kv_scale = scaling_factor
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Self attention has no KV cache scaling "
|
||||||
|
"factor attribute!")
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
|
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
|
||||||
|
"lm_head"
|
||||||
|
]
|
||||||
|
embedding_modules = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings"
|
||||||
|
}
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
default_bitsandbytes_target_modules = [
|
||||||
|
".gate_proj.",
|
||||||
|
".down_proj.",
|
||||||
|
".up_proj.",
|
||||||
|
".q_proj.",
|
||||||
|
".k_proj.",
|
||||||
|
".v_proj.",
|
||||||
|
".o_proj.",
|
||||||
|
]
|
||||||
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||||
|
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||||
|
bitsandbytes_stacked_params_mapping = {
|
||||||
|
# shard_name, weight_name, index
|
||||||
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
"k_proj": ("qkv_proj", 1),
|
||||||
|
"v_proj": ("qkv_proj", 2),
|
||||||
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
|
"up_proj": ("gate_up_proj", 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
mistral_mapping = {
|
||||||
|
"layers": "model.layers",
|
||||||
|
"attention": "self_attn",
|
||||||
|
"wq": "q_proj",
|
||||||
|
"wk": "k_proj",
|
||||||
|
"wv": "v_proj",
|
||||||
|
"wo": "o_proj",
|
||||||
|
"attention_norm": "input_layernorm",
|
||||||
|
"feed_forward": "mlp",
|
||||||
|
"w1": "gate_proj",
|
||||||
|
"w2": "down_proj",
|
||||||
|
"w3": "up_proj",
|
||||||
|
"ffn_norm": "post_attention_layernorm",
|
||||||
|
"tok_embeddings": "model.embed_tokens",
|
||||||
|
"output": "lm_head",
|
||||||
|
"norm": "model.norm"
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: HairuoConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
self.model = HairuoModel(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
lora_config = lora_config,
|
||||||
|
prefix="model")
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
if lora_config:
|
||||||
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=(
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
# We need bigger padding if using lora for kernel
|
||||||
|
# compatibility
|
||||||
|
if not lora_config else
|
||||||
|
lora_config.lora_vocab_padding_size),
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.lm_head.tie_weights(
|
||||||
|
self.model.embed_tokens)
|
||||||
|
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size,
|
||||||
|
logit_scale)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
model_output = self.model(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata, intermediate_tensors)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
hidden_states = hidden_states / self.config.mup_scale_width
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(self, logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."]
|
||||||
|
if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
loader.load_weights(
|
||||||
|
self.maybe_remap_mistral(name, loaded_weight)
|
||||||
|
for name, loaded_weight in weights)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_remap_mistral(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
) -> Tuple[str, torch.Tensor]:
|
||||||
|
|
||||||
|
def permute(w: torch.Tensor, n_heads: int):
|
||||||
|
attn_in = self.config.head_dim * n_heads
|
||||||
|
attn_out = self.config.hidden_size
|
||||||
|
|
||||||
|
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||||
|
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||||
|
|
||||||
|
mapping = self.mistral_mapping
|
||||||
|
modules = name.split(".")
|
||||||
|
|
||||||
|
# rotary embeds should be sliced
|
||||||
|
if "wk" in modules:
|
||||||
|
loaded_weight = permute(loaded_weight,
|
||||||
|
self.config.num_key_value_heads)
|
||||||
|
elif "wq" in modules:
|
||||||
|
loaded_weight = permute(loaded_weight,
|
||||||
|
self.config.num_attention_heads)
|
||||||
|
|
||||||
|
for item in modules:
|
||||||
|
if item in mapping and mapping[item] not in name:
|
||||||
|
name = name.replace(item, mapping[item])
|
||||||
|
|
||||||
|
return name, loaded_weight
|
||||||
|
|
669
ihp/zoo/hairuo/vllm_hairuo_v2.py
Normal file
669
ihp/zoo/hairuo/vllm_hairuo_v2.py
Normal file
@ -0,0 +1,669 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: sunxian <sunxian@inspur.com>
|
||||||
|
# @date: 2024/07/18
|
||||||
|
#
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Iterable
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import BCEWithLogitsLoss
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from torch.nn import MSELoss
|
||||||
|
|
||||||
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
|
get_compressed_tensors_cache_scale)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||||
|
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory, make_layers)
|
||||||
|
|
||||||
|
from .configuration_hairuo import HairuoConfig
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoMLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
bias: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
input_size=hidden_size,
|
||||||
|
output_sizes=[intermediate_size] * 2,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
input_size=intermediate_size,
|
||||||
|
output_size=hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class HairuoAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: HairuoConfig,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
bias: bool = False,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||||
|
self.head_dim = getattr(config, "head_dim",
|
||||||
|
self.hidden_size // self.total_num_heads)
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
total_num_heads=self.total_num_heads,
|
||||||
|
total_num_kv_heads=self.total_num_kv_heads,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
input_size=self.total_num_heads * self.head_dim,
|
||||||
|
output_size=hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
is_neox_style = True
|
||||||
|
if quant_config is not None and quant_config.get_name() == "gguf":
|
||||||
|
is_neox_style = False
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
|
)
|
||||||
|
self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache()
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
# orig_dtype = q.dtype
|
||||||
|
# q, k = q.float(), k.float()
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
# q, k = q.to(orig_dtype), k.to(orig_dtype)
|
||||||
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: HairuoConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.mup_scale_hidden_states = config.mup_scale_depth / math.sqrt(config.num_hidden_layers)
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
if rope_scaling is not None and getattr(
|
||||||
|
config, "original_max_position_embeddings", None):
|
||||||
|
rope_scaling["original_max_position_embeddings"] = (
|
||||||
|
config.original_max_position_embeddings)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
|
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||||
|
config, "bias", False)
|
||||||
|
self.self_attn = HairuoAttention(
|
||||||
|
config=config,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||||
|
config.num_attention_heads),
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
quant_config=quant_config,
|
||||||
|
bias=attention_bias,
|
||||||
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
)
|
||||||
|
self.mlp = HairuoMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
bias=getattr(config, "mlp_bias", False),
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Self Attention
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.self_attn(positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
hidden_states = residual + hidden_states * self.mup_scale_hidden_states
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states= self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states * self.mup_scale_hidden_states
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class HairuoModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: HairuoConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||||
|
and get_pp_group().is_last_rank):
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: HairuoDecoderLayer(config=config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix),
|
||||||
|
prefix=f"{prefix}.layers",
|
||||||
|
)
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids) * self.config.mup_scale_emb
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(positions, hidden_states,
|
||||||
|
kv_caches[i - self.start_layer],
|
||||||
|
attn_metadata, residual)
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual
|
||||||
|
})
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
"""
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if ("rotary_emb.cos_cached" in name
|
||||||
|
or "rotary_emb.sin_cached" in name):
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
if scale_name := get_compressed_tensors_cache_scale(name):
|
||||||
|
# Loading kv cache scales for compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = loaded_weight[0]
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
continue
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
# If this function is called, it should always initialize KV cache scale
|
||||||
|
# factors (or else raise an exception). Thus, handled exceptions should
|
||||||
|
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||||
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||||
|
quantization_param_path, tp_rank, tp_size,
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
self.config.__class__.model_type):
|
||||||
|
if not isinstance(self.layers[layer_idx], nn.Identity):
|
||||||
|
layer_self_attn = self.layers[layer_idx].self_attn
|
||||||
|
|
||||||
|
if is_hip():
|
||||||
|
# The scaling factor convention we are assuming is
|
||||||
|
# quantized_value * scaling_factor ~= true_value
|
||||||
|
# which is consistent with the practice of setting
|
||||||
|
# scaling_factor = tensor_amax / FPtype_max
|
||||||
|
scaling_factor *= 2
|
||||||
|
if hasattr(layer_self_attn, "kv_scale"):
|
||||||
|
layer_self_attn.attn._kv_scale = scaling_factor
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Self attention has no KV cache scaling "
|
||||||
|
"factor attribute!")
|
||||||
|
"""
|
||||||
|
|
||||||
|
class HairuoForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
|
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
|
||||||
|
"lm_head"
|
||||||
|
]
|
||||||
|
embedding_modules = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings"
|
||||||
|
}
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
default_bitsandbytes_target_modules = [
|
||||||
|
".gate_proj.",
|
||||||
|
".down_proj.",
|
||||||
|
".up_proj.",
|
||||||
|
".q_proj.",
|
||||||
|
".k_proj.",
|
||||||
|
".v_proj.",
|
||||||
|
".o_proj.",
|
||||||
|
]
|
||||||
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||||
|
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||||
|
bitsandbytes_stacked_params_mapping = {
|
||||||
|
# shard_name, weight_name, index
|
||||||
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
"k_proj": ("qkv_proj", 1),
|
||||||
|
"v_proj": ("qkv_proj", 2),
|
||||||
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
|
"up_proj": ("gate_up_proj", 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
mistral_mapping = {
|
||||||
|
"layers": "model.layers",
|
||||||
|
"attention": "self_attn",
|
||||||
|
"wq": "q_proj",
|
||||||
|
"wk": "k_proj",
|
||||||
|
"wv": "v_proj",
|
||||||
|
"wo": "o_proj",
|
||||||
|
"attention_norm": "input_layernorm",
|
||||||
|
"feed_forward": "mlp",
|
||||||
|
"w1": "gate_proj",
|
||||||
|
"w2": "down_proj",
|
||||||
|
"w3": "up_proj",
|
||||||
|
"ffn_norm": "post_attention_layernorm",
|
||||||
|
"tok_embeddings": "model.embed_tokens",
|
||||||
|
"output": "lm_head",
|
||||||
|
"norm": "model.norm"
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: HairuoConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
self.model = HairuoModel(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
lora_config = lora_config,
|
||||||
|
prefix="model")
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
if lora_config:
|
||||||
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=(
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
# We need bigger padding if using lora for kernel
|
||||||
|
# compatibility
|
||||||
|
if not lora_config else
|
||||||
|
lora_config.lora_vocab_padding_size),
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.lm_head.tie_weights(
|
||||||
|
self.model.embed_tokens)
|
||||||
|
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size,
|
||||||
|
logit_scale)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
model_output = self.model(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata, intermediate_tensors)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
hidden_states = hidden_states / self.config.mup_scale_width
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(self, logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if ("rotary_emb.cos_cached" in name
|
||||||
|
or "rotary_emb.sin_cached" in name):
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
if scale_name := get_compressed_tensors_cache_scale(name):
|
||||||
|
# Loading kv cache scales for compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = loaded_weight[0]
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
continue
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."]
|
||||||
|
if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
loader.load_weights(
|
||||||
|
self.maybe_remap_mistral(name, loaded_weight)
|
||||||
|
for name, loaded_weight in weights)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def maybe_remap_mistral(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
) -> Tuple[str, torch.Tensor]:
|
||||||
|
|
||||||
|
def permute(w: torch.Tensor, n_heads: int):
|
||||||
|
attn_in = self.config.head_dim * n_heads
|
||||||
|
attn_out = self.config.hidden_size
|
||||||
|
|
||||||
|
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||||
|
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||||
|
|
||||||
|
mapping = self.mistral_mapping
|
||||||
|
modules = name.split(".")
|
||||||
|
|
||||||
|
# rotary embeds should be sliced
|
||||||
|
if "wk" in modules:
|
||||||
|
loaded_weight = permute(loaded_weight,
|
||||||
|
self.config.num_key_value_heads)
|
||||||
|
elif "wq" in modules:
|
||||||
|
loaded_weight = permute(loaded_weight,
|
||||||
|
self.config.num_attention_heads)
|
||||||
|
|
||||||
|
for item in modules:
|
||||||
|
if item in mapping and mapping[item] not in name:
|
||||||
|
name = name.replace(item, mapping[item])
|
||||||
|
|
||||||
|
return name, loaded_weight
|
||||||
|
|
10
ihp/zoo/llama/__init__.py
Normal file
10
ihp/zoo/llama/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: jiangzhs <jiangzhs@inspur.com>
|
||||||
|
# @date: 2024/10/10
|
||||||
|
#
|
||||||
|
|
||||||
|
from ihp.zoo.llama.modeling_llama import LlamaForCausalLM
|
704
ihp/zoo/llama/modeling_llama.py
Normal file
704
ihp/zoo/llama/modeling_llama.py
Normal file
@ -0,0 +1,704 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: jiangzhs <jiangzhs@inspur.com>
|
||||||
|
# @date: 2024/10/10
|
||||||
|
#
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.cache_utils import DynamicCache
|
||||||
|
from transformers.cache_utils import StaticCache
|
||||||
|
from transformers.generation import GenerationMixin
|
||||||
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
||||||
|
from transformers.models.llama.modeling_llama import repeat_kv
|
||||||
|
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
|
from transformers.utils import add_start_docstrings_to_model_forward
|
||||||
|
from transformers.utils import is_flash_attn_greater_or_equal_2_10
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from ihp.zoo.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CONFIG_FOR_DOC = "LlamaConfig"
|
||||||
|
|
||||||
|
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||||
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.is_causal = True
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||||
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||||
|
|
||||||
|
# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
||||||
|
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||||
|
query_slices = self.q_proj.weight.split(
|
||||||
|
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||||
|
)
|
||||||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
|
||||||
|
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
query_states = torch.cat(query_states, dim=-1)
|
||||||
|
|
||||||
|
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
key_states = torch.cat(key_states, dim=-1)
|
||||||
|
|
||||||
|
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory."
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||||
|
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||||
|
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||||
|
else:
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaFlashAttention2(LlamaAttention):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||||
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||||
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if isinstance(past_key_value, StaticCache):
|
||||||
|
raise ValueError(
|
||||||
|
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
||||||
|
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||||
|
)
|
||||||
|
|
||||||
|
output_attentions = False
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
|
# therefore we just need to keep the original shape
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory."
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
position_ids=position_ids,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
sliding_window=getattr(self, "sliding_window", None),
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
LLAMA_ATTENTION_CLASSES = {
|
||||||
|
"eager": LlamaAttention,
|
||||||
|
"flash_attention_2": LlamaFlashAttention2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: LlamaConfig, layer_idx: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||||
|
|
||||||
|
self.mlp = LlamaMLP(config)
|
||||||
|
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaModel(LlamaPreTrainedModel):
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embed_tokens = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = False
|
||||||
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
|
return_legacy_cache = True
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
|
)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# create position embeddings to be shared across the decoder layers
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
causal_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
position_embeddings,
|
||||||
|
cu_seqlens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if return_legacy_cache:
|
||||||
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
output_attentions: bool,
|
||||||
|
):
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds=input_tensor,
|
||||||
|
past_key_values_length=past_seen_tokens,
|
||||||
|
is_training=self.training,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
sequence_length = input_tensor.shape[1]
|
||||||
|
if using_static_cache:
|
||||||
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=target_length,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=input_tensor.shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and attention_mask is not None
|
||||||
|
and attention_mask.device.type == "cuda"
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||||
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = LlamaModel(config)
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
num_logits_to_keep: int = 0,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||||
|
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
logits = torch.cat(logits, dim=-1)
|
||||||
|
else:
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
197
ihp/zoo/modeling_flash_attention_utils.py
Normal file
197
ihp/zoo/modeling_flash_attention_utils.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: jiangzhs <jiangzhs@inspur.com>
|
||||||
|
# @date: 2024/10/08
|
||||||
|
#
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
from transformers.utils import is_flash_attn_greater_or_equal
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_func
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import index_first_axis # noqa
|
||||||
|
from flash_attn.bert_padding import pad_input
|
||||||
|
from flash_attn.bert_padding import unpad_input
|
||||||
|
|
||||||
|
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_unpad_data(
|
||||||
|
attention_mask: torch.Tensor, cu_seqlens: torch.Tensor = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||||
|
if cu_seqlens is not None:
|
||||||
|
max_seqlen_in_batch = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]).item()
|
||||||
|
indices = torch.arange(0, cu_seqlens[-1].item(), device=cu_seqlens.device)
|
||||||
|
else:
|
||||||
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
|
return (indices, cu_seqlens, max_seqlen_in_batch)
|
||||||
|
|
||||||
|
|
||||||
|
def _upad_input(
|
||||||
|
query_layer: torch.Tensor,
|
||||||
|
key_layer: torch.Tensor,
|
||||||
|
value_layer: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
query_length: int,
|
||||||
|
cu_seqlens,
|
||||||
|
):
|
||||||
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask, cu_seqlens)
|
||||||
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||||
|
|
||||||
|
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
|
||||||
|
value_layer = index_first_axis(
|
||||||
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||||
|
)
|
||||||
|
if query_length == kv_seq_len:
|
||||||
|
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
|
||||||
|
cu_seqlens_q = cu_seqlens_k
|
||||||
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||||
|
indices_q = indices_k
|
||||||
|
elif query_length == 1:
|
||||||
|
max_seqlen_in_batch_q = 1
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||||
|
) # There is a memcpy here, that is very bad.
|
||||||
|
indices_q = cu_seqlens_q[:-1]
|
||||||
|
query_layer = query_layer.squeeze(1)
|
||||||
|
else:
|
||||||
|
# The -q_len: slice assumes left padding.
|
||||||
|
attention_mask = attention_mask[:, -query_length:]
|
||||||
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||||
|
|
||||||
|
return (
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
indices_q,
|
||||||
|
(cu_seqlens_q, cu_seqlens_k),
|
||||||
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
||||||
|
query = query.view(-1, query.size(-2), query.size(-1))
|
||||||
|
key = key.view(-1, key.size(-2), key.size(-1))
|
||||||
|
value = value.view(-1, value.size(-2), value.size(-1))
|
||||||
|
position_ids = position_ids.flatten()
|
||||||
|
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
||||||
|
|
||||||
|
cu_seq_lens = torch.cat(
|
||||||
|
(indices_q[position_ids == 0], torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32))
|
||||||
|
)
|
||||||
|
|
||||||
|
max_length = position_ids.max() + 1
|
||||||
|
|
||||||
|
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
|
||||||
|
|
||||||
|
|
||||||
|
def _flash_attention_forward(
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
query_length: int,
|
||||||
|
is_causal: bool,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
use_top_left_mask: bool = False,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
deterministic: bool = None,
|
||||||
|
cu_seqlens=None,
|
||||||
|
):
|
||||||
|
if not use_top_left_mask:
|
||||||
|
causal = is_causal
|
||||||
|
else:
|
||||||
|
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
|
||||||
|
causal = is_causal and query_length != 1
|
||||||
|
|
||||||
|
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
||||||
|
use_sliding_windows = (
|
||||||
|
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
|
||||||
|
)
|
||||||
|
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
||||||
|
|
||||||
|
if is_flash_attn_greater_or_equal("2.4.1"):
|
||||||
|
if deterministic is None:
|
||||||
|
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||||
|
flash_kwargs["deterministic"] = deterministic
|
||||||
|
|
||||||
|
if softcap is not None:
|
||||||
|
flash_kwargs["softcap"] = softcap
|
||||||
|
|
||||||
|
# Contains at least one padding token in the sequence
|
||||||
|
if attention_mask is not None or cu_seqlens is not None:
|
||||||
|
batch_size = query_states.shape[0]
|
||||||
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
|
||||||
|
query_states, key_states, value_states, attention_mask, query_length, cu_seqlens
|
||||||
|
)
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output_unpad = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
**flash_kwargs,
|
||||||
|
)
|
||||||
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
|
|
||||||
|
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||||
|
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||||
|
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||||
|
# Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always)
|
||||||
|
elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
|
||||||
|
batch_size = query_states.size(0)
|
||||||
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
||||||
|
query_states, key_states, value_states, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||||
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||||
|
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_in_batch_q,
|
||||||
|
max_seqlen_k=max_seqlen_in_batch_k,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
**flash_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
||||||
|
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_func(
|
||||||
|
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output
|
10
ihp/zoo/qwen/__init__.py
Normal file
10
ihp/zoo/qwen/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: jiangzhs <jiangzhs@inspur.com>
|
||||||
|
# @date: 2024/10/08
|
||||||
|
#
|
||||||
|
|
||||||
|
from ihp.zoo.qwen.modeling_qwen2 import Qwen2ForCausalLM
|
698
ihp/zoo/qwen/modeling_qwen2.py
Normal file
698
ihp/zoo/qwen/modeling_qwen2.py
Normal file
@ -0,0 +1,698 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright @2024 AI. Inspur Inc.
|
||||||
|
#
|
||||||
|
# @author: jiangzhs <jiangzhs@inspur.com>
|
||||||
|
# @date: 2024/10/08
|
||||||
|
#
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers import Cache
|
||||||
|
from transformers import DynamicCache
|
||||||
|
from transformers import StaticCache
|
||||||
|
from transformers.generation import GenerationMixin
|
||||||
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.qwen2 import Qwen2Config
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2PreTrainedModel
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2RotaryEmbedding
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import repeat_kv
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
from transformers.utils import is_flash_attn_greater_or_equal_2_10
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from ihp.zoo.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
|
||||||
|
_CONFIG_FOR_DOC = "Qwen2Config"
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2Attention(nn.Module):
|
||||||
|
def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
if layer_idx is None:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||||
|
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.is_causal = True
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
|
||||||
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
|
f" and `num_heads`: {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||||
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||||
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||||
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||||
|
|
||||||
|
self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory."
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2FlashAttention2(Qwen2Attention):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if position_embeddings is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory."
|
||||||
|
)
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
else:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||||
|
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
||||||
|
kv_seq_len = key_states.shape[-2] + cache_position[0]
|
||||||
|
if (
|
||||||
|
getattr(self.config, "sliding_window", None) is not None
|
||||||
|
and kv_seq_len > self.config.sliding_window
|
||||||
|
and cache_has_contents
|
||||||
|
):
|
||||||
|
slicing_tokens = 1 - self.config.sliding_window
|
||||||
|
|
||||||
|
past_key = past_key_value[self.layer_idx][0]
|
||||||
|
past_value = past_key_value[self.layer_idx][1]
|
||||||
|
|
||||||
|
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
|
||||||
|
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||||
|
f" {past_key.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, slicing_tokens:]
|
||||||
|
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
||||||
|
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in float16 just to be sure everything works as expected.
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
# Reashape to the expected shape for Flash Attention
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config.use_sliding_window
|
||||||
|
and getattr(self.config, "sliding_window", None) is not None
|
||||||
|
and self.layer_idx >= self.config.max_window_layers
|
||||||
|
):
|
||||||
|
sliding_window = self.config.sliding_window
|
||||||
|
else:
|
||||||
|
sliding_window = None
|
||||||
|
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
position_ids=position_ids,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
is_causal=self.is_causal,
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
QWEN2_ATTENTION_CLASSES = {
|
||||||
|
"eager": Qwen2Attention,
|
||||||
|
"flash_attention_2": Qwen2FlashAttention2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2DecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: Qwen2Config, layer_idx: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
if config.sliding_window and config._attn_implementation != "flash_attention_2":
|
||||||
|
logger.warning_once(
|
||||||
|
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
||||||
|
"unexpected results may be encountered."
|
||||||
|
)
|
||||||
|
self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||||
|
|
||||||
|
self.mlp = Qwen2MLP(config)
|
||||||
|
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2Model(Qwen2PreTrainedModel):
|
||||||
|
def __init__(self, config: Qwen2Config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self._attn_implementation = config._attn_implementation
|
||||||
|
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embed_tokens = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = False
|
||||||
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
|
return_legacy_cache = True
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# create position embeddings to be shared across the decoder layers
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
causal_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
position_embeddings,
|
||||||
|
cu_seqlens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if return_legacy_cache:
|
||||||
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
output_attentions: bool,
|
||||||
|
):
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds=input_tensor,
|
||||||
|
past_key_values_length=past_seen_tokens,
|
||||||
|
is_training=self.training,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
|
||||||
|
sequence_length = input_tensor.shape[1]
|
||||||
|
if using_static_cache:
|
||||||
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=target_length,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=input_tensor.shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and attention_mask is not None
|
||||||
|
and attention_mask.device.type == "cuda"
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
|
||||||
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = Qwen2Model(config)
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
num_logits_to_keep: int = 0,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
55
test_hairuo.py
Normal file
55
test_hairuo.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
prompt = "你好,我的名字是"
|
||||||
|
model_path = './model_ckpt/hairuo'
|
||||||
|
# run_type = 'vllm'
|
||||||
|
run_type = 'transformers'
|
||||||
|
|
||||||
|
if run_type == 'transformers':
|
||||||
|
from ihp.zoo.hairuo import HairuoTokenizer
|
||||||
|
from ihp.zoo.hairuo import HairuoForCausalLM
|
||||||
|
|
||||||
|
model = HairuoForCausalLM.from_pretrained(model_path)
|
||||||
|
tokenizer = HairuoTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
|
model.requires_grad_(False)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
generate_ids = model.generate(inputs.input_ids, attention_mask = inputs.attention_mask, max_length=200, temperature=0.8, do_sample=True, eos_token_id=151644, pad_token_id=151644)
|
||||||
|
generated_text = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
|
||||||
|
if run_type == 'vllm':
|
||||||
|
|
||||||
|
# 载入 LLM 和 SamplingParams
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm import ModelRegistry
|
||||||
|
from ihp.zoo.hairuo.vllm_hairuo import HairuoForCausalLM
|
||||||
|
ModelRegistry.register_model("HairuoForCausalLM", HairuoForCausalLM)
|
||||||
|
# 推理数据以List[str]格式组织
|
||||||
|
prompts = [
|
||||||
|
"你好,我的名字是",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"AI的未来是什么?",
|
||||||
|
]
|
||||||
|
# 设置采样参数
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=1)
|
||||||
|
# 加载模型
|
||||||
|
llm = LLM(
|
||||||
|
model=model_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
# dtype='float32',
|
||||||
|
gpu_memory_utilization=0.95,
|
||||||
|
max_model_len=100,
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
# 执行推理
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# 输出推理结果
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
Loading…
Reference in New Issue
Block a user