From c19f2cd4de7378906e6db8efab7acb84ef0cc9b5 Mon Sep 17 00:00:00 2001 From: nikitaved Date: Wed, 17 Jul 2024 05:14:15 -0400 Subject: [PATCH] OverridenKVCache: remove index_add and use index_copy instead --- thunder/tests/litgpt_model.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/thunder/tests/litgpt_model.py b/thunder/tests/litgpt_model.py index 65aed8c386..caca9b982d 100644 --- a/thunder/tests/litgpt_model.py +++ b/thunder/tests/litgpt_model.py @@ -134,20 +134,9 @@ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> # move the buffer to the activation dtype for when AMP is used self.k = self.k.to(k.dtype) self.v = self.v.to(v.dtype) - # update the cache - # NOTE: `torch._dynamo.is_compiling` is being deprecated, we should update this once all versions have `torch.compiler.is_compiling`. - is_compiling = ( - torch.compiler.is_compiling if hasattr(torch.compiler, "is_compiling") else torch._dynamo.is_compiling - ) - if is_compiling(): - # inductor doesn't support `index_add` with bfloat16 - k = self.k.index_copy_(2, input_pos, k) - v = self.v.index_copy_(2, input_pos, v) - return k, v - # See issue: "Support more indexing operators (index_copy and index_add)" - k = self.k = torch.index_add(self.k, 2, input_pos, k) - v = self.v = torch.index_add(self.v, 2, input_pos, v) - # THUNDER bug: cannot return self.k, self.v here (may be cuda graphs related - no minimum repro) + + k = self.k.index_copy_(2, input_pos, k) + v = self.v.index_copy_(2, input_pos, v) return k, v def reset_parameters(self) -> None: