-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ThunderFX: handles the callable input of fx.Node #1548
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the fix, @kiya00 . The fix itself looks good, however the test failures look real.
Also, once everything works fine, I'd expect the following test to fail as it shouldn't find any module which was passed to inductor
. (So it will need an update)
lightning-thunder/thunder/tests/test_dynamo.py
Lines 268 to 270 in 673bdb9
def test_splitter_autograd_function(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): | |
x = torch.ones(2, device=device, dtype=dtype, requires_grad=True) | |
This line should raise an assertion error -
assert any(target.startswith("inductor_") for target in targets) |
Hi @kshitij12345 I met 2 problems, do you maybe have any suggestions?
Or do we set a specific pass to delete the
Here is the FX graph structure of
the
cc: @IvanYashchuk |
Regarding1., I am not very familiar with For 2., #1463 should take care of this as regions within |
By cherry-pick #1463, the
the
the augmented fwd trace is:
Note that the lightning-thunder/thunder/core/transforms.py Lines 2744 to 2746 in c6294ac
I think the correct behavior should be to use the grad transformation: lightning-thunder/thunder/core/jit_ext.py Line 753 in c6294ac
torch._C._set_grad_enabled in the fwd_body_0 module?It seems if I let is_constant_for_vjp(symbol of higher_order_autograd_function_apply) return False even if the output of higher_order_autograd_function_apply has tag ProxyTag.DETACHED_AUTOGRAD_GRAPH , the backward trace is expected
|
Thanks for the explanation @kiya00. I was wondering that since the line below goes through autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, s0, args_tensor_mask = [True], non_differentiable_idx = []); |
yes, it's automatically, the bwd of symbol lightning-thunder/thunder/core/transforms.py Lines 2744 to 2764 in 2d0199e
so I tried to let is_constant_for_vjp always return False for autograd_function_apply then the fwd trace is(corresponds to the test case in the PR):
the bwd trace:
|
Hi @kshitij12345 @IvanYashchuk , it's ready for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com> Co-authored-by: Kshiteej K <kshitijkalambarkar@gmail.com>
Hi @IvanYashchuk do you want to take another look before merging? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's merge. Thank you all!
@t-vi, could you please approve this one? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @kiya00 @kshitij12345 @crcrpar @IvanYashchuk
Before submitting
What does this PR do?
Fixes #1539.
Needs #1463, #1620 to pass no_grad regions to thunder in thunderFX; #1568
As the analysis in #1539 (comment), this PR try to fix it by removing the unused
torch.autograd.function.FunctionCtx()
and processing the get_attr node