-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lookaside for torch.ops.higher_order.autograd_function_apply
#1256
Conversation
b4647ed
to
71db6cd
Compare
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
627845d
to
f349308
Compare
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
9c51ae2
to
94c3409
Compare
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
c5621cd
to
dd702f5
Compare
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
dd702f5
to
1b85a21
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want the lookasides' scope to be limited only to the preprocessing of PyTorch code. If the removed code is reused in the updated lookaside we'll achieve that.
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
1b85a21
to
7729af1
Compare
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
7729af1
to
40ae171
Compare
thunder/tests/test_jit_general.py
Outdated
@@ -1231,6 +1233,9 @@ def my_sin(x): | |||
torch.testing.assert_close(y, y_ref) | |||
|
|||
initial_computation_trace = thunder.last_traces(jitted)[0] | |||
bsym_str_ids = tuple( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initial_computation trace is not a valid Python function
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(x):
# x: "cpu f32[2, 2]"
# /home/tv/data/firma/grid/thunder/lightning-thunder/thunder/tests/test_jit_general.py:1216: return grad_output * torch.cos(x)
t6 = ltorch.autograd_function_apply(_function_0, _function_1, x, args_tensor_mask=[True], non_differentiable_idx=[]) # t6: "cpu f32[2, 2]"
# t6 = ltorch.sin(x) # t6: "cpu f32[2, 2]"
# t6 = prims.sin(x) # t6: "cpu f32[2, 2]"
return t6
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a contract that any of the traces generated from a callable is executable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What part of the provided code snippet makes it an invalid Python function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a contract that any of the traces generated from a callable is executable?
This has been the case up to now and it is what I have repeatedly said about why I want the things properly inlined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please elaborate on what part of the example trace makes it improperly executable?
A valid Python program from a trace is generated using its string representation (trace.python()
) and its "context" (trace.python_ctx
). The context is used to pass as the globals=
argument to the built-in exec
function (https://docs.python.org/3/library/functions.html#exec).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I brought the two
initial_computation_trace = thunder.last_traces(jitted)[0] |
The top is this PR, the bottom, main.
I'm not seeing the difference.
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(x):
# x: "cpu f32[2, 2]"
# /home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/tests/test_jit_general.py:1218: return grad_output * x.cos()
t6 = ltorch.autograd_function_apply(_function_0, _function_1, x, args_tensor_mask=[True], non_differentiable_idx=[]) # t6: "cpu f32[2, 2]"
# t6 = ltorch.sin(x) # t6: "cpu f32[2, 2]"
# t6 = prims.sin(x) # t6: "cpu f32[2, 2]"
return t6
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(x):
# x: "cpu f32[2, 2]"
# /home/mkozuki/ghq/github.com/crcrpar/lightning-thunder/thunder/tests/test_jit_general.py:1217: return torch.ops.higher_order.autograd_function_apply(
t0 = ltorch.autograd_function_apply(_function_0, _function_1, x, args_tensor_mask=[True], non_differentiable_idx=[]) # t0: "cpu f32[2, 2]"
# t0 = ltorch.sin(x) # t0: "cpu f32[2, 2]"
# t0 = prims.sin(x) # t0: "cpu f32[2, 2]"
return t0
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
ce17411
to
005697e
Compare
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
005697e
to
e903bb8
Compare
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Masaki offered #1526 as an alternative without higher order functions. I merged it to unblock uses of autograd_function_apply. |
What does this PR do?
As per #1248, the support of
torch.ops.higher_order.autograd_function_apply
would be a bit more flexible by tracing into bothfwd
andbwd
.cc @apaz-cli