-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ThunderFX: handles the callable input of fx.Node #1548
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the fix, @kiya00 . The fix itself looks good, however the test failures look real.
Also, once everything works fine, I'd expect the following test to fail as it shouldn't find any module which was passed to inductor
. (So it will need an update)
lightning-thunder/thunder/tests/test_dynamo.py
Lines 268 to 270 in 673bdb9
def test_splitter_autograd_function(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): | |
x = torch.ones(2, device=device, dtype=dtype, requires_grad=True) | |
This line should raise an assertion error -
assert any(target.startswith("inductor_") for target in targets) |
Hi @kshitij12345 I met 2 problems, do you maybe have any suggestions?
Or do we set a specific pass to delete the
Here is the FX graph structure of
the
cc: @IvanYashchuk |
Regarding1., I am not very familiar with For 2., #1463 should take care of this as regions within |
By cherry-pick #1463, the
the
the augmented fwd trace is:
Note that the lightning-thunder/thunder/core/transforms.py Lines 2744 to 2746 in c6294ac
I think the correct behavior should be to use the grad transformation: lightning-thunder/thunder/core/jit_ext.py Line 753 in c6294ac
torch._C._set_grad_enabled in the fwd_body_0 module?It seems if I let is_constant_for_vjp(symbol of higher_order_autograd_function_apply) return False even if the output of higher_order_autograd_function_apply has tag ProxyTag.DETACHED_AUTOGRAD_GRAPH , the backward trace is expected
|
Thanks for the explanation @kiya00. I was wondering that since the line below goes through autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, s0, args_tensor_mask = [True], non_differentiable_idx = []); |
yes, it's automatically, the bwd of symbol lightning-thunder/thunder/core/transforms.py Lines 2744 to 2764 in 2d0199e
so I tried to let is_constant_for_vjp always return False for autograd_function_apply then the fwd trace is(corresponds to the test case in the PR):
the bwd trace:
|
thunder/core/transforms.py
Outdated
if isinstance(symbol.sym.id, str) and symbol.sym.id.startswith("higher_order_autograd_function_apply"): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? Can you please provide an example of a failing case without these lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the reason is #1548 (comment), the test case test_splitter_autograd_function
modified in this PR will fail without it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's most likely a viable workaround for the problems seen in this PR, but I don't think this is a correct fix. There's a problem with thunder.torch._set_grad_enabled_with_warning(False)
called inside the forward function passed to thunder.torch.autograd_function_apply
that causes the system to ignore provided backward which should be fixed first:
import thunder
import torch
def forward(_, x):
saved_for_backward = (x,)
thunder.torch._set_grad_enabled_with_warning(False) # Without this line the specified backward is called as expected
sin = thunder.torch.sin(x)
thunder.torch._set_grad_enabled_with_warning(True)
return sin, saved_for_backward
def backward(_, grad_output, *saved_tensors):
raise NotImplementedError
def my_sin(x):
res = thunder.torch.autograd_function_apply(
forward,
backward,
x,
args_tensor_mask=[True],
non_differentiable_idx=[],
)
return res
jitted = thunder.jit(my_sin)
x = torch.randn((2, 2), requires_grad=True)
out = jitted(x) # Should raise NotImplementedError but it doesn't
out.backward(out)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the repro, I think autograd_function_apply
should override any set_grad_enabled
inside. I will try updating #1463 to work with this and discuss the fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thunder.torch.autograd_function_apply
corresponds to the torch.ops.higher_order.autograd_function_apply
that only appears in dynamo, and dynamo will add a no_grad guard around the forward function, so it's ok if we use it in thunderFX, but I agree we need to think if there are other ways to handle it.
if we write it as follows, it raises error
import thunder
import torch
from thunder.dynamo import thunderfx
class Sin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
with torch.no_grad():
return torch.sin(x)
@staticmethod
def backward(ctx, g):
#(x,) = ctx.saved_tensors
#return g * torch.cos(x) * 100
raise NotImplementedError("aaaaa")
def my_sin(x):
return Sin.apply(x)
# jitted = thunder.jit(my_sin)
jitted = thunderfx(my_sin)
x = torch.randn((2, 2), requires_grad=True)
out = jitted(x) # Should raise NotImplementedError but it doesn't
out.backward(out)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kshitij12345, could you please create a pull request with the changes from #1548 (comment) and a test from #1548 (comment). With your fix, the code added here to is_constant_for_vjp
shouldn't be necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, the lookaside for torch.ops.higher_order.autograd_function_apply
doesn't use thunder.torch.autograd_function_apply
, and the fix proposed in #1548 (comment) needs to be duplicated for forward_result
here:
lightning-thunder/thunder/core/jit_ext.py
Line 863 in 35ca2e9
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have updated #1463
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need to update to consider for comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @kshitij12345 , I've removed the cherry-picked commits and the modification in is_constant_for_vjp
, I think after #1548 (comment) is fixed, the test case test_splitter_autograd_function
in this PR should pass
Before submitting
What does this PR do?
Fixes #1539.
Needs #1463 to pass no_grad regions to thunder in thunderFX; #1568
As the analysis in #1539 (comment), this PR try to fix it by adding a dead code elimination and processing the get_attr node