Skip to content

Commit

Permalink
remove OverridenKVCache and fix some peculiar cases of prims.copy_
Browse files Browse the repository at this point in the history
…+ NVFuser (#788)
  • Loading branch information
nikitaved authored Jul 17, 2024
1 parent efb6b4f commit 7fdf213
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 39 deletions.
16 changes: 16 additions & 0 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,22 @@ def _can_fuse_node(n: Node):
# (Used to name fusions like nvFusion0, nvFusion1, ...)
fusion_counter: int = 0
for bsyms in bound_symbol_groups:
# Related to in-place ops.
# prims.copy_ is a no-op for NVFuser in a sense that it does not
# generate nor runs any kernels.
# For that reason we should avoid fusing prims.copy_ unless
# it comes after other non-copy symbols in a fusion.
# See the following relevant issues:
# https://github.com/Lightning-AI/lightning-thunder/issues/789
# https://github.com/Lightning-AI/lightning-thunder/issues/791
# NOTE: filter all first "dangling" no-op copies
while len(bsyms) > 0 and bsyms[0].sym.id is prims.PrimIDs.COPY_:
fused_bsyms.append(bsyms[0])
bsyms = bsyms[1:]

if len(bsyms) == 0:
continue

# TODO The following allows generating single node fusions, which
# may be suboptimal for real-world performance.
# Provide a mechanism to switch between "test" and "perf" modes
Expand Down
39 changes: 0 additions & 39 deletions thunder/tests/litgpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,47 +118,8 @@
name_to_config = {config["name"]: config for config in configs}


class OverridenKVCache(nn.Module):
def __init__(
self,
k_shape: tuple[int, int, int, int],
v_shape: tuple[int, int, int, int],
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)

def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# move the buffer to the activation dtype for when AMP is used
self.k = self.k.to(k.dtype)
self.v = self.v.to(v.dtype)
# update the cache
# NOTE: `torch._dynamo.is_compiling` is being deprecated, we should update this once all versions have `torch.compiler.is_compiling`.
is_compiling = (
torch.compiler.is_compiling if hasattr(torch.compiler, "is_compiling") else torch._dynamo.is_compiling
)
if is_compiling():
# inductor doesn't support `index_add` with bfloat16
k = self.k.index_copy_(2, input_pos, k)
v = self.v.index_copy_(2, input_pos, v)
return k, v
# See issue: "Support more indexing operators (index_copy and index_add)"
k = self.k = torch.index_add(self.k, 2, input_pos, k)
v = self.v = torch.index_add(self.v, 2, input_pos, v)
# THUNDER bug: cannot return self.k, self.v here (may be cuda graphs related - no minimum repro)
return k, v

def reset_parameters(self) -> None:
torch.nn.init.zeros_(self.k)
torch.nn.init.zeros_(self.v)


import litgpt

# override for operator workarounds
litgpt.model.KVCache = OverridenKVCache
# add the testing configurations
litgpt.config.name_to_config.update(name_to_config)
name_to_config.update(litgpt.config.name_to_config)
Expand Down

0 comments on commit 7fdf213

Please sign in to comment.