Skip to content

Commit 7a572f1

Browse files
hiworldwzjshihaobai
authored and
wangzaijun
committed
add support for phi3-mini (#433) (#435)
Co-authored-by: shihaobai <42648726+shihaobai@users.noreply.github.com>
1 parent dbfcafd commit 7a572f1

16 files changed

+1454
-31
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
4949
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
5050
- [Stablelm](https://huggingface.co/stabilityai/stablelm-2-1_6b)
5151
- [MiniCPM](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)
52+
- [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3)
5253
- [CohereForAI](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
5354

5455
> When you start Qwen-7b, you need to set the parameter '--eos_id 151643 --trust_remote_code'.
@@ -61,6 +62,8 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
6162
6263
> Stablelm needs to set the parameter '--trust_remote_code'.
6364
65+
> Phi-3 only supports Mini and Small.
66+
6467
## Get started
6568

6669
### Requirements

lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def load_hf_weights(self, weights):
1919
if "lm_head.weight" in weights:
2020
# print(weights['lm_head.weight'].shape)
2121
self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :])
22+
tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
23+
if tie_word_embeddings:
24+
self.lm_head_weight_ = self.wte_weight_
2225
if "model.norm.weight" in weights:
2326
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])
2427

lightllm/models/llama/model.py

+111-31
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import json
33
import torch
4+
import math
45
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
56
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
67
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
@@ -17,6 +18,7 @@
1718

1819
logger = init_logger(__name__)
1920

21+
2022
class LlamaTpPartModel(TpPartBaseModel):
2123
# weight class
2224
pre_and_post_weight_class = LlamaPreAndPostLayerWeight
@@ -34,14 +36,14 @@ class LlamaTpPartModel(TpPartBaseModel):
3436
def __init__(self, kvargs):
3537
super().__init__(kvargs)
3638
return
37-
39+
3840
def _init_config(self):
3941
super()._init_config()
4042
# rename key
4143
# repair_config()
4244
self._reset_num_key_value_heads()
43-
return
44-
45+
return
46+
4547
def _reset_num_key_value_heads(self):
4648
if "num_key_value_heads" not in self.config:
4749
self.config["num_key_value_heads"] = self.config["num_attention_heads"]
@@ -52,13 +54,15 @@ def _verify_params(self):
5254
assert self.config["num_key_value_heads"] % self.world_size_ == 0
5355
assert self.config["num_attention_heads"] % self.world_size_ == 0
5456
return
55-
57+
5658
def _init_mem_manager(self):
57-
self.mem_manager = select_mem_manager_class(self.mode)(self.max_total_token_num,
58-
dtype=self.data_type,
59-
head_num=self.config["num_key_value_heads"] // self.world_size_,
60-
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
61-
layer_num=self.config["num_hidden_layers"])
59+
self.mem_manager = select_mem_manager_class(self.mode)(
60+
self.max_total_token_num,
61+
dtype=self.data_type,
62+
head_num=self.config["num_key_value_heads"] // self.world_size_,
63+
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
64+
layer_num=self.config["num_hidden_layers"],
65+
)
6266
return
6367

