Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

caching in make_aug_forward_and_backward breaks TE executor. #81

Closed
kshitij12345 opened this issue Mar 26, 2024 · 0 comments · Fixed by #82
Closed

caching in make_aug_forward_and_backward breaks TE executor. #81

kshitij12345 opened this issue Mar 26, 2024 · 0 comments · Fixed by #82
Assignees
Labels
bug Something isn't working

Comments

@kshitij12345
Copy link
Collaborator

As discussed offline, Caching in make_aug_forward_and_backward leads to reusing the symbols created by transformer_engine_ex which are stateful and lead to incorrect program.
Ref:

key = (bsym.sym, subkey := _make_cache_key(bsym.args, bsym.kwargs))
cached_result = _cache.get(key, None) if subkey is not None else None
if cached_result is not None:
return cached_result

Sample Program

import torch
import thunder
from thunder.executors.transformer_engineex import transformer_engine_ex
from transformer_engine.pytorch import fp8_autocast
dim = 256

class ThunderModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(dim, dim, bias=False)
        self.fc2 = torch.nn.Linear(dim, dim, bias=False)

    def forward(self, x):
        return self.fc2(torch.nn.functional.relu(self.fc1(x)))

x = torch.arange(dim * dim, dtype=torch.float).view(dim, dim).cuda()

thunder_model = ThunderModel().cuda()

jit_model = thunder.jit(thunder_model, executors=(transformer_engine_ex,),)

with fp8_autocast():
    o = jit_model(x).sum()

print(thunder.last_traces(jit_model)[-1])

Generated Trace (te_linear_0 is called twice):

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(*args):
  # args: "Collection"
  t0, t1, t2, = args
  del args
  (t6, ctx_te_1) = te_linear_0(t0, t1, None)
  t7 = torch.gt(t6, 0.0)  # t7: "cuda:0 b8[256, 256]"
    # t7 = ltorch.gt(t6, 0.0)  # t7: "cuda:0 b8[256, 256]"
      # t7 = prims.gt(t6, 0.0)  # t7: "cuda:0 b8[256, 256]"
  t8 = torch.where(t7, t6, 0.0)  # t8: "cuda:0 f32[256, 256]"
    # t8 = ltorch.where(t7, t6, 0.0)  # t8: "cuda:0 f32[256, 256]"
      # t8 = prims.where(t7, t6, 0.0)  # t8: "cuda:0 f32[256, 256]"
  del t6
  (t13, C12) = te_linear_0(t8, t2, None)
  del t8
  return {'output': t13, 'flat_args': [t0, t1, t2], 'flat_output': (t13,)}, ((t7,), (C12, ctx_te_1))
@kshitij12345 kshitij12345 added bug Something isn't working help wanted Extra attention is needed labels Mar 26, 2024
@kshitij12345 kshitij12345 removed the help wanted Extra attention is needed label Mar 26, 2024
@t-vi t-vi closed this as completed in #82 Mar 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants