Skip to content

Commit 537c871

Browse files
authored
Fix bug for rope_scaling in config.josn is None (#359)
1 parent b445a6b commit 537c871

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

lightllm/models/llama/model.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,10 @@ def _init_to_get_dynamic_ntk_rotary(self):
140140
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
141141
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
142142
base = self.config.get("rope_theta", 10000.0)
143-
scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
143+
if self.config.get("rope_scaling", {}) is None:
144+
scaling_factor = 1.0
145+
else:
146+
scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
144147
max_seq_len = 32 * max_position_embeddings # 64k
145148
self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")
146149
self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")
@@ -165,7 +168,10 @@ def _init_to_get_yarn_rotary(self):
165168
dim = self.head_dim_
166169
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
167170
base = self.config.get("rope_theta", 10000.0)
168-
scale = self.config.get("rope_scaling", {}).get("factor", 1.0)
171+
if self.config.get("rope_scaling", {}) is None:
172+
scale = 1.0
173+
else:
174+
scale = self.config.get("rope_scaling", {}).get("factor", 1.0)
169175
original_max_position_embeddings = self.config.get("original_max_position_embeddings", 2048)
170176
extrapolation_factor = 1.0
171177
attn_factor = 1.0

0 commit comments

Comments
 (0)