Skip to content

Commit 8a5016b

Browse files
authored
register torch._C._functorch.unwrap_if_dead as identity (#661)
1 parent 390d8e3 commit 8a5016b

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

thunder/tests/test_core.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2819,3 +2819,35 @@ def foo(x):
28192819
expected = foo(TestDataclass(t, s))
28202820

28212821
torch.testing.assert_close(actual, expected)
2822+
2823+
2824+
@pytest.mark.filterwarnings("ignore:Please use `torch.vmap`")
2825+
def test_custom_autograd_function():
2826+
from torch.autograd.gradcheck import GradcheckError
2827+
from torch.testing._internal.common_utils import gradcheck
2828+
2829+
class MyFunction(torch.autograd.Function):
2830+
2831+
@staticmethod
2832+
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
2833+
return x * 2.0
2834+
2835+
# this is wrong on purpose.
2836+
@staticmethod
2837+
def backward(ctx, grad_output) -> torch.Tensor:
2838+
return grad_output
2839+
2840+
class Model(torch.nn.Module):
2841+
def __init__(self):
2842+
super().__init__()
2843+
2844+
def forward(self, x) -> torch.Tensor:
2845+
return MyFunction.apply(x)
2846+
2847+
x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
2848+
model = Model().to(dtype=torch.float64)
2849+
jitted = thunder.jit(model, skip_inplace_functionalization=True)
2850+
2851+
gradcheck(jitted, (x,))
2852+
with pytest.raises(GradcheckError):
2853+
gradcheck(model, (x,))

thunder/torch/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# Initializes the language context
1818
from thunder.torch.langctx import register_method, register_property
1919

20+
from thunder.core.baseutils import run_once
2021
import thunder.clang as clang
2122
import thunder.core.devices as devices
2223
from thunder.core.devices import to_device, device_from_string
@@ -4936,6 +4937,24 @@ def _set_grad_enabled_with_warning(enabled: bool) -> None:
49364937
register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning)
49374938

49384939

4940+
@run_once
4941+
def _warn_for_unwrap_if_dead():
4942+
warnings.warn(
4943+
"torch._C._functorch.unwrap_if_dead has no effect under thunder.jit. "
4944+
"`torch.autograd.Function.backward` is not respected by `thunder.jit`. "
4945+
"`thunder.jit` generates backward based on the forward"
4946+
)
4947+
4948+
4949+
def _unwrap_if_dead(tensor):
4950+
4951+
_warn_for_unwrap_if_dead()
4952+
return tensor
4953+
4954+
4955+
register_function(torch._C._functorch.unwrap_if_dead, _unwrap_if_dead)
4956+
4957+
49394958
#
49404959
# Distributed operations
49414960
#

0 commit comments

Comments
 (0)