From cc6b232fc7ecaacfeca35853ffc6ab2af00b9e4b Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 27 Jun 2024 12:56:26 +0900 Subject: [PATCH 1/3] register `torch._C._functorch.unwrap_if_dead` as identity Signed-off-by: Masaki Kozuki --- thunder/torch/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 04ccd74628..437bfe314e 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -4936,6 +4936,14 @@ def _set_grad_enabled_with_warning(enabled: bool) -> None: register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning) +def _unwrap_if_dead(tensor): + warnings.warn("torch._C._functorch.unwrap_if_dead have no effect under thunder.jit") + return tensor + + +register_function(torch._C._functorch.unwrap_if_dead, _unwrap_if_dead) + + # # Distributed operations # From dc18d002c9264fe73821fc5fab90c374909f31e3 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 27 Jun 2024 15:05:27 +0900 Subject: [PATCH 2/3] test Signed-off-by: Masaki Kozuki --- thunder/tests/test_core.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index fd751939b8..09c5e81baa 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2819,3 +2819,35 @@ def foo(x): expected = foo(TestDataclass(t, s)) torch.testing.assert_close(actual, expected) + + +@pytest.mark.filterwarnings("ignore:Please use `torch.vmap`") +def test_custom_autograd_function(): + from torch.autograd.gradcheck import GradcheckError + from torch.testing._internal.common_utils import gradcheck + + class MyFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x: torch.Tensor) -> torch.Tensor: + return x * 2.0 + + # this is wrong on purpose. + @staticmethod + def backward(ctx, grad_output) -> torch.Tensor: + return grad_output + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x) -> torch.Tensor: + return MyFunction.apply(x) + + x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True) + model = Model().to(dtype=torch.float64) + jitted = thunder.jit(model, skip_inplace_functionalization=True) + + gradcheck(jitted, (x,)) + with pytest.raises(GradcheckError): + gradcheck(model, (x,)) From 63f8dc07f0305c0bbda75ace311ed135234ffb5d Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 27 Jun 2024 17:00:23 +0900 Subject: [PATCH 3/3] updated warn Signed-off-by: Masaki Kozuki --- thunder/torch/__init__.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 437bfe314e..613443d942 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -17,6 +17,7 @@ # Initializes the language context from thunder.torch.langctx import register_method, register_property +from thunder.core.baseutils import run_once import thunder.clang as clang import thunder.core.devices as devices from thunder.core.devices import to_device, device_from_string @@ -4936,8 +4937,18 @@ def _set_grad_enabled_with_warning(enabled: bool) -> None: register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning) +@run_once +def _warn_for_unwrap_if_dead(): + warnings.warn( + "torch._C._functorch.unwrap_if_dead has no effect under thunder.jit. " + "`torch.autograd.Function.backward` is not respected by `thunder.jit`. " + "`thunder.jit` generates backward based on the forward" + ) + + def _unwrap_if_dead(tensor): - warnings.warn("torch._C._functorch.unwrap_if_dead have no effect under thunder.jit") + + _warn_for_unwrap_if_dead() return tensor