File tree Expand file tree Collapse file tree 2 files changed +51
-0
lines changed Expand file tree Collapse file tree 2 files changed +51
-0
lines changed Original file line number Diff line number Diff line change @@ -2819,3 +2819,35 @@ def foo(x):
2819
2819
expected = foo (TestDataclass (t , s ))
2820
2820
2821
2821
torch .testing .assert_close (actual , expected )
2822
+
2823
+
2824
+ @pytest .mark .filterwarnings ("ignore:Please use `torch.vmap`" )
2825
+ def test_custom_autograd_function ():
2826
+ from torch .autograd .gradcheck import GradcheckError
2827
+ from torch .testing ._internal .common_utils import gradcheck
2828
+
2829
+ class MyFunction (torch .autograd .Function ):
2830
+
2831
+ @staticmethod
2832
+ def forward (ctx , x : torch .Tensor ) -> torch .Tensor :
2833
+ return x * 2.0
2834
+
2835
+ # this is wrong on purpose.
2836
+ @staticmethod
2837
+ def backward (ctx , grad_output ) -> torch .Tensor :
2838
+ return grad_output
2839
+
2840
+ class Model (torch .nn .Module ):
2841
+ def __init__ (self ):
2842
+ super ().__init__ ()
2843
+
2844
+ def forward (self , x ) -> torch .Tensor :
2845
+ return MyFunction .apply (x )
2846
+
2847
+ x = torch .randn ((2 , 2 ), dtype = torch .float64 , requires_grad = True )
2848
+ model = Model ().to (dtype = torch .float64 )
2849
+ jitted = thunder .jit (model , skip_inplace_functionalization = True )
2850
+
2851
+ gradcheck (jitted , (x ,))
2852
+ with pytest .raises (GradcheckError ):
2853
+ gradcheck (model , (x ,))
Original file line number Diff line number Diff line change 17
17
# Initializes the language context
18
18
from thunder .torch .langctx import register_method , register_property
19
19
20
+ from thunder .core .baseutils import run_once
20
21
import thunder .clang as clang
21
22
import thunder .core .devices as devices
22
23
from thunder .core .devices import to_device , device_from_string
@@ -4936,6 +4937,24 @@ def _set_grad_enabled_with_warning(enabled: bool) -> None:
4936
4937
register_function (torch ._C ._set_grad_enabled , _set_grad_enabled_with_warning )
4937
4938
4938
4939
4940
+ @run_once
4941
+ def _warn_for_unwrap_if_dead ():
4942
+ warnings .warn (
4943
+ "torch._C._functorch.unwrap_if_dead has no effect under thunder.jit. "
4944
+ "`torch.autograd.Function.backward` is not respected by `thunder.jit`. "
4945
+ "`thunder.jit` generates backward based on the forward"
4946
+ )
4947
+
4948
+
4949
+ def _unwrap_if_dead (tensor ):
4950
+
4951
+ _warn_for_unwrap_if_dead ()
4952
+ return tensor
4953
+
4954
+
4955
+ register_function (torch ._C ._functorch .unwrap_if_dead , _unwrap_if_dead )
4956
+
4957
+
4939
4958
#
4940
4959
# Distributed operations
4941
4960
#
You can’t perform that action at this time.
0 commit comments