Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.autocast context manager got broken for prims.linear #725

Closed
IvanYashchuk opened this issue Jul 8, 2024 · 3 comments · Fixed by #810
Closed

torch.autocast context manager got broken for prims.linear #725

IvanYashchuk opened this issue Jul 8, 2024 · 3 comments · Fixed by #810

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Jul 8, 2024

🐛 Bug

Before #705 Thunder applied autocast transform to all functions invoked within a torch.autocast region. In #705 the logic has changed to apply autocast dispatch only for "torchsymbols" introducing a regression for all other symbols.

Here's a reproducer that is run on before the above PR is merged and after:

import thunder
import torch

def foo(x, w):
    return thunder.prims.linear(x, w, None)

device = torch.device("cuda")
with device:
    # fp32 inputs
    x, w = torch.randn(16, 16), torch.randn(16, 16)
    print(x.dtype, w.dtype)
    
with torch.autocast("cuda", torch.bfloat16):
    jfoo = thunder.jit(foo)
    jit_out = jfoo(x, w)

print(thunder.last_traces(jfoo)[-1])

trace before c816506:

def computation(x, w):
  # x: "cuda:0 f32[16, 16]"
  # w: "cuda:0 f32[16, 16]"
  [t0, t1] = nvFusion0(w, x)
    # t0 = prims.convert_element_type(x, dtypes.bfloat16)  # t0: "cuda:0 bf16[16, 16]"
    # t1 = prims.convert_element_type(w, dtypes.bfloat16)  # t1: "cuda:0 bf16[16, 16]"
  del w, x
  t2 = torch.nn.functional.linear(t0, t1, None)  # t2: "cuda:0 bf16[16, 16]"
    # t2 = ltorch.linear(t0, t1, None)  # t2: "cuda:0 bf16[16, 16]"
      # t2 = prims.linear(t0, t1, None)  # t2: "cuda:0 bf16[16, 16]"
  del t0, t1
  return t2

trace after c816506:

def computation(x, w):
  # x: "cuda:0 f32[16, 16]"
  # w: "cuda:0 f32[16, 16]"
  t0 = torch.nn.functional.linear(x, w, None)  # t0: "cuda:0 f32[16, 16]"
    # t0 = ltorch.linear(x, w, None)  # t0: "cuda:0 f32[16, 16]"
      # t0 = prims.linear(x, w, None)  # t0: "cuda:0 f32[16, 16]"
  del x, w
  return to

cc @crcrpar @apaz-cli

@t-vi
Copy link
Collaborator

t-vi commented Jul 8, 2024

So the unexpected change in behaviour is bad, but given that it's torch.autocast, is it unreasonable to not autocast prim?

@IvanYashchuk
Copy link
Collaborator Author

PyTorch users are used to using torch.autocast for enabling mixed precision training. I think Thunder should support it as a way of invoking the autocast transform for Thunder functions.
It's confusing to have thunder.torch.linear and thunder.prims.linear behave differently when both have an autocast rule registered. If a custom executor adds an autocast rule for its own op then it's also out of reach to use the same PyTorch code that worked without using the custom executor.

@tfogal
Copy link
Collaborator

tfogal commented Jul 8, 2024

triage review:

  • we need to be tracking how autocast works more closely (done in autocast: support mixed dtypes #705)
  • we don't need to put autocast logic inside torchsymbol, could just be in symbol
  • is autocast the way for us to apply casting to models?
  • can we translate the context manager into the canonical "thunder" context manager? then this applies to everything
  • autocast isn't a transform? the transform actually is still there
  • changing the implicit contract silently wasn't expected/intended
  • current behavior is maybe not unreasonable.
  • @IvanYashchuk to follow up

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants