Skip to content

Commit f50db5c

Browse files
authored
add support for MiniCPM (#338)
1 parent 35f6d85 commit f50db5c

File tree

10 files changed

+194
-3
lines changed

10 files changed

+194
-3
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
4747
- [Llava-13b](https://huggingface.co/liuhaotian/llava-v1.5-13b)
4848
- [Mixtral]()
4949
- [Stablelm](https://huggingface.co/stabilityai/stablelm-2-1_6b)
50+
- [MiniCPM](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)
5051

5152
> When you start Qwen-7b, you need to set the parameter '--eos_id 151643 --trust_remote_code'.
5253

lightllm/models/minicpm/__init__.py

Whitespace-only changes.

lightllm/models/minicpm/layer_infer/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
import torch.functional as F
3+
import torch.distributed as dist
4+
import numpy as np
5+
6+
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
7+
from lightllm.models.internlm.layer_weights.transformer_layer_weight import InternlmTransformerLayerWeight
8+
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
9+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
10+
11+
12+
class InternlmTransformerLayerInfer(LlamaTransformerLayerInfer):
13+
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
14+
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
15+
return
16+
17+
def _get_qkv(
18+
self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight
19+
) -> torch.Tensor:
20+
q = torch.addmm(
21+
layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0
22+
)
23+
torch.addmm(
24+
layer_weight.kv_bias_,
25+
input.view(-1, self.embed_dim_),
26+
layer_weight.kv_weight_,
27+
beta=1.0,
28+
alpha=1.0,
29+
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
30+
)
31+
rotary_emb_fwd(
32+
q.view(-1, self.tp_q_head_num_, self.head_dim_),
33+
cache_kv[:, 0 : self.tp_k_head_num_, :],
34+
infer_state.position_cos,
35+
infer_state.position_sin,
36+
)
37+
return q, cache_kv
38+
39+
def _get_o(
40+
self, input, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight
41+
) -> torch.Tensor:
42+
o_tensor = torch.addmm(
43+
layer_weight.o_bias_,
44+
input.view(-1, self.tp_o_head_num_ * self.head_dim_),
45+
layer_weight.o_weight_,
46+
beta=1.0 / self.world_size_,
47+
)
48+
return o_tensor

lightllm/models/minicpm/layer_weights/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
import numpy as np
3+
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
4+
5+
6+
class MiniCPMPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
7+
def __init__(self, tp_rank, world_size, data_type, network_config, mode):
8+
super().__init__(tp_rank, world_size, data_type, network_config, mode)
9+
hidden_size = self.network_config_["hidden_size"]
10+
dim_model_base = self.network_config_.get("dim_model_base", hidden_size)
11+
self.lm_head_scale = hidden_size / dim_model_base
12+
self.scale_emb = self.network_config_.get("scale_emb", 1)
13+
return
14+
15+
def load_hf_weights(self, weights):
16+
vob_size = self.network_config_["vocab_size"]
17+
split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64)
18+
split_start = split_indexes[self.tp_rank_]
19+
split_end = split_indexes[self.tp_rank_ + 1]
20+
if "model.embed_tokens.weight" in weights:
21+
# print(weights['model.embed_tokens.weight'].shape)
22+
self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :])
23+
if "lm_head.weight" in weights:
24+
# print(weights['lm_head.weight'].shape)
25+
self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) / self.lm_head_scale
26+
if "model.norm.weight" in weights:
27+
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])
28+
29+
return
30+
31+
def verify_load(self):
32+
if not hasattr(self, "lm_head_weight_"):
33+
self.lm_head_weight_ = self.wte_weight_ / self.lm_head_scale
34+
self.wte_weight_ = self.wte_weight_ * self.scale_emb
35+
errors = "weights load not ok"
36+
weights = [self.wte_weight_, self.lm_head_weight_, self.final_norm_weight_]
37+
for i in range(len(weights)):
38+
assert weights[i] is not None, "index:" + str(i) + " " + errors
39+
return
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torch
2+
import math
3+
import numpy as np
4+
from lightllm.common.basemodel import TransformerLayerWeight
5+
6+
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
7+
8+
9+
class MiniCPMTransformerLayerWeight(LlamaTransformerLayerWeight):
10+
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]):
11+
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode)
12+
num_hidden_layers = self.network_config_["num_hidden_layers"]
13+
scale_depth = self.network_config_.get("scale_depth", math.sqrt(num_hidden_layers))
14+
self.layer_scale =scale_depth / math.sqrt(num_hidden_layers)
15+
return
16+
17+
def _load_qkvo_weights(self, weights):
18+
# input layernorm params
19+
if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights:
20+
self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"])
21+
22+
n_embed = self.network_config_["hidden_size"]
23+
q_split_n_embed = n_embed // self.world_size_
24+
kv_split_n_embed = (
25+
n_embed
26+
// self.network_config_["num_attention_heads"]
27+
* self.network_config_["num_key_value_heads"]
28+
// self.world_size_
29+
)
30+
if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights:
31+
self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"]
32+
self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :]
33+
self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1))
34+
35+
if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights:
36+
k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"]
37+
k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
38+
self.k_weight_ = k_weight_.transpose(0, 1)
39+
40+
if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights:
41+
v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"]
42+
v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
43+
self.v_weight_ = v_weight_.transpose(0, 1)
44+
45+
# attention output dense params
46+
if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights:
47+
self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"]
48+
self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)]
49+
self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) * self.layer_scale
50+
51+
self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1)
52+
53+
return
54+
55+
def _load_ffn_weights(self, weights):
56+
if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights:
57+
self.ffn_norm_weight_ = self._cuda(
58+
weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"]
59+
)
60+
61+
inter_size = self.network_config_["intermediate_size"]
62+
split_inter_size = inter_size // self.world_size_
63+
64+
if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights:
65+
up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][
66+
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
67+
]
68+
self.up_proj = up_proj.transpose(0, 1)
69+
70+
if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights:
71+
gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][
72+
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
73+
]
74+
self.gate_proj = gate_proj.transpose(0, 1)
75+
76+
self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1)
77+
78+
if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights:
79+
self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][
80+
:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)
81+
]
82+
self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) * self.layer_scale
83+
return

