Skip to content

Commit aa98b35

Browse files
authored
Update model.py To Fix ntk length bug. (#375)
1 parent f5dc783 commit aa98b35

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

lightllm/models/llama/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _init_to_get_dynamic_ntk_rotary(self):
144144
scaling_factor = 1.0
145145
else:
146146
scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
147-
max_seq_len = self.max_seq_length
147+
max_seq_len = max(self.max_seq_length, max_position_embeddings)
148148
self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")
149149
self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")
150150

0 commit comments

Comments
 (0)