From 585d7f75035ee982216172ee95deaa70b0fb7b5b Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 23 Jan 2026 15:37:14 +0000 Subject: [PATCH 1/4] Call post_grad passes when compile=False --- autoparallel/api.py | 26 ++++++- tests/test_api.py | 165 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 189 insertions(+), 2 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index e71e043..3969919 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -16,7 +16,6 @@ from torch._functorch.aot_autograd import ( aot_compile_joint_with_descriptors, aot_export_joint_with_descriptors, - boxed_nop_preserve_node_meta, ) from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.decomposition import select_decomp_table @@ -51,6 +50,29 @@ _APPLY_VIEW_MM_VIEW_PATTERN = False +def build_compile_fn(fake_mode): + from torch._inductor.compile_fx import ( + _recursive_post_grad_passes, + get_cuda_device_context, + ) + from torch._inductor.virtualized import V + + def boxed_nop_preserve_node_meta(fx_g, example_inputs): + with V.set_fake_mode(fake_mode): + cuda_context = get_cuda_device_context(fx_g) + with cuda_context: + _recursive_post_grad_passes(fx_g, is_inference=False) + + def run(args): + with torch.fx.traceback.preserve_node_meta(): + return torch.fx.Interpreter(fx_g).boxed_run(args) + + run._boxed_call = True + return run + + return boxed_nop_preserve_node_meta + + def _assign_attr( attr: Any, target_module: torch.nn.Module, @@ -314,7 +336,7 @@ def __init__( debug_boxed_nop_preserve_node_meta, numerics_logger=numerics_logger ) else: - self.compiler_fn = boxed_nop_preserve_node_meta # type: ignore[assignment] + self.compiler_fn = build_compile_fn(self.fake_mode) # type: ignore[assignment] self.enable_ac = enable_ac self.ac_stage_size_in_GiB = ac_stage_size_in_GiB self.reshard_after_forward = reshard_after_forward diff --git a/tests/test_api.py b/tests/test_api.py index 50d8540..5ecbc59 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -677,3 +677,168 @@ def forward(self, x): ) assert parallel_model is not None + + +# Tests for build_compile_fn and _recursive_post_grad_passes + + +def test_recursive_post_grad_passes_called_when_compile_false(device_mesh_1d): + """Test that _recursive_post_grad_passes is called when compile=False. + + This verifies the fix that ensures post-grad passes (like layout optimization) + are run even when not using the full inductor compilation pipeline. + """ + from unittest.mock import patch + + dim = 128 + + class Model(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear = nn.Linear(dim, dim) + + def forward(self, x): + return self.linear(x) + + def input_fn(): + b = 32 + return (torch.rand(b, dim, device="cuda"),) + + with torch.device("meta"): + model = Model(dim) + + # Track if _recursive_post_grad_passes was called + post_grad_passes_called = [] + + original_recursive_post_grad_passes = None + try: + from torch._inductor.compile_fx import _recursive_post_grad_passes + + original_recursive_post_grad_passes = _recursive_post_grad_passes + except ImportError: + pytest.skip("_recursive_post_grad_passes not available in this PyTorch version") + + def tracking_post_grad_passes(fx_g, is_inference=False): + post_grad_passes_called.append((fx_g, is_inference)) + # Call the original function + return original_recursive_post_grad_passes(fx_g, is_inference=is_inference) + + with patch( + "torch._inductor.compile_fx._recursive_post_grad_passes", + side_effect=tracking_post_grad_passes, + ): + with AutoParallel( + model, + input_fn, + device_mesh_1d, + compile=False, # This should use build_compile_fn + ) as autop: + autop.add_input_constraints([(Shard(0),)]) + sharding_placement = autop.optimize_placement() + _ = autop.apply_placement(sharding_placement) + + # Verify _recursive_post_grad_passes was called + assert ( + len(post_grad_passes_called) > 0 + ), "_recursive_post_grad_passes should be called when compile=False" + + # Verify it was called with is_inference=False (for training mode) + for fx_g, is_inference in post_grad_passes_called: + assert ( + is_inference is False + ), "_recursive_post_grad_passes should be called with is_inference=False" + # Verify a valid graph module was passed + assert isinstance(fx_g, torch.fx.GraphModule) + + +def test_recursive_post_grad_passes_not_called_when_compile_true(device_mesh_1d): + """Test that build_compile_fn is NOT used when compile=True. + + When compile=True, the full inductor pipeline (compile_fx_inner) should + be used instead of build_compile_fn. + """ + dim = 128 + + class Model(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear = nn.Linear(dim, dim) + + def forward(self, x): + return self.linear(x) + + def input_fn(): + b = 32 + return (torch.rand(b, dim, device="cuda"),) + + with torch.device("meta"): + model = Model(dim) + + from torch._inductor.compile_fx import compile_fx_inner + + # When compile=True, the compiler_fn should be compile_fx_inner + autop = AutoParallel( + model, + input_fn, + device_mesh_1d, + compile=True, + ) + + assert autop.compiler_fn is compile_fx_inner + + +def test_compile_false_end_to_end(device_mesh_1d): + """Test that compile=False produces a working model with post-grad passes applied. + + This is an end-to-end test that verifies the model can be created and + potentially executed (in fake mode). + """ + dim = 128 + + class Model(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear = nn.Linear(dim, dim) + self.register_buffer("scale", torch.ones(dim)) + + def forward(self, x): + return self.linear(x) * self.scale + + def init_weights(self): + nn.init.eye_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def input_fn(): + b = 512 + return (torch.rand(b, dim, device="cuda"),) + + with torch.device("meta"): + model = Model(dim) + + with AutoParallel( + model, + input_fn, + device_mesh_1d, + compile=False, + ) as autop: + autop.add_input_constraints([(Shard(0),)]) + sharding_placement = autop.optimize_placement() + parallel_mod = autop.apply_placement(sharding_placement) + + # Verify model structure + assert parallel_mod is not None + assert hasattr(parallel_mod, "linear") + + # Verify parameters are DTensors + from torch.distributed.tensor import DTensor + + for name, param in parallel_mod.named_parameters(): + assert isinstance(param, DTensor), f"Parameter {name} should be DTensor" + + # Initialize and verify + parallel_mod.to_empty(device="cuda") + parallel_mod.init_weights() + + # Verify weights were initialized correctly + weight = parallel_mod.get_parameter("linear.weight").full_tensor() + assert torch.allclose(weight, torch.eye(dim, device="cuda")) From 9a15998a221167fd6b49fcc655ea60fc6f07d23c Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 26 Jan 2026 10:26:25 +0000 Subject: [PATCH 2/4] Fix tests --- tests/test_api.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_api.py b/tests/test_api.py index 5ecbc59..af39d11 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -701,7 +701,7 @@ def forward(self, x): return self.linear(x) def input_fn(): - b = 32 + b = 512 return (torch.rand(b, dim, device="cuda"),) with torch.device("meta"): @@ -768,7 +768,7 @@ def forward(self, x): return self.linear(x) def input_fn(): - b = 32 + b = 512 return (torch.rand(b, dim, device="cuda"),) with torch.device("meta"): @@ -805,7 +805,9 @@ def forward(self, x): return self.linear(x) * self.scale def init_weights(self): - nn.init.eye_(self.linear.weight) + # TODO: can't use eye_ because of https://github.com/pytorch/pytorch/issues/173357 + # nn.init.eye_(self.linear.weight) + nn.init.ones_(self.linear.weight) nn.init.zeros_(self.linear.bias) def input_fn(): @@ -841,4 +843,6 @@ def input_fn(): # Verify weights were initialized correctly weight = parallel_mod.get_parameter("linear.weight").full_tensor() - assert torch.allclose(weight, torch.eye(dim, device="cuda")) + # TODO: can't use eye_ because of https://github.com/pytorch/pytorch/issues/173357 + # assert torch.allclose(weight, torch.eye(dim, device="cuda")) + assert torch.allclose(weight, torch.ones(dim, dim, device="cuda")) From a1eba86b3f65624ce66ff4b229448cd641379821 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 26 Jan 2026 10:36:56 +0000 Subject: [PATCH 3/4] Only call schedule_overlap_bucketing_from_inductor_configs instead of full post_grad passes --- autoparallel/api.py | 11 +++------- tests/test_api.py | 52 +++++++++++++++++++++++---------------------- 2 files changed, 30 insertions(+), 33 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 3969919..ffeb9b4 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -51,17 +51,12 @@ def build_compile_fn(fake_mode): - from torch._inductor.compile_fx import ( - _recursive_post_grad_passes, - get_cuda_device_context, + from torch._inductor.fx_passes.overlap_scheduling import ( + schedule_overlap_bucketing_from_inductor_configs, ) - from torch._inductor.virtualized import V def boxed_nop_preserve_node_meta(fx_g, example_inputs): - with V.set_fake_mode(fake_mode): - cuda_context = get_cuda_device_context(fx_g) - with cuda_context: - _recursive_post_grad_passes(fx_g, is_inference=False) + schedule_overlap_bucketing_from_inductor_configs(fx_g) def run(args): with torch.fx.traceback.preserve_node_meta(): diff --git a/tests/test_api.py b/tests/test_api.py index af39d11..3ceea49 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -679,14 +679,14 @@ def forward(self, x): assert parallel_model is not None -# Tests for build_compile_fn and _recursive_post_grad_passes +# Tests for build_compile_fn and schedule_overlap_bucketing_from_inductor_configs -def test_recursive_post_grad_passes_called_when_compile_false(device_mesh_1d): - """Test that _recursive_post_grad_passes is called when compile=False. +def test_overlap_bucketing_called_when_compile_false(device_mesh_1d): + """Test that schedule_overlap_bucketing_from_inductor_configs is called when compile=False. - This verifies the fix that ensures post-grad passes (like layout optimization) - are run even when not using the full inductor compilation pipeline. + This verifies the fix that ensures overlap scheduling passes are run + even when not using the full inductor compilation pipeline. """ from unittest.mock import patch @@ -707,25 +707,31 @@ def input_fn(): with torch.device("meta"): model = Model(dim) - # Track if _recursive_post_grad_passes was called - post_grad_passes_called = [] + # Track if schedule_overlap_bucketing_from_inductor_configs was called + overlap_bucketing_called = [] - original_recursive_post_grad_passes = None + original_schedule_overlap_bucketing = None try: - from torch._inductor.compile_fx import _recursive_post_grad_passes + from torch._inductor.fx_passes.overlap_scheduling import ( + schedule_overlap_bucketing_from_inductor_configs, + ) - original_recursive_post_grad_passes = _recursive_post_grad_passes + original_schedule_overlap_bucketing = ( + schedule_overlap_bucketing_from_inductor_configs + ) except ImportError: - pytest.skip("_recursive_post_grad_passes not available in this PyTorch version") + pytest.skip( + "schedule_overlap_bucketing_from_inductor_configs not available in this PyTorch version" + ) - def tracking_post_grad_passes(fx_g, is_inference=False): - post_grad_passes_called.append((fx_g, is_inference)) + def tracking_overlap_bucketing(fx_g): + overlap_bucketing_called.append(fx_g) # Call the original function - return original_recursive_post_grad_passes(fx_g, is_inference=is_inference) + return original_schedule_overlap_bucketing(fx_g) with patch( - "torch._inductor.compile_fx._recursive_post_grad_passes", - side_effect=tracking_post_grad_passes, + "torch._inductor.fx_passes.overlap_scheduling.schedule_overlap_bucketing_from_inductor_configs", + side_effect=tracking_overlap_bucketing, ): with AutoParallel( model, @@ -737,17 +743,13 @@ def tracking_post_grad_passes(fx_g, is_inference=False): sharding_placement = autop.optimize_placement() _ = autop.apply_placement(sharding_placement) - # Verify _recursive_post_grad_passes was called + # Verify schedule_overlap_bucketing_from_inductor_configs was called assert ( - len(post_grad_passes_called) > 0 - ), "_recursive_post_grad_passes should be called when compile=False" + len(overlap_bucketing_called) > 0 + ), "schedule_overlap_bucketing_from_inductor_configs should be called when compile=False" - # Verify it was called with is_inference=False (for training mode) - for fx_g, is_inference in post_grad_passes_called: - assert ( - is_inference is False - ), "_recursive_post_grad_passes should be called with is_inference=False" - # Verify a valid graph module was passed + # Verify a valid graph module was passed + for fx_g in overlap_bucketing_called: assert isinstance(fx_g, torch.fx.GraphModule) From f9105d1272977f67a3a6e0c9d08b0608686aba2a Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 26 Jan 2026 13:10:05 +0000 Subject: [PATCH 4/4] Disable inductor-specific configs --- autoparallel/api.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index ffeb9b4..04a99f8 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -56,7 +56,15 @@ def build_compile_fn(fake_mode): ) def boxed_nop_preserve_node_meta(fx_g, example_inputs): - schedule_overlap_bucketing_from_inductor_configs(fx_g) + # when not running with inductor, disable additional flags which are + # inductor-specific + with torch._inductor.config.patch( + { + "aten_distributed_optimizations.insert_overlap_deps": (False), + "aten_distributed_optimizations.enable_fusion_regions": (False), + } + ): + schedule_overlap_bucketing_from_inductor_configs(fx_g) def run(args): with torch.fx.traceback.preserve_node_meta():