Skip to content

Commit f7b9937

Browse files
Support internlm2 wquant. (#345)
1 parent b1a73ec commit f7b9937

File tree

6 files changed

+121
-11
lines changed

6 files changed

+121
-11
lines changed

lightllm/models/internlm2_wquant/__init__.py

Whitespace-only changes.

lightllm/models/internlm2_wquant/layer_weights/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from lightllm.models.internlm_wquant.layer_weights.transformer_layer_weight import InternlmTransformerLayerWeightQuantized
2+
3+
4+
class Internlm2TransformerLayerWeightQuantized(InternlmTransformerLayerWeightQuantized):
5+
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]):
6+
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode)
7+
return
8+
9+
def _load_qkvo_weights(self, weights):
10+
# input layernorm params
11+
if f"model.layers.{self.layer_num_}.attention_norm.weight" in weights:
12+
self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.attention_norm.weight"])
13+
14+
n_embed = self.network_config_["hidden_size"]
15+
q_split_n_embed = n_embed // self.world_size_
16+
kv_split_n_embed = (
17+
n_embed
18+
// self.network_config_["num_attention_heads"]
19+
* self.network_config_["num_key_value_heads"]
20+
// self.world_size_
21+
)
22+
head_dim = n_embed // self.network_config_["num_attention_heads"]
23+
# q k v weights for llama
24+
if f"model.layers.{self.layer_num_}.attention.wqkv.weight" in weights:
25+
qkv_weight_ = weights[f"model.layers.{self.layer_num_}.attention.wqkv.weight"]
26+
q_groups = self.network_config_["num_attention_heads"] // self.network_config_["num_key_value_heads"]
27+
qkv_weight_ = qkv_weight_.reshape(self.network_config_["num_key_value_heads"], q_groups + 2, head_dim, -1)
28+
q_weight_ = qkv_weight_[:, :q_groups, :, :].reshape(-1, qkv_weight_.shape[-1])
29+
q_weight_ = q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) :].transpose(0, 1)
30+
self.q_weight_ = self.quantize_weight(q_weight_)
31+
32+
k_weight_ = qkv_weight_[:, -2, :, :].reshape(-1, qkv_weight_.shape[-1])
33+
self.k_weight_ = k_weight_[
34+
kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) :
35+
].transpose(0, 1)
36+
v_weight_ = qkv_weight_[:, -1, :, :].reshape(-1, qkv_weight_.shape[-1])
37+
self.v_weight_ = v_weight_[
38+
kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) :
39+
].transpose(0, 1)
40+
41+
self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1, handle_func=self.quantize_weight)
42+
43+
# attention output dense params
44+
if f"model.layers.{self.layer_num_}.attention.wo.weight" in weights:
45+
self.o_weight_ = weights[f"model.layers.{self.layer_num_}.attention.wo.weight"]
46+
self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)]
47+
self.o_weight_ = self.quantize_weight(self.o_weight_.transpose(0, 1))
48+
if f"model.layers.{self.layer_num_}.attention.wo.bias" in weights:
49+
self.o_bias_ = weights[f"model.layers.{self.layer_num_}.attention.wo.bias"]
50+
self.o_bias_ = self._cuda(self.o_bias_)
51+
return
52+
53+
def _load_ffn_weights(self, weights):
54+
if f"model.layers.{self.layer_num_}.ffn_norm.weight" in weights:
55+
self.ffn_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.ffn_norm.weight"])
56+
57+
inter_size = self.network_config_["intermediate_size"]
58+
split_inter_size = inter_size // self.world_size_
59+
60+
if f"model.layers.{self.layer_num_}.feed_forward.w3.weight" in weights:
61+
up_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w3.weight"][
62+
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
63+
]
64+
self.up_proj = up_proj.transpose(0, 1)
65+
66+
if f"model.layers.{self.layer_num_}.feed_forward.w1.weight" in weights:
67+
gate_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w1.weight"][
68+
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
69+
]
70+
self.gate_proj = gate_proj.transpose(0, 1)
71+
72+
self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1, handle_func=self.quantize_weight)
73+
74+
if f"model.layers.{self.layer_num_}.feed_forward.w2.weight" in weights:
75+
self.down_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w2.weight"][
76+
:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)
77+
]
78+
self.down_proj = self.quantize_weight(self.down_proj.transpose(0, 1))
79+
return
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
import json
3+
import torch
4+
5+
from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
6+
from lightllm.models.internlm2_wquant.layer_weights.transformer_layer_weight import Internlm2TransformerLayerWeightQuantized
7+
from lightllm.models.internlm_wquant.model import InternlmTpPartModelWQuant
8+
from lightllm.common.mem_utils import select_mem_manager_class
9+
10+
11+
class Internlm2TpPartModelWQuant(InternlmTpPartModelWQuant):
12+
# weight class
13+
pre_and_post_weight_class = Internlm2PreAndPostLayerWeight
14+
transformer_weight_class = Internlm2TransformerLayerWeightQuantized
15+
16+
def __init__(self, kvargs):
17+
super().__init__(kvargs)
18+
19+
def _verify_params(self):
20+
assert self.load_way in ["HF", "DS"], "llama only supports HF and DS format to load Now!"
21+
assert any("w4a16" in mode_ or "w8a16" in mode_ for mode_ in self.mode), "only for weight quant model"
22+
assert self.config["num_key_value_heads"] % self.world_size_ == 0
23+
assert self.config["num_attention_heads"] % self.world_size_ == 0
24+
return
25+
26+
def _init_mem_manager(self):
27+
self.mem_manager = select_mem_manager_class(self.mode)(self.max_total_token_num,
28+
dtype=torch.float16,
29+
head_num=self.config["num_key_value_heads"] // self.world_size_,
30+
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
31+
layer_num=self.config["num_hidden_layers"],
32+
always_copy=True)
33+
return

