From aa98b3536d6486a65c4abaed551c651a1de3d6d6 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Tue, 26 Mar 2024 15:06:56 +0800 Subject: [PATCH] Update model.py To Fix ntk length bug. (#375) --- lightllm/models/llama/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 044ad63bd..8b5a4d549 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -144,7 +144,7 @@ def _init_to_get_dynamic_ntk_rotary(self): scaling_factor = 1.0 else: scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) - max_seq_len = self.max_seq_length + max_seq_len = max(self.max_seq_length, max_position_embeddings) self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda") self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")