6468
def _init_custom(self):
@@ -67,37 +71,51 @@ def _init_custom(self):
6771
"""
6872
if self.config.get("use_rope_yarn", False):
6973
self._init_to_get_yarn_rotary()
70-
elif self.config.get("use_dynamic_ntk", False) or (self.config.get("rope_scaling", None) is not None and self.config.get("rope_scaling", {}).get("type", "base") == "dynamic"):
74+
elif self.config.get("use_dynamic_ntk", False) or (
75+
self.config.get("rope_scaling", None) is not None
76+
and self.config.get("rope_scaling", {}).get("type", "base") == "dynamic"
77+
):
7178
self._init_to_get_dynamic_ntk_rotary()
79+
elif (
80+
self.config.get("rope_scaling", None) is not None
81+
and self.config.get("rope_scaling", {}).get("type", "base") == "su"
82+
):
83+
self._init_to_su_rotary()
7284
else:
7385
self._init_to_get_rotary()
7486
return
7587

7688
def _init_weights(self):
77-
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
89+
self.pre_post_weight = self.pre_and_post_weight_class(
90+
self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode
91+
)
7892
self.trans_layers_weight = [
79-
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
93+
self.transformer_weight_class(
94+
i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode
95+
)
8096
for i in range(self.config["n_layer"])
8197
]
82-
if self.load_way == 'HF':
98+
if self.load_way == "HF":
8399
load_hf_weights(
84100
self.data_type,
85101
weight_dir=self.weight_dir_,
86102
pre_post_layer=self.pre_post_weight,
87103
transformer_layer_list=self.trans_layers_weight,
88-
weight_dict=self.weight_dict)
104+
weight_dict=self.weight_dict,
105+
)
89106
else:
90107
load_ds_weights(
91108
self.data_type,
92109
weight_dir=self.weight_dir_,
93110
pre_post_layer=self.pre_post_weight,
94111
transformer_layer_list=self.trans_layers_weight,
95112
weight_dict=self.weight_dict,
96-
prefix='model.layers.',
97-
num_layer=self.config["n_layer"])
113+
prefix="model.layers.",
114+
num_layer=self.config["n_layer"],
115+
)
98116
self.pre_post_weight.verify_load()
99-
[weight.verify_load() for weight in self.trans_layers_weight]
100-
return
117+
[weight.verify_load() for weight in self.trans_layers_weight]
118+
return
101119

102120
def _init_to_get_rotary(self, default_base=10000):
103121
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
@@ -112,8 +130,7 @@ def _init_to_get_rotary(self, default_base=10000):
112130
max_seq_len = self.config["max_sequence_length"]
113131
else:
114132
max_position_embeddings = self.config.get(
115-
"max_position_embeddings",
116-
2048 if base <= 10000.0 + 1e-5 else 16384
133+
"max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384
117134
)
118135
max_seq_len = max_position_embeddings * rope_scaling_factor
119136

@@ -124,11 +141,13 @@ def _init_to_get_rotary(self, default_base=10000):
124141
if ntk_alpha > 1:
125142
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
126143
max_seq_len *= ntk_alpha
127-
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim-2))) #Base change formula
144+
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
128145
except:
129146
pass
130147

131-
inv_freq = 1.0 / (base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))
148+
inv_freq = 1.0 / (
149+
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
150+
)
132151
t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor
133152
freqs = torch.outer(t, inv_freq)
134153

@@ -147,24 +166,37 @@ def _init_to_get_dynamic_ntk_rotary(self):
147166
max_seq_len = max(self.max_seq_length, max_position_embeddings)
148167
self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda")
149168
self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda")
150-
151-
inv_freq = 1.0 / (base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))
169+
170+
inv_freq = 1.0 / (
171+
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
172+
)
152173
t = torch.arange(max_position_embeddings, device="cpu", dtype=torch.float32)
153174
freqs = torch.outer(t, inv_freq)
154175
self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(self.data_type).cuda()
155176
self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(self.data_type).cuda()
156177

157178
for seq_loc_index in range(max_position_embeddings, max_seq_len, 1):
158-
new_base = base * ((scaling_factor * (seq_loc_index + 1) / max_position_embeddings) -(scaling_factor - 1)) ** (partial_head_dim / (partial_head_dim - 2))
159-
inv_freq = 1.0 / (new_base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))
160-
t = torch.tensor([seq_loc_index,], device="cpu", dtype=torch.float32)
179+
new_base = base * (
180+
(scaling_factor * (seq_loc_index + 1) / max_position_embeddings) - (scaling_factor - 1)
181+
) ** (partial_head_dim / (partial_head_dim - 2))
182+
inv_freq = 1.0 / (
183+
new_base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
184+
)
185+
t = torch.tensor(
186+
[
187+
seq_loc_index,
188+
],
189+
device="cpu",
190+
dtype=torch.float32,
191+
)
161192
freqs = torch.outer(t, inv_freq)
162-
self._cos_cached[seq_loc_index:seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).cuda()
163-
self._sin_cached[seq_loc_index:seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).cuda()
193+
self._cos_cached[seq_loc_index : seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).cuda()
194+
self._sin_cached[seq_loc_index : seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).cuda()
164195
return
165196

166197
def _init_to_get_yarn_rotary(self):
167198
from .yarn_rotary_utils import find_correction_range, linear_ramp_mask, get_mscale
199+
168200
dim = self.head_dim_
169201
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
170202
base = self.config.get("rope_theta", 10000.0)
@@ -183,10 +215,12 @@ def _init_to_get_yarn_rotary(self):
183215
inv_freq_interpolation = 1.0 / (scale * pos_freqs)
184216

185217
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
186-
inv_freq_mask = (1 - linear_ramp_mask(low, high, dim // 2).float().cuda()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
218+
inv_freq_mask = (
219+
1 - linear_ramp_mask(low, high, dim // 2).float().cuda()
220+
) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
187221
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
188222

189-
mscale = float(get_mscale(scale) * attn_factor) # Get n-d magnitude scaling corrected for interpolation
223+
mscale = float(get_mscale(scale) * attn_factor) # Get n-d magnitude scaling corrected for interpolation
190224

191225
# Build here to make `torch.jit.trace` work.
192226
max_seq_len_cached = max_position_embeddings
@@ -199,4 +233,50 @@ def _init_to_get_yarn_rotary(self):
199233

200234
return
201235

236+
def _init_to_su_rotary(self):
237+
rope_scaling = self.config["rope_scaling"]
238+
short_factor = rope_scaling["short_factor"]
239+
long_factor = rope_scaling["long_factor"]
240+
original_max_position_embeddings = self.config["original_max_position_embeddings"]
241+
max_position_embeddings = self.config.get("max_position_embeddings", original_max_position_embeddings)
242+
base = self.config.get("rope_theta", 10000.0)
243+
short_factor = torch.tensor(short_factor, dtype=torch.float32, device="cpu")
244+
long_factor = torch.tensor(long_factor, dtype=torch.float32, device="cpu")
245+
246+
scale = max_position_embeddings / original_max_position_embeddings
247+
if scale <= 1.0:
248+
rope_scaling_factor = 1.0
249+
else:
250+
rope_scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings))
251+
252+
max_seq_len = max(self.max_seq_length, max_position_embeddings)
253+
self._cos_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=self.data_type, device="cuda")
254+
self._sin_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=self.data_type, device="cuda")
255+
256+
inv_freq = 1.0 / (
257+
short_factor
258+
* base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_)
259+
)
260+
t = torch.arange(original_max_position_embeddings, device="cpu", dtype=torch.float32)
261+
freqs = torch.outer(t, inv_freq)
262+
self._cos_cached[0:original_max_position_embeddings, :] = (
263+
(torch.cos(freqs) * rope_scaling_factor).to(self.data_type).cuda()
264+
)
265+
self._sin_cached[0:original_max_position_embeddings, :] = (
266+
(torch.sin(freqs) * rope_scaling_factor).to(self.data_type).cuda()
267+
)
202268

269+
inv_freq = 1.0 / (
270+
long_factor
271+
* base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_)
272+
)
273+
t = torch.arange(original_max_position_embeddings, max_seq_len, device="cpu", dtype=torch.float32)
274+
freqs = torch.outer(t, inv_freq)
275+
self._cos_cached[original_max_position_embeddings:, :] = (
276+
(torch.cos(freqs) * rope_scaling_factor).to(self.data_type).cuda()
277+
)
278+
self._sin_cached[original_max_position_embeddings:, :] = (
279+
(torch.sin(freqs) * rope_scaling_factor).to(self.data_type).cuda()
280+
)
281+
282+
return

lightllm/models/phi3/__init__.py

Whitespace-only changes.

lightllm/models/phi3/layer_infer/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
import torch.functional as F
3+
import torch.distributed as dist
4+
import numpy as np
5+
from functools import partial
6+
7+
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
8+
from lightllm.models.phi3.triton_kernel.rotary_emb import rotary_emb_fwd
9+
from lightllm.models.phi3.triton_kernel.context_flashattention_nopad import (
10+
context_attention_fwd,
11+
context_attention_fwd_no_prompt_cache,
12+
)
13+
from lightllm.models.phi3.triton_kernel.destindex_copy_kv import destindex_copy_kv
14+
from lightllm.models.phi3.layer_weights.transformer_layer_weight import Phi3TransformerLayerWeight
15+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
16+
17+
18+
class Phi3TransformerLayerInfer(LlamaTransformerLayerInfer):
19+
""" """
20+
21+
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
22+
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
23+
return
24+
25+
def _bind_attention(self):
26+
self._context_attention_kernel = partial(Phi3TransformerLayerInfer._context_attention_kernel, self)
27+
self._copy_kv_to_mem_cache = partial(Phi3TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
28+
self._token_attention_kernel = partial(Phi3TransformerLayerInfer._token_decode_attention_flashdecoding, self)
29+
return
30+
31+
def _get_qkv(self, input_emb, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: Phi3TransformerLayerWeight):
32+
q = torch.mm(input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_)
33+
torch.mm(
34+
input_emb.view(-1, self.embed_dim_),
35+
layer_weight.kv_weight_,
36+
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
37+
)
38+
rotary_emb_fwd(
39+
q.view(-1, self.tp_q_head_num_, self.head_dim_),
40+
cache_kv[:, 0 : self.tp_k_head_num_, :],
41+
infer_state.position_cos,
42+
infer_state.position_sin,
43+
)
44+
return q, cache_kv
45+
46+
def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager):
47+
destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_])
48+
return
49+
50+
def _context_attention_kernel(
51+
self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None
52+
) -> torch.Tensor:
53+
o_tensor = torch.empty_like(q) if out is None else out
54+
if infer_state.use_dynamic_prompt_cache:
55+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
56+
context_attention_fwd(
57+
q.view(-1, self.tp_q_head_num_, self.head_dim_),
58+
kv[:, 0 : self.tp_k_head_num_, :],
59+
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
60+
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
61+
infer_state.b_req_idx,
62+
infer_state.b_start_loc,
63+
infer_state.b_seq_len,
64+
infer_state.b_ready_cache_len,
65+
infer_state.max_len_in_batch,
66+
infer_state.req_manager.req_to_token_indexs,
67+
)
68+
else:
69+
context_attention_fwd_no_prompt_cache(
70+
q.view(-1, self.tp_q_head_num_, self.head_dim_),
71+
kv[:, 0 : self.tp_k_head_num_, :],
72+
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
73+
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
74+
infer_state.b_start_loc,
75+
infer_state.b_seq_len,
76+
infer_state.max_len_in_batch,
77+
)
78+
79+
return o_tensor
80+
81+
def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
82+
from lightllm.models.phi3.triton_kernel.flash_decoding import token_decode_attention_flash_decoding
83+
84+
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
85+
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
86+
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
87+
]
88+
return token_decode_attention_flash_decoding(
89+
q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out
90+
)

lightllm/models/phi3/layer_weights/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)