Skip to content

Commit dc18d00

Browse files
committed
test
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
1 parent cc6b232 commit dc18d00

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

thunder/tests/test_core.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2819,3 +2819,35 @@ def foo(x):
28192819
expected = foo(TestDataclass(t, s))
28202820

28212821
torch.testing.assert_close(actual, expected)
2822+
2823+
2824+
@pytest.mark.filterwarnings("ignore:Please use `torch.vmap`")
2825+
def test_custom_autograd_function():
2826+
from torch.autograd.gradcheck import GradcheckError
2827+
from torch.testing._internal.common_utils import gradcheck
2828+
2829+
class MyFunction(torch.autograd.Function):
2830+
2831+
@staticmethod
2832+
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
2833+
return x * 2.0
2834+
2835+
# this is wrong on purpose.
2836+
@staticmethod
2837+
def backward(ctx, grad_output) -> torch.Tensor:
2838+
return grad_output
2839+
2840+
class Model(torch.nn.Module):
2841+
def __init__(self):
2842+
super().__init__()
2843+
2844+
def forward(self, x) -> torch.Tensor:
2845+
return MyFunction.apply(x)
2846+
2847+
x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
2848+
model = Model().to(dtype=torch.float64)
2849+
jitted = thunder.jit(model, skip_inplace_functionalization=True)
2850+
2851+
gradcheck(jitted, (x,))
2852+
with pytest.raises(GradcheckError):
2853+
gradcheck(model, (x,))

0 commit comments

Comments
 (0)