1
1
import os
2
2
import json
3
3
import torch
4
+ import math
4
5
from lightllm .models .llama .layer_infer .pre_layer_infer import LlamaPreLayerInfer
5
6
from lightllm .models .llama .layer_infer .post_layer_infer import LlamaPostLayerInfer
6
7
from lightllm .models .llama .layer_infer .transformer_layer_infer import LlamaTransformerLayerInfer
17
18
18
19
logger = init_logger (__name__ )
19
20
21
+
20
22
class LlamaTpPartModel (TpPartBaseModel ):
21
23
# weight class
22
24
pre_and_post_weight_class = LlamaPreAndPostLayerWeight
@@ -34,14 +36,14 @@ class LlamaTpPartModel(TpPartBaseModel):
34
36
def __init__ (self , kvargs ):
35
37
super ().__init__ (kvargs )
36
38
return
37
-
39
+
38
40
def _init_config (self ):
39
41
super ()._init_config ()
40
42
# rename key
41
43
# repair_config()
42
44
self ._reset_num_key_value_heads ()
43
- return
44
-
45
+ return
46
+
45
47
def _reset_num_key_value_heads (self ):
46
48
if "num_key_value_heads" not in self .config :
47
49
self .config ["num_key_value_heads" ] = self .config ["num_attention_heads" ]
@@ -52,13 +54,15 @@ def _verify_params(self):
52
54
assert self .config ["num_key_value_heads" ] % self .world_size_ == 0
53
55
assert self .config ["num_attention_heads" ] % self .world_size_ == 0
54
56
return
55
-
57
+
56
58
def _init_mem_manager (self ):
57
- self .mem_manager = select_mem_manager_class (self .mode )(self .max_total_token_num ,
58
- dtype = self .data_type ,
59
- head_num = self .config ["num_key_value_heads" ] // self .world_size_ ,
60
- head_dim = self .config ["hidden_size" ] // self .config ["num_attention_heads" ],
61
- layer_num = self .config ["num_hidden_layers" ])
59
+ self .mem_manager = select_mem_manager_class (self .mode )(
60
+ self .max_total_token_num ,
61
+ dtype = self .data_type ,
62
+ head_num = self .config ["num_key_value_heads" ] // self .world_size_ ,
63
+ head_dim = self .config ["hidden_size" ] // self .config ["num_attention_heads" ],
64
+ layer_num = self .config ["num_hidden_layers" ],
65
+ )
62
66
return
63
67
64
68
def _init_custom (self ):
@@ -67,37 +71,51 @@ def _init_custom(self):
67
71
"""
68
72
if self .config .get ("use_rope_yarn" , False ):
69
73
self ._init_to_get_yarn_rotary ()
70
- elif self .config .get ("use_dynamic_ntk" , False ) or (self .config .get ("rope_scaling" , None ) is not None and self .config .get ("rope_scaling" , {}).get ("type" , "base" ) == "dynamic" ):
74
+ elif self .config .get ("use_dynamic_ntk" , False ) or (
75
+ self .config .get ("rope_scaling" , None ) is not None
76
+ and self .config .get ("rope_scaling" , {}).get ("type" , "base" ) == "dynamic"
77
+ ):
71
78
self ._init_to_get_dynamic_ntk_rotary ()
79
+ elif (
80
+ self .config .get ("rope_scaling" , None ) is not None
81
+ and self .config .get ("rope_scaling" , {}).get ("type" , "base" ) == "su"
82
+ ):
83
+ self ._init_to_su_rotary ()
72
84
else :
73
85
self ._init_to_get_rotary ()
74
86
return
75
87
76
88
def _init_weights (self ):
77
- self .pre_post_weight = self .pre_and_post_weight_class (self .tp_rank_ , self .world_size_ , self .data_type , network_config = self .config , mode = self .mode )
89
+ self .pre_post_weight = self .pre_and_post_weight_class (
90
+ self .tp_rank_ , self .world_size_ , self .data_type , network_config = self .config , mode = self .mode
91
+ )
78
92
self .trans_layers_weight = [
79
- self .transformer_weight_class (i , self .tp_rank_ , self .world_size_ , self .data_type , network_config = self .config , mode = self .mode )
93
+ self .transformer_weight_class (
94
+ i , self .tp_rank_ , self .world_size_ , self .data_type , network_config = self .config , mode = self .mode
95
+ )
80
96
for i in range (self .config ["n_layer" ])
81
97
]
82
- if self .load_way == 'HF' :
98
+ if self .load_way == "HF" :
83
99
load_hf_weights (
84
100
self .data_type ,
85
101
weight_dir = self .weight_dir_ ,
86
102
pre_post_layer = self .pre_post_weight ,
87
103
transformer_layer_list = self .trans_layers_weight ,
88
- weight_dict = self .weight_dict )
104
+ weight_dict = self .weight_dict ,
105
+ )
89
106
else :
90
107
load_ds_weights (
91
108
self .data_type ,
92
109
weight_dir = self .weight_dir_ ,
93
110
pre_post_layer = self .pre_post_weight ,
94
111
transformer_layer_list = self .trans_layers_weight ,
95
112
weight_dict = self .weight_dict ,
96
- prefix = 'model.layers.' ,
97
- num_layer = self .config ["n_layer" ])
113
+ prefix = "model.layers." ,
114
+ num_layer = self .config ["n_layer" ],
115
+ )
98
116
self .pre_post_weight .verify_load ()
99
- [weight .verify_load () for weight in self .trans_layers_weight ]
100
- return
117
+ [weight .verify_load () for weight in self .trans_layers_weight ]
118
+ return
101
119
102
120
def _init_to_get_rotary (self , default_base = 10000 ):
103
121
partial_head_dim = int (self .config .get ("partial_rotary_factor" , 1 ) * self .head_dim_ )
@@ -112,8 +130,7 @@ def _init_to_get_rotary(self, default_base=10000):
112
130
max_seq_len = self .config ["max_sequence_length" ]
113
131
else :
114
132
max_position_embeddings = self .config .get (
115
- "max_position_embeddings" ,
116
- 2048 if base <= 10000.0 + 1e-5 else 16384
133
+ "max_position_embeddings" , 2048 if base <= 10000.0 + 1e-5 else 16384
117
134
)
118
135
max_seq_len = max_position_embeddings * rope_scaling_factor
119
136
@@ -124,11 +141,13 @@ def _init_to_get_rotary(self, default_base=10000):
124
141
if ntk_alpha > 1 :
125
142
logger .info (f"Note: NTK enabled, alpha set to { ntk_alpha } " )
126
143
max_seq_len *= ntk_alpha
127
- base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2 ))) # Base change formula
144
+ base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2 ))) # Base change formula
128
145
except :
129
146
pass
130
147
131
- inv_freq = 1.0 / (base ** (torch .arange (0 , partial_head_dim , 2 , device = "cpu" , dtype = torch .float32 ) / partial_head_dim ))
148
+ inv_freq = 1.0 / (
149
+ base ** (torch .arange (0 , partial_head_dim , 2 , device = "cpu" , dtype = torch .float32 ) / partial_head_dim )
150
+ )
132
151
t = torch .arange (max_seq_len + 1024 * 128 , device = "cpu" , dtype = torch .float32 ) / rope_scaling_factor
133
152
freqs = torch .outer (t , inv_freq )
134
153
@@ -147,24 +166,37 @@ def _init_to_get_dynamic_ntk_rotary(self):
147
166
max_seq_len = max (self .max_seq_length , max_position_embeddings )
148
167
self ._cos_cached = torch .zeros ((max_seq_len , partial_head_dim // 2 ), dtype = self .data_type , device = "cuda" )
149
168
self ._sin_cached = torch .zeros ((max_seq_len , partial_head_dim // 2 ), dtype = self .data_type , device = "cuda" )
150
-
151
- inv_freq = 1.0 / (base ** (torch .arange (0 , partial_head_dim , 2 , device = "cpu" , dtype = torch .float32 ) / partial_head_dim ))
169
+
170
+ inv_freq = 1.0 / (
171
+ base ** (torch .arange (0 , partial_head_dim , 2 , device = "cpu" , dtype = torch .float32 ) / partial_head_dim )
172
+ )
152
173
t = torch .arange (max_position_embeddings , device = "cpu" , dtype = torch .float32 )
153
174
freqs = torch .outer (t , inv_freq )
154
175
self ._cos_cached [0 :max_position_embeddings , :] = torch .cos (freqs ).to (self .data_type ).cuda ()
155
176
self ._sin_cached [0 :max_position_embeddings , :] = torch .sin (freqs ).to (self .data_type ).cuda ()
156
177
157
178
for seq_loc_index in range (max_position_embeddings , max_seq_len , 1 ):
158
- new_base = base * ((scaling_factor * (seq_loc_index + 1 ) / max_position_embeddings ) - (scaling_factor - 1 )) ** (partial_head_dim / (partial_head_dim - 2 ))
159
- inv_freq = 1.0 / (new_base ** (torch .arange (0 , partial_head_dim , 2 , device = "cpu" , dtype = torch .float32 ) / partial_head_dim ))
160
- t = torch .tensor ([seq_loc_index ,], device = "cpu" , dtype = torch .float32 )
179
+ new_base = base * (
180
+ (scaling_factor * (seq_loc_index + 1 ) / max_position_embeddings ) - (scaling_factor - 1 )
181
+ ) ** (partial_head_dim / (partial_head_dim - 2 ))
182
+ inv_freq = 1.0 / (
183
+ new_base ** (torch .arange (0 , partial_head_dim , 2 , device = "cpu" , dtype = torch .float32 ) / partial_head_dim )
184
+ )
185
+ t = torch .tensor (
186
+ [
187
+ seq_loc_index ,
188
+ ],
189
+ device = "cpu" ,
190
+ dtype = torch .float32 ,
191
+ )
161
192
freqs = torch .outer (t , inv_freq )
162
- self ._cos_cached [seq_loc_index : seq_loc_index + 1 , :] = torch .cos (freqs ).to (self .data_type ).cuda ()
163
- self ._sin_cached [seq_loc_index : seq_loc_index + 1 , :] = torch .sin (freqs ).to (self .data_type ).cuda ()
193
+ self ._cos_cached [seq_loc_index : seq_loc_index + 1 , :] = torch .cos (freqs ).to (self .data_type ).cuda ()
194
+ self ._sin_cached [seq_loc_index : seq_loc_index + 1 , :] = torch .sin (freqs ).to (self .data_type ).cuda ()
164
195
return
165
196
166
197
def _init_to_get_yarn_rotary (self ):
167
198
from .yarn_rotary_utils import find_correction_range , linear_ramp_mask , get_mscale
199
+
168
200
dim = self .head_dim_
169
201
max_position_embeddings = self .config .get ("max_position_embeddings" , 2048 )
170
202
base = self .config .get ("rope_theta" , 10000.0 )
@@ -183,10 +215,12 @@ def _init_to_get_yarn_rotary(self):
183
215
inv_freq_interpolation = 1.0 / (scale * pos_freqs )
184
216
185
217
low , high = find_correction_range (beta_fast , beta_slow , dim , base , original_max_position_embeddings )
186
- inv_freq_mask = (1 - linear_ramp_mask (low , high , dim // 2 ).float ().cuda ()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
218
+ inv_freq_mask = (
219
+ 1 - linear_ramp_mask (low , high , dim // 2 ).float ().cuda ()
220
+ ) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
187
221
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask ) + inv_freq_extrapolation * inv_freq_mask
188
222
189
- mscale = float (get_mscale (scale ) * attn_factor ) # Get n-d magnitude scaling corrected for interpolation
223
+ mscale = float (get_mscale (scale ) * attn_factor ) # Get n-d magnitude scaling corrected for interpolation
190
224
191
225
# Build here to make `torch.jit.trace` work.
192
226
max_seq_len_cached = max_position_embeddings
@@ -199,4 +233,50 @@ def _init_to_get_yarn_rotary(self):
199
233
200
234
return
201
235
236
+ def _init_to_su_rotary (self ):
237
+ rope_scaling = self .config ["rope_scaling" ]
238
+ short_factor = rope_scaling ["short_factor" ]
239
+ long_factor = rope_scaling ["long_factor" ]
240
+ original_max_position_embeddings = self .config ["original_max_position_embeddings" ]
241
+ max_position_embeddings = self .config .get ("max_position_embeddings" , original_max_position_embeddings )
242
+ base = self .config .get ("rope_theta" , 10000.0 )
243
+ short_factor = torch .tensor (short_factor , dtype = torch .float32 , device = "cpu" )
244
+ long_factor = torch .tensor (long_factor , dtype = torch .float32 , device = "cpu" )
245
+
246
+ scale = max_position_embeddings / original_max_position_embeddings
247
+ if scale <= 1.0 :
248
+ rope_scaling_factor = 1.0
249
+ else :
250
+ rope_scaling_factor = math .sqrt (1 + math .log (scale ) / math .log (original_max_position_embeddings ))
251
+
252
+ max_seq_len = max (self .max_seq_length , max_position_embeddings )
253
+ self ._cos_cached = torch .zeros ((max_seq_len , self .head_dim_ // 2 ), dtype = self .data_type , device = "cuda" )
254
+ self ._sin_cached = torch .zeros ((max_seq_len , self .head_dim_ // 2 ), dtype = self .data_type , device = "cuda" )
255
+
256
+ inv_freq = 1.0 / (
257
+ short_factor
258
+ * base ** (torch .arange (0 , self .head_dim_ , 2 , device = "cpu" , dtype = torch .float32 ) / self .head_dim_ )
259
+ )
260
+ t = torch .arange (original_max_position_embeddings , device = "cpu" , dtype = torch .float32 )
261
+ freqs = torch .outer (t , inv_freq )
262
+ self ._cos_cached [0 :original_max_position_embeddings , :] = (
263
+ (torch .cos (freqs ) * rope_scaling_factor ).to (self .data_type ).cuda ()
264
+ )
265
+ self ._sin_cached [0 :original_max_position_embeddings , :] = (
266
+ (torch .sin (freqs ) * rope_scaling_factor ).to (self .data_type ).cuda ()
267
+ )
202
268
269
+ inv_freq = 1.0 / (
270
+ long_factor
271
+ * base ** (torch .arange (0 , self .head_dim_ , 2 , device = "cpu" , dtype = torch .float32 ) / self .head_dim_ )
272
+ )
273
+ t = torch .arange (original_max_position_embeddings , max_seq_len , device = "cpu" , dtype = torch .float32 )
274
+ freqs = torch .outer (t , inv_freq )
275
+ self ._cos_cached [original_max_position_embeddings :, :] = (
276
+ (torch .cos (freqs ) * rope_scaling_factor ).to (self .data_type ).cuda ()
277
+ )
278
+ self ._sin_cached [original_max_position_embeddings :, :] = (
279
+ (torch .sin (freqs ) * rope_scaling_factor ).to (self .data_type ).cuda ()
280
+ )
281
+
282
+ return
0 commit comments