File tree Expand file tree Collapse file tree 1 file changed +32
-0
lines changed Expand file tree Collapse file tree 1 file changed +32
-0
lines changed Original file line number Diff line number Diff line change @@ -2819,3 +2819,35 @@ def foo(x):
2819
2819
expected = foo (TestDataclass (t , s ))
2820
2820
2821
2821
torch .testing .assert_close (actual , expected )
2822
+
2823
+
2824
+ @pytest .mark .filterwarnings ("ignore:Please use `torch.vmap`" )
2825
+ def test_custom_autograd_function ():
2826
+ from torch .autograd .gradcheck import GradcheckError
2827
+ from torch .testing ._internal .common_utils import gradcheck
2828
+
2829
+ class MyFunction (torch .autograd .Function ):
2830
+
2831
+ @staticmethod
2832
+ def forward (ctx , x : torch .Tensor ) -> torch .Tensor :
2833
+ return x * 2.0
2834
+
2835
+ # this is wrong on purpose.
2836
+ @staticmethod
2837
+ def backward (ctx , grad_output ) -> torch .Tensor :
2838
+ return grad_output
2839
+
2840
+ class Model (torch .nn .Module ):
2841
+ def __init__ (self ):
2842
+ super ().__init__ ()
2843
+
2844
+ def forward (self , x ) -> torch .Tensor :
2845
+ return MyFunction .apply (x )
2846
+
2847
+ x = torch .randn ((2 , 2 ), dtype = torch .float64 , requires_grad = True )
2848
+ model = Model ().to (dtype = torch .float64 )
2849
+ jitted = thunder .jit (model , skip_inplace_functionalization = True )
2850
+
2851
+ gradcheck (jitted , (x ,))
2852
+ with pytest .raises (GradcheckError ):
2853
+ gradcheck (model , (x ,))
You can’t perform that action at this time.
0 commit comments