Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

no_grad is lost in jitted functions #1486

Open
beverlylytle opened this issue Nov 27, 2024 · 1 comment
Open

no_grad is lost in jitted functions #1486

beverlylytle opened this issue Nov 27, 2024 · 1 comment
Labels

Comments

@beverlylytle
Copy link
Collaborator

beverlylytle commented Nov 27, 2024

🐛 Bug

For a function decorated with torch.no_grad, the compile data of the jitted version has is_grad_enabled set to True when I would expect it to be False.

Code sample

import torch
import thunder

@torch.no_grad
def f(x):
    print(torch.is_grad_enabled())
    return x * 2


jf = thunder.jit(f)

x = torch.ones((2,2))

f(x)                                              # prints False
jf(x)                                             # prints True
thunder.compile_data(jf).is_grad_enabled          # True
@kshitij12345
Copy link
Collaborator

kshitij12345 commented Nov 28, 2024

Here is what is happening,

# This is how the thunder interpreter sees the above function.
def f(x):
    torch._C._set_grad_enabled(False)
    print(torch.is_grad_enabled())
    result = x * 2  # CompileData.is_grad_enabled = False only for this part.
    torch._C._set_grad_enabled(True)
    return result

We can verify this from thunder.last_traces, adding the following lines to the above script

# Verifying from trace
traces = thunder.last_traces(jf)
print(traces[0])

Trace

def computation(x):
  # x: "cpu f32[2, 2]"

  # /home/kkalambarkar/git/pytorch/torch/autograd/grad_mode.py:187:             torch._C._set_grad_enabled(mode)
  ltorch._set_grad_enabled_with_warning(False)

  # /home/kkalambarkar/lightning-thunder/scratchpad/test.py:153:            return x * 2
  t0 = ltorch.mul(x, 2)  # t0: "cpu f32[2, 2]"
    # t0 = prims.mul(x, 2.0)  # t0: "cpu f32[2, 2]"

  # /home/kkalambarkar/git/pytorch/torch/autograd/grad_mode.py:187:             torch._C._set_grad_enabled(mode)
  ltorch._set_grad_enabled_with_warning(True)
  return {'output': t0, 'flat_args': [x]}

So, we can see that at the end of the function, we have this _set_grad_enabled(True) which sets the CompileData.is_grad_enabled to True.

CompileData.is_grad_enabled will only be False between the no_grad region. So, if ltorch.mul or other operator queried this state, it would have seen False.

NOTE - The Symbols which care about this should query this state during Tracing. Post that is_grad_enabled will reflect the last state it was updated to.

Also, the reason we see print(torch.is_grad_enabled()) printing True is because, our lookaside for _set_grad_enabled only updates the state in CompileData and doesn't actually call torch._C._set_grad_enabled which would have updated the state for PyTorch. So, PyTorch never knows this while tracing and hence is_grad_enabled returns True. (I am not sure if we do/want to support printing during tracing).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants