@@ -8647,64 +8647,15 @@ def global_context_capture_fn(frame_summary):
86478647 self .assertEqual (seen_frames [1 ].name , "uwu_inline_me" )
86488648 self .assertEqual (seen_frames [2 ].line , "r2 = uwu_inline_me_deep(y, z)" )
86498649
8650- def test_recompile_on_disable_1 (self ):
8651- # fix https://github.com/pytorch/pytorch/issues/157399
8650+ def test_error_on_recompile (self ):
86528651 @torch .compile (backend = "eager" )
8653- def fn (x ):
8654- @torch ._dynamo .disable
8655- def inner (x ):
8656- return x + 10
8657-
8658- return inner (x ) + 1
8659-
8660- with unittest .mock .patch ("torch._dynamo.config.error_on_recompile" , True ):
8661- try :
8662- for i in range (5 ):
8663- fn (torch .rand (2 , 3 ))
8664- except torch ._dynamo .exc .RecompileError as e :
8665- self .fail ("RecompileError raised unexpectedly: " + str (e ))
8666-
8667- def test_recompile_on_disable_2 (self ):
8668- def outer (x , cond ):
8669- @torch ._dynamo .disable ()
8670- def fn0 (y ):
8671- return y + 1
8672-
8673- @torch ._dynamo .disable ()
8674- def fn1 (y ):
8675- return y + 2
8676-
8677- if cond :
8678- f = fn0
8679- else :
8680- f = fn1
8681-
8682- torch ._dynamo .graph_break ()
8683- # there will be a resume function here
8684- return f (x )
8685-
8686- with unittest .mock .patch ("torch._dynamo.config.error_on_recompile" , True ):
8687- with self .assertRaises (torch ._dynamo .exc .RecompileError ):
8688- x = torch .rand (2 , 3 )
8689- self .assertEqual (outer (x , True ), torch .compile (outer )(x , True ))
8690- self .assertEqual (outer (x , False ), torch .compile (outer )(x , False ))
8691-
8692- def test_create_nested_fn_cache_clear (self ):
8693- def outer (x ):
8694- @torch ._dynamo .disable ()
8695- def f (y ):
8696- return y + 2
8697-
8698- return f (x ) + 1
8652+ def fn (a , b ):
8653+ return a + b
86998654
8700- outer = torch .compile (outer )
87018655 with unittest .mock .patch ("torch._dynamo.config.error_on_recompile" , True ):
87028656 with self .assertRaises (torch ._dynamo .exc .RecompileError ):
8703- outer (torch .randn (3 , 3 ))
8704- from torch ._dynamo .utils import create_nested_fn_cache
8705-
8706- create_nested_fn_cache .clear ()
8707- outer (torch .randn (3 , 3 ))
8657+ fn (torch .rand (2 , 3 ), torch .rand (2 , 3 ))
8658+ fn (torch .rand (2 , 3 ), (1 , 2 , 3 ))
87088659
87098660 def test_guards_strip_function_call (self ):
87108661 from torch ._dynamo .guards import strip_function_call
0 commit comments