lightllm/models/internlm_wquant/layer_weights/transformer_layer_weight.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def verify_load(self):
1616

1717
# handle internlm 20b, which has no bias, so set q k v o bias to zero
1818
if not self.network_config_.get("bias", True):
19-
for layer_type in ("q", "k", "v", "o"):
19+
for layer_type in ("q", "kv", "o"):
2020
attr_name = f"{layer_type}_bias_"
2121
if hasattr(self, attr_name):
2222
continue
@@ -44,15 +44,13 @@ def _load_qkvo_weights(self, weights):
4444
self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"])
4545

4646
n_embed = self.network_config_["hidden_size"]
47-
4847
q_split_n_embed = n_embed // self.world_size_
4948
kv_split_n_embed = (
5049
n_embed
5150
// self.network_config_["num_attention_heads"]
5251
* self.network_config_["num_key_value_heads"]
5352
// self.world_size_
5453
)
55-
5654
# q k v weights for llama
5755
if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights:
5856
q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"]
@@ -71,17 +69,14 @@ def _load_qkvo_weights(self, weights):
7169
k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"]
7270
k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
7371
self.k_weight_ = k_weight_.transpose(0, 1).to(self.data_type_)
74-
7572
if f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" in weights:
7673
self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][
7774
kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)
7875
]
79-
8076
if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights:
8177
v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"]
8278
v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
8379
self.v_weight_ = v_weight_.transpose(0, 1).to(self.data_type_)
84-
8580
if f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" in weights:
8681
self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][
8782
kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)

lightllm/server/router/model_infer/model_rpc.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from lightllm.models.stablelm.model import StablelmTpPartModel
2828
from lightllm.models.internlm2.model import Internlm2TpPartModel
2929
from lightllm.models.internlm_wquant.model import InternlmTpPartModelWQuant
30+
from lightllm.models.internlm2_wquant.model import Internlm2TpPartModelWQuant
3031
from lightllm.models.yi.model import YiTpPartModel
3132
from lightllm.models.mistral.model import MistralTpPartModel
3233
from lightllm.models.minicpm.model import MiniCPMTpPartModel
@@ -126,14 +127,16 @@ def exposed_init_model(self, kvargs):
126127
self.model = StarcoderTpPartModel(model_kvargs)
127128
elif self.model_type == 'chatglm':
128129
self.model = ChatGlm2TpPartModel(model_kvargs)
129-
elif self.model_type == 'internlm' or self.model_type == 'internlm2':
130+
elif self.model_type == 'internlm':
130131
if any('w8a16' in mode_ or 'w4a16' in mode_ for mode_ in self.mode):
131132
self.model = InternlmTpPartModelWQuant(model_kvargs)
132133
else:
133-
if model_cfg["architectures"][0] == 'InternLM2ForCausalLM':
134-
self.model = Internlm2TpPartModel(model_kvargs)
135-
else:
136-
self.model = InternlmTpPartModel(model_kvargs)
134+
self.model = InternlmTpPartModel(model_kvargs)
135+
elif self.model_type == 'internlm2':
136+
if any('w8a16' in mode_ or 'w4a16' in mode_ for mode_ in self.mode):
137+
self.model = Internlm2TpPartModelWQuant(model_kvargs)
138+
else:
139+
self.model = Internlm2TpPartModel(model_kvargs)
137140
elif self.model_type == "Yi":
138141
self.model = YiTpPartModel(model_kvargs)
139142
elif self.model_type == "mistral":

0 commit comments

Comments
 (0)