lightllm/models/minicpm/model.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import os
2+
import json
3+
import torch
4+
from lightllm.models.minicpm.layer_weights.transformer_layer_weight import MiniCPMTransformerLayerWeight
5+
from lightllm.models.minicpm.layer_weights.pre_and_post_layer_weight import MiniCPMPreAndPostLayerWeight
6+
from lightllm.models.llama.model import LlamaTpPartModel
7+
8+
9+
class MiniCPMTpPartModel(LlamaTpPartModel):
10+
# weight class
11+
transformer_weight_class = MiniCPMTransformerLayerWeight
12+
pre_and_post_weight_class = MiniCPMPreAndPostLayerWeight
13+
14+
def __init__(self, kvargs):
15+
super().__init__(kvargs)
16+

lightllm/server/router/model_infer/model_rpc.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from lightllm.models.internlm_wquant.model import InternlmTpPartModelWQuant
3030
from lightllm.models.yi.model import YiTpPartModel
3131
from lightllm.models.mistral.model import MistralTpPartModel
32+
from lightllm.models.minicpm.model import MiniCPMTpPartModel
3233
from lightllm.models.llava.model import LlavaTpPartModel
3334
from lightllm.models.qwen_vl.model import QWenVLTpPartModel
3435
from lightllm.models.internlm_xcomposer.model import InternlmComposerTpPartModel
@@ -87,7 +88,7 @@ def exposed_init_model(self, kvargs):
8788
}
8889

8990
try:
90-
self.model_type = model_cfg["model_type"]
91+
self.model_type = model_cfg.get("model_type", "")
9192
if self.model_type == "bloom":
9293
self.model = BloomTpPartModel(model_kvargs)
9394
elif self.model_type == "llama":
@@ -141,6 +142,8 @@ def exposed_init_model(self, kvargs):
141142
self.model = StablelmTpPartModel(model_kvargs)
142143
elif self.model_type == "mixtral":
143144
self.model = MixtralTpPartModel(model_kvargs)
145+
elif self.model_type == "minicpm" or model_cfg["architectures"][0]=="MiniCPMForCausalLM":
146+
self.model = MiniCPMTpPartModel(model_kvargs)
144147
elif self.model_type == "llava":
145148
self.model = LlavaTpPartModel(model_kvargs)
146149
self.is_multimodal = True

lightllm/server/tokenizer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ def get_tokenizer(
6969
**kwargs)
7070

7171
model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name)
72-
if model_cfg["model_type"] == "llava" or model_cfg["model_type"] == "internlmxcomposer2":
72+
model_type = model_cfg.get("model_type", "")
73+
if model_type == "llava" or model_type == "internlmxcomposer2":
7374
tokenizer = LlavaTokenizer(tokenizer, model_cfg)
74-
elif model_cfg["model_type"] == "qwen" and "visual" in model_cfg:
75+
elif model_type == "qwen" and "visual" in model_cfg:
7576
tokenizer = QWenVLTokenizer(tokenizer, model_cfg)
7677

7778
if not isinstance(tokenizer, PreTrainedTokenizerFast):

0 commit comments

Comments
 (0)