diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index fd751939b8..09c5e81baa 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2819,3 +2819,35 @@ def foo(x): expected = foo(TestDataclass(t, s)) torch.testing.assert_close(actual, expected) + + +@pytest.mark.filterwarnings("ignore:Please use `torch.vmap`") +def test_custom_autograd_function(): + from torch.autograd.gradcheck import GradcheckError + from torch.testing._internal.common_utils import gradcheck + + class MyFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x: torch.Tensor) -> torch.Tensor: + return x * 2.0 + + # this is wrong on purpose. + @staticmethod + def backward(ctx, grad_output) -> torch.Tensor: + return grad_output + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x) -> torch.Tensor: + return MyFunction.apply(x) + + x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True) + model = Model().to(dtype=torch.float64) + jitted = thunder.jit(model, skip_inplace_functionalization=True) + + gradcheck(jitted, (x,)) + with pytest.raises(GradcheckError): + gradcheck(model, (x,)) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 04ccd74628..613443d942 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -17,6 +17,7 @@ # Initializes the language context from thunder.torch.langctx import register_method, register_property +from thunder.core.baseutils import run_once import thunder.clang as clang import thunder.core.devices as devices from thunder.core.devices import to_device, device_from_string @@ -4936,6 +4937,24 @@ def _set_grad_enabled_with_warning(enabled: bool) -> None: register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning) +@run_once +def _warn_for_unwrap_if_dead(): + warnings.warn( + "torch._C._functorch.unwrap_if_dead has no effect under thunder.jit. " + "`torch.autograd.Function.backward` is not respected by `thunder.jit`. " + "`thunder.jit` generates backward based on the forward" + ) + + +def _unwrap_if_dead(tensor): + + _warn_for_unwrap_if_dead() + return tensor + + +register_function(torch._C._functorch.unwrap_if_dead, _unwrap_if_dead) + + # # Distributed operations #