Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register torch._C._functorch.unwrap_if_dead as identity #661

Merged
merged 3 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2819,3 +2819,35 @@ def foo(x):
expected = foo(TestDataclass(t, s))

torch.testing.assert_close(actual, expected)


@pytest.mark.filterwarnings("ignore:Please use `torch.vmap`")
def test_custom_autograd_function():
from torch.autograd.gradcheck import GradcheckError
from torch.testing._internal.common_utils import gradcheck

class MyFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return x * 2.0

# this is wrong on purpose.
@staticmethod
def backward(ctx, grad_output) -> torch.Tensor:
return grad_output

class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x) -> torch.Tensor:
return MyFunction.apply(x)

x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
model = Model().to(dtype=torch.float64)
jitted = thunder.jit(model, skip_inplace_functionalization=True)

gradcheck(jitted, (x,))
with pytest.raises(GradcheckError):
gradcheck(model, (x,))
8 changes: 8 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4936,6 +4936,14 @@ def _set_grad_enabled_with_warning(enabled: bool) -> None:
register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning)


def _unwrap_if_dead(tensor):
warnings.warn("torch._C._functorch.unwrap_if_dead have no effect under thunder.jit")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably better as warn_once, but hey.

t-vi marked this conversation as resolved.
Show resolved Hide resolved
return tensor


register_function(torch._C._functorch.unwrap_if_dead, _unwrap_if_dead)


#
# Distributed operations
#
Expand Down
Loading