@@ -140,7 +140,10 @@ def _init_to_get_dynamic_ntk_rotary(self):
140
140
partial_head_dim = int (self .config .get ("partial_rotary_factor" , 1 ) * self .head_dim_ )
141
141
max_position_embeddings = self .config .get ("max_position_embeddings" , 2048 )
142
142
base = self .config .get ("rope_theta" , 10000.0 )
143
- scaling_factor = self .config .get ("rope_scaling" , {}).get ("factor" , 1.0 )
143
+ if self .config .get ("rope_scaling" , {}) is None :
144
+ scaling_factor = 1.0
145
+ else :
146
+ scaling_factor = self .config .get ("rope_scaling" , {}).get ("factor" , 1.0 )
144
147
max_seq_len = 32 * max_position_embeddings # 64k
145
148
self ._cos_cached = torch .zeros ((max_seq_len , partial_head_dim // 2 ), dtype = torch .float16 , device = "cuda" )
146
149
self ._sin_cached = torch .zeros ((max_seq_len , partial_head_dim // 2 ), dtype = torch .float16 , device = "cuda" )
@@ -165,7 +168,10 @@ def _init_to_get_yarn_rotary(self):
165
168
dim = self .head_dim_
166
169
max_position_embeddings = self .config .get ("max_position_embeddings" , 2048 )
167
170
base = self .config .get ("rope_theta" , 10000.0 )
168
- scale = self .config .get ("rope_scaling" , {}).get ("factor" , 1.0 )
171
+ if self .config .get ("rope_scaling" , {}) is None :
172
+ scale = 1.0
173
+ else :
174
+ scale = self .config .get ("rope_scaling" , {}).get ("factor" , 1.0 )
169
175
original_max_position_embeddings = self .config .get ("original_max_position_embeddings" , 2048 )
170
176
extrapolation_factor = 1.0
171
177
attn_factor = 1.0
0 commit comments