Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register torch._C._functorch.unwrap_if_dead as identity #661

Merged
merged 3 commits into from
Jun 27, 2024

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Jun 27, 2024

What does this PR do?

Fixes #660.
The function is called inside of custom torch.autograd.Function. If one uses only torch functions in their custom Function, then it should be supported.

@t-vi
Copy link
Collaborator

t-vi commented Jun 27, 2024

So I'm not opposed, but as far as I can tell, #660 is about the TensorParallelism implementation in NeMo . So I wonder if we would take on this problem a bit further up the stack (e.g. at the level of the custom autograd.Function, Module, ...)?

@crcrpar crcrpar marked this pull request as ready for review June 27, 2024 06:07
crcrpar added 2 commits June 27, 2024 15:07
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar force-pushed the crpa/unwrap_if_dead-for-custom-autograd-function branch from 94096d0 to dc18d00 Compare June 27, 2024 06:07
thunder/torch/__init__.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @crcrpar . Great analysis that we want to not fall over in autograd.Function.
To my mind we might want to warn in autograd.Function.apply, but we can do that elsewhere (probably wants to be in jit_ext to have lookaside with warning and then trace through the original apply), but let's have this for now.

@@ -4936,6 +4936,14 @@ def _set_grad_enabled_with_warning(enabled: bool) -> None:
register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning)


def _unwrap_if_dead(tensor):
warnings.warn("torch._C._functorch.unwrap_if_dead have no effect under thunder.jit")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably better as warn_once, but hey.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing as always. Thank you @crcrpar !

@t-vi t-vi merged commit 8a5016b into main Jun 27, 2024
27 of 31 checks passed
@t-vi t-vi deleted the crpa/unwrap_if_dead-for-custom-autograd-function branch June 27, 2024 10:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TypeError: unwrap_if_dead() incompatible args in NeVA model
2 participants