fix: use miniCPM's way to init model
This commit is contained in:
parent
6dce85771e
commit
aef0641a79
@ -346,91 +346,7 @@ class HairuoModel(nn.Module):
|
|||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
"""
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
||||||
stacked_params_mapping = [
|
|
||||||
# (param_name, shard_name, shard_id)
|
|
||||||
(".qkv_proj", ".q_proj", "q"),
|
|
||||||
(".qkv_proj", ".k_proj", "k"),
|
|
||||||
(".qkv_proj", ".v_proj", "v"),
|
|
||||||
(".gate_up_proj", ".gate_proj", 0),
|
|
||||||
(".gate_up_proj", ".up_proj", 1),
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
if ("rotary_emb.cos_cached" in name
|
|
||||||
or "rotary_emb.sin_cached" in name):
|
|
||||||
# Models trained using ColossalAI may include these tensors in
|
|
||||||
# the checkpoint. Skip them.
|
|
||||||
continue
|
|
||||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
|
||||||
# Loading kv cache scales for compressed-tensors quantization
|
|
||||||
param = params_dict[scale_name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
loaded_weight = loaded_weight[0]
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
continue
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
|
|
||||||
# If this function is called, it should always initialize KV cache scale
|
|
||||||
# factors (or else raise an exception). Thus, handled exceptions should
|
|
||||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
|
||||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
|
||||||
quantization_param_path, tp_rank, tp_size,
|
|
||||||
self.config.num_hidden_layers,
|
|
||||||
self.config.__class__.model_type):
|
|
||||||
if not isinstance(self.layers[layer_idx], nn.Identity):
|
|
||||||
layer_self_attn = self.layers[layer_idx].self_attn
|
|
||||||
|
|
||||||
if is_hip():
|
|
||||||
# The scaling factor convention we are assuming is
|
|
||||||
# quantized_value * scaling_factor ~= true_value
|
|
||||||
# which is consistent with the practice of setting
|
|
||||||
# scaling_factor = tensor_amax / FPtype_max
|
|
||||||
scaling_factor *= 2
|
|
||||||
if hasattr(layer_self_attn, "kv_scale"):
|
|
||||||
layer_self_attn.attn._kv_scale = scaling_factor
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Self attention has no KV cache scaling "
|
|
||||||
"factor attribute!")
|
|
||||||
"""
|
|
||||||
|
|
||||||
class HairuoForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class HairuoForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
|
||||||
@ -568,11 +484,11 @@ class HairuoForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
(".qkv_proj", ".q_proj", "q"),
|
("qkv_proj", "q_proj", "q"),
|
||||||
(".qkv_proj", ".k_proj", "k"),
|
("qkv_proj", "k_proj", "k"),
|
||||||
(".qkv_proj", ".v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
(".gate_up_proj", ".gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
(".gate_up_proj", ".up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
@ -583,87 +499,47 @@ class HairuoForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
# With tie_word_embeddings, we can skip lm_head.weight
|
||||||
# Loading kv cache scales for compressed-tensors quantization
|
# The weight might appear unnecessarily in the files if the model is
|
||||||
param = params_dict[scale_name]
|
# processed with quantization, LoRA, fine-tuning, etc.
|
||||||
weight_loader = getattr(param, "weight_loader",
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
default_weight_loader)
|
|
||||||
loaded_weight = loaded_weight[0]
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
continue
|
continue
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
|
||||||
break
|
break
|
||||||
|
"""
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
for param_name, weight_name, expert_id in expert_params_mapping:
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
# Remapping the name of FP8 kv-scale.
|
name = name.replace(weight_name, param_name)
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
if is_pp_missing_parameter(name, self):
|
||||||
if name is None:
|
continue
|
||||||
continue
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
if is_pp_missing_parameter(name, self):
|
weight_loader(param,
|
||||||
continue
|
loaded_weight,
|
||||||
|
weight_name,
|
||||||
param = params_dict[name]
|
expert_id=expert_id)
|
||||||
weight_loader = getattr(param, "weight_loader",
|
break
|
||||||
default_weight_loader)
|
else:
|
||||||
weight_loader(param, loaded_weight)
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
"""
|
continue
|
||||||
loader = AutoWeightsLoader(
|
param = params_dict[name]
|
||||||
self,
|
weight_loader = getattr(param, "weight_loader",
|
||||||
skip_prefixes=(["lm_head."]
|
default_weight_loader)
|
||||||
if self.config.tie_word_embeddings else None),
|
weight_loader(param, loaded_weight)
|
||||||
)
|
"""
|
||||||
loader.load_weights(
|
|
||||||
self.maybe_remap_mistral(name, loaded_weight)
|
|
||||||
for name, loaded_weight in weights)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def maybe_remap_mistral(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
) -> Tuple[str, torch.Tensor]:
|
|
||||||
|
|
||||||
def permute(w: torch.Tensor, n_heads: int):
|
|
||||||
attn_in = self.config.head_dim * n_heads
|
|
||||||
attn_out = self.config.hidden_size
|
|
||||||
|
|
||||||
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
|
||||||
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
|
||||||
|
|
||||||
mapping = self.mistral_mapping
|
|
||||||
modules = name.split(".")
|
|
||||||
|
|
||||||
# rotary embeds should be sliced
|
|
||||||
if "wk" in modules:
|
|
||||||
loaded_weight = permute(loaded_weight,
|
|
||||||
self.config.num_key_value_heads)
|
|
||||||
elif "wq" in modules:
|
|
||||||
loaded_weight = permute(loaded_weight,
|
|
||||||
self.config.num_attention_heads)
|
|
||||||
|
|
||||||
for item in modules:
|
|
||||||
if item in mapping and mapping[item] not in name:
|
|
||||||
name = name.replace(item, mapping[item])
|
|
||||||
|
|
||||||
return name, loaded_weight
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user