diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 1e4bc06634..12f9b85e1d 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -921,13 +921,9 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar custom_autograd_function_cls = unwrap(obj) custom_forward = custom_autograd_function_cls.forward - args_, kwargs_ = tree_map(unwrap, (args, kwargs)) ctx = torch.autograd.function.FunctionCtx() - - pr = ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[wrap_const(custom_forward).provenance]) - wrapped_ctx = wrap(ctx, provenance=pr) - args_, kwargs_ = tree_map(lambda a: wrap(a, provenance=pr), (args_, kwargs_)) - return _interpret_call(custom_forward, wrapped_ctx, *args_, **kwargs_) + wrapped_ctx = wrap_const(ctx) + return _interpret_call(custom_forward, wrapped_ctx, *args, **kwargs) @register_general_jit_lookaside(torch.finfo) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 426a5fdad6..094e17b0e5 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2850,64 +2850,6 @@ def foo(x): torch.testing.assert_close(actual, expected) -@pytest.mark.filterwarnings("ignore:Please use `torch.vmap`") -def test_custom_autograd_function(): - from torch.autograd.gradcheck import GradcheckError - from torch.testing._internal.common_utils import gradcheck - - class MyFunction(torch.autograd.Function): - - @staticmethod - def forward(ctx, x: torch.Tensor) -> torch.Tensor: - return x * 2.0 - - # this is wrong on purpose. - @staticmethod - def backward(ctx, grad_output) -> torch.Tensor: - return grad_output - - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x) -> torch.Tensor: - return MyFunction.apply(x) - - x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True) - model = Model().to(dtype=torch.float64) - jitted = thunder.jit(model) - - gradcheck(jitted, (x,)) - with pytest.raises(GradcheckError): - gradcheck(model, (x,)) - - class MyLinear(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: - ctx.save_for_backward(x) - ctx.pretty_attr = 100 - return torch.matmul(x, weight.t()) - - @staticmethod - def backward(ctx, grad_output): - (x,) = ctx.saved_tensors - return torch.matmul(grad_output, weight), torch.matmul(grad_output.t(), x) - - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.l1 = torch.nn.Linear(2, 2, bias=False) - - def forward(self, x): - return MyLinear.apply(x, self.l1.weight) - - x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True) - model = Model().to(dtype=torch.float64) - jitted = thunder.jit(model) - - gradcheck(jitted, (x,)) - - def test_proxy_repr(): # Verify that we can call `__repr__` on different proxy subclasses. t = thunder.core.trace.TraceCtx() diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 6e01a37035..f722c1ff5f 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1135,3 +1135,63 @@ def foo(t, batch_size): assert_close(expected, actual) assert thunder.cache_misses(jfoo) == 1 assert thunder.cache_hits(jfoo) == 1 + + +@pytest.mark.filterwarnings("ignore:Please use `torch.vmap`") +def test_custom_autograd_function(): + from torch.autograd.gradcheck import GradcheckError + from torch.testing._internal.common_utils import gradcheck + + class MyFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x: torch.Tensor) -> torch.Tensor: + return x * 2.0 + + # this is wrong on purpose. + @staticmethod + def backward(ctx, grad_output) -> torch.Tensor: + return grad_output + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x) -> torch.Tensor: + return MyFunction.apply(x) + + x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True) + model = Model().to(dtype=torch.float64) + jitted = thunder.jit(model) + + gradcheck(jitted, (x,)) + with pytest.raises(GradcheckError): + gradcheck(model, (x,)) + + class MyLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, weight: torch.Tensor, shape: tuple) -> torch.Tensor: + ctx.shape = shape + ctx.save_for_backward(x, weight) + ctx.pretty_attr = 100 + return torch.matmul(x, weight.t()) + + @staticmethod + def backward(ctx, grad_output): + (x, weight) = ctx.saved_tensors + assert weight.shape == ctx.shape # really bogus, just to use ctx.shape + return torch.matmul(grad_output, weight), torch.matmul(grad_output.t(), x) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(2, 2, bias=False) + + def forward(self, x): + return MyLinear.apply(x, self.l1.weight, self.l1.weight.shape) + + x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True) + model = Model().to(dtype=torch.float64) + jitted = thunder.jit(model) + + gradcheck(jitted, (x,))