Skip to content

Commit

Permalink
OverridenKVCache: remove index_add and use index_copy instead
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitaved committed Jul 17, 2024
1 parent 29379ec commit c19f2cd
Showing 1 changed file with 3 additions and 14 deletions.
17 changes: 3 additions & 14 deletions thunder/tests/litgpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,9 @@ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) ->
# move the buffer to the activation dtype for when AMP is used
self.k = self.k.to(k.dtype)
self.v = self.v.to(v.dtype)
# update the cache
# NOTE: `torch._dynamo.is_compiling` is being deprecated, we should update this once all versions have `torch.compiler.is_compiling`.
is_compiling = (
torch.compiler.is_compiling if hasattr(torch.compiler, "is_compiling") else torch._dynamo.is_compiling
)
if is_compiling():
# inductor doesn't support `index_add` with bfloat16
k = self.k.index_copy_(2, input_pos, k)
v = self.v.index_copy_(2, input_pos, v)
return k, v
# See issue: "Support more indexing operators (index_copy and index_add)"
k = self.k = torch.index_add(self.k, 2, input_pos, k)
v = self.v = torch.index_add(self.v, 2, input_pos, v)
# THUNDER bug: cannot return self.k, self.v here (may be cuda graphs related - no minimum repro)

k = self.k.index_copy_(2, input_pos, k)
v = self.v.index_copy_(2, input_pos, v)
return k, v

def reset_parameters(self) -> None:
Expand Down

0 comments on commit c19f2cd

Please sign in to comment.