|
| 1 | +import torch |
| 2 | +import math |
| 3 | +import numpy as np |
| 4 | +from lightllm.common.basemodel import TransformerLayerWeight |
| 5 | + |
| 6 | +from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight |
| 7 | + |
| 8 | + |
| 9 | +class MiniCPMTransformerLayerWeight(LlamaTransformerLayerWeight): |
| 10 | + def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): |
| 11 | + super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) |
| 12 | + num_hidden_layers = self.network_config_["num_hidden_layers"] |
| 13 | + scale_depth = self.network_config_.get("scale_depth", math.sqrt(num_hidden_layers)) |
| 14 | + self.layer_scale =scale_depth / math.sqrt(num_hidden_layers) |
| 15 | + return |
| 16 | + |
| 17 | + def _load_qkvo_weights(self, weights): |
| 18 | + # input layernorm params |
| 19 | + if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights: |
| 20 | + self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"]) |
| 21 | + |
| 22 | + n_embed = self.network_config_["hidden_size"] |
| 23 | + q_split_n_embed = n_embed // self.world_size_ |
| 24 | + kv_split_n_embed = ( |
| 25 | + n_embed |
| 26 | + // self.network_config_["num_attention_heads"] |
| 27 | + * self.network_config_["num_key_value_heads"] |
| 28 | + // self.world_size_ |
| 29 | + ) |
| 30 | + if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: |
| 31 | + self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] |
| 32 | + self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] |
| 33 | + self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) |
| 34 | + |
| 35 | + if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: |
| 36 | + k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] |
| 37 | + k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] |
| 38 | + self.k_weight_ = k_weight_.transpose(0, 1) |
| 39 | + |
| 40 | + if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: |
| 41 | + v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] |
| 42 | + v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] |
| 43 | + self.v_weight_ = v_weight_.transpose(0, 1) |
| 44 | + |
| 45 | + # attention output dense params |
| 46 | + if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: |
| 47 | + self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] |
| 48 | + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] |
| 49 | + self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) * self.layer_scale |
| 50 | + |
| 51 | + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) |
| 52 | + |
| 53 | + return |
| 54 | + |
| 55 | + def _load_ffn_weights(self, weights): |
| 56 | + if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights: |
| 57 | + self.ffn_norm_weight_ = self._cuda( |
| 58 | + weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"] |
| 59 | + ) |
| 60 | + |
| 61 | + inter_size = self.network_config_["intermediate_size"] |
| 62 | + split_inter_size = inter_size // self.world_size_ |
| 63 | + |
| 64 | + if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: |
| 65 | + up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ |
| 66 | + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : |
| 67 | + ] |
| 68 | + self.up_proj = up_proj.transpose(0, 1) |
| 69 | + |
| 70 | + if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: |
| 71 | + gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ |
| 72 | + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : |
| 73 | + ] |
| 74 | + self.gate_proj = gate_proj.transpose(0, 1) |
| 75 | + |
| 76 | + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) |
| 77 | + |
| 78 | + if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: |
| 79 | + self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ |
| 80 | + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) |
| 81 | + ] |
| 82 | + self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) * self.layer_scale |
| 83 | + return |
0 commit comments