-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
register torch._C._functorch.unwrap_if_dead
as identity
#661
Conversation
So I'm not opposed, but as far as I can tell, #660 is about the TensorParallelism implementation in NeMo . So I wonder if we would take on this problem a bit further up the stack (e.g. at the level of the custom autograd.Function, Module, ...)? |
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
94096d0
to
dc18d00
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @crcrpar . Great analysis that we want to not fall over in autograd.Function.
To my mind we might want to warn in autograd.Function.apply, but we can do that elsewhere (probably wants to be in jit_ext to have lookaside with warning and then trace through the original apply), but let's have this for now.
thunder/torch/__init__.py
Outdated
@@ -4936,6 +4936,14 @@ def _set_grad_enabled_with_warning(enabled: bool) -> None: | |||
register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning) | |||
|
|||
|
|||
def _unwrap_if_dead(tensor): | |||
warnings.warn("torch._C._functorch.unwrap_if_dead have no effect under thunder.jit") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably better as warn_once, but hey.
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing as always. Thank you @crcrpar !
What does this PR do?
Fixes #660.
The function is called inside of custom
torch.autograd.Function
. If one uses onlytorch
functions in their customFunction
, then it should be supported.