@@ -5609,11 +5609,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
56095609 dim0_x = torch.export.Dim("dim0_x", min=3)
56105610 dim1_x = torch.export.Dim("dim1_x", max=8000)
56115611 dynamic_shapes = {"x": (dim0_x, dim1_x)}
5612- em = torch.export.export (
5612+ em = torch.export._trace._export (
56135613 m,
56145614 (a,),
56155615 dynamic_shapes=dynamic_shapes,
5616- prefer_deferred_runtime_asserts_over_guards =True,
5616+ allow_complex_guards_as_runtime_asserts =True,
56175617 )
56185618 em.module()(torch.randn(4, 3))
56195619 with self.assertRaisesRegex(
@@ -13497,7 +13497,7 @@ def forward(self, x):
1349713497
1349813498 def test_disable_forced_specializations_ok(self):
1349913499 # check that we don't force specialization, and defer to runtime asserts
13500- # with prefer_deferred_runtime_asserts_over_guards =True to successfully export
13500+ # with allow_complex_guards_as_runtime_asserts =True to successfully export
1350113501 # case 1: modulo guards
1350213502 from torch.export import dims
1350313503
@@ -13507,11 +13507,11 @@ def forward(self, x):
1350713507
1350813508 inputs = (torch.randn(10, 72),)
1350913509 dx, dy = dims("dx", "dy")
13510- ep = torch.export.export (
13510+ ep = torch.export._trace._export (
1351113511 Mod4Reshape(),
1351213512 inputs,
1351313513 dynamic_shapes={"x": (dx, dy)},
13514- prefer_deferred_runtime_asserts_over_guards =True,
13514+ allow_complex_guards_as_runtime_asserts =True,
1351513515 )
1351613516 out1 = ep.module()(torch.randn(8, 7))
1351713517 self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
@@ -13541,11 +13541,11 @@ def forward(self, x, y, z):
1354113541
1354213542 for private_api in (True, False):
1354313543 if private_api:
13544- ep = torch.export.export (
13544+ ep = torch.export._trace._export (
1354513545 FreeReshape(),
1354613546 inputs,
1354713547 dynamic_shapes=dynamic_shapes,
13548- prefer_deferred_runtime_asserts_over_guards =True,
13548+ allow_complex_guards_as_runtime_asserts =True,
1354913549 )
1355013550 else:
1355113551 ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
@@ -13582,11 +13582,11 @@ def forward(self, x, y):
1358213582 "x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
1358313583 "y": (Dim("dy", min=8),),
1358413584 }
13585- ep = torch.export.export (
13585+ ep = torch.export._trace._export (
1358613586 Reshape3d(),
1358713587 inputs,
1358813588 dynamic_shapes=dynamic_shapes,
13589- prefer_deferred_runtime_asserts_over_guards =True,
13589+ allow_complex_guards_as_runtime_asserts =True,
1359013590 )
1359113591 out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126))
1359213592 self.assertEqual(out1.shape, torch.ones(126).shape)
@@ -13708,11 +13708,11 @@ def forward(self, x):
1370813708 model = Model()
1370913709 x = torch.rand(1024, 20, 16)
1371013710 dynamic_shapes = {"x": {0: Dim("batch")}}
13711- ep = torch.export.export (
13711+ ep = torch.export._trace._export (
1371213712 model,
1371313713 (x,),
1371413714 dynamic_shapes=dynamic_shapes,
13715- prefer_deferred_runtime_asserts_over_guards =True,
13715+ allow_complex_guards_as_runtime_asserts =True,
1371613716 )
1371713717 with self.assertRaisesRegex(
1371813718 RuntimeError,
@@ -13785,11 +13785,11 @@ def forward(self, x, y):
1378513785
1378613786 inputs = (torch.randn(6), torch.randn(12))
1378713787 dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]}
13788- ep = torch.export.export (
13788+ ep = torch.export._trace._export (
1378913789 Foo(),
1379013790 inputs,
1379113791 dynamic_shapes=dynamic_shapes,
13792- prefer_deferred_runtime_asserts_over_guards =True,
13792+ allow_complex_guards_as_runtime_asserts =True,
1379313793 )
1379413794 # check forward pass
1379513795 out0, out1 = ep.module()(torch.randn(9), torch.randn(27))
@@ -13824,7 +13824,7 @@ def forward(self, x, y):
1382413824 Foo(),
1382513825 inputs,
1382613826 dynamic_shapes=dynamic_shapes,
13827- prefer_deferred_runtime_asserts_over_guards =True,
13827+ allow_complex_guards_as_runtime_asserts =True,
1382813828 ).run_decompositions()
1382913829
1383013830 self.assertEqual(
@@ -14236,11 +14236,11 @@ def forward(self, x, y):
1423614236
1423714237 inputs = (torch.randn(5), torch.randn(3))
1423814238 shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)}
14239- ep = torch.export.export (
14239+ ep = torch.export._trace._export (
1424014240 Foo(),
1424114241 inputs,
1424214242 dynamic_shapes=shapes,
14243- prefer_deferred_runtime_asserts_over_guards =True,
14243+ allow_complex_guards_as_runtime_asserts =True,
1424414244 )
1424514245 # count 2 pow nodes, 2 sym_size.int nodes
1424614246 self.assertEqual(
@@ -15039,11 +15039,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1503915039
1504015040 for private_api in (True, False):
1504115041 if private_api:
15042- ep = torch.export.export (
15042+ ep = torch.export._trace._export (
1504315043 ModConstraint(),
1504415044 (torch.randn(3, 4),),
1504515045 dynamic_shapes={"x": (dynamic, dynamic)},
15046- prefer_deferred_runtime_asserts_over_guards =True,
15046+ allow_complex_guards_as_runtime_asserts =True,
1504715047 )
1504815048 else:
1504915049 ep = export(
@@ -15057,7 +15057,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1505715057 for node in ep.graph.nodes
1505815058 ].count(True)
1505915059 if private_api:
15060- self.assertEqual(num_asserts, 6 )
15060+ self.assertEqual(num_asserts, 7 )
1506115061 with self.assertRaisesRegex(
1506215062 RuntimeError,
1506315063 r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)",
0 commit comments