diff --git a/autoparallel/api.py b/autoparallel/api.py index e71e043..04a99f8 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -16,7 +16,6 @@ from torch._functorch.aot_autograd import ( aot_compile_joint_with_descriptors, aot_export_joint_with_descriptors, - boxed_nop_preserve_node_meta, ) from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.decomposition import select_decomp_table @@ -51,6 +50,32 @@ _APPLY_VIEW_MM_VIEW_PATTERN = False +def build_compile_fn(fake_mode): + from torch._inductor.fx_passes.overlap_scheduling import ( + schedule_overlap_bucketing_from_inductor_configs, + ) + + def boxed_nop_preserve_node_meta(fx_g, example_inputs): + # when not running with inductor, disable additional flags which are + # inductor-specific + with torch._inductor.config.patch( + { + "aten_distributed_optimizations.insert_overlap_deps": (False), + "aten_distributed_optimizations.enable_fusion_regions": (False), + } + ): + schedule_overlap_bucketing_from_inductor_configs(fx_g) + + def run(args): + with torch.fx.traceback.preserve_node_meta(): + return torch.fx.Interpreter(fx_g).boxed_run(args) + + run._boxed_call = True + return run + + return boxed_nop_preserve_node_meta + + def _assign_attr( attr: Any, target_module: torch.nn.Module, @@ -314,7 +339,7 @@ def __init__( debug_boxed_nop_preserve_node_meta, numerics_logger=numerics_logger ) else: - self.compiler_fn = boxed_nop_preserve_node_meta # type: ignore[assignment] + self.compiler_fn = build_compile_fn(self.fake_mode) # type: ignore[assignment] self.enable_ac = enable_ac self.ac_stage_size_in_GiB = ac_stage_size_in_GiB self.reshard_after_forward = reshard_after_forward diff --git a/tests/test_api.py b/tests/test_api.py index 50d8540..3ceea49 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -677,3 +677,174 @@ def forward(self, x): ) assert parallel_model is not None + + +# Tests for build_compile_fn and schedule_overlap_bucketing_from_inductor_configs + + +def test_overlap_bucketing_called_when_compile_false(device_mesh_1d): + """Test that schedule_overlap_bucketing_from_inductor_configs is called when compile=False. + + This verifies the fix that ensures overlap scheduling passes are run + even when not using the full inductor compilation pipeline. + """ + from unittest.mock import patch + + dim = 128 + + class Model(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear = nn.Linear(dim, dim) + + def forward(self, x): + return self.linear(x) + + def input_fn(): + b = 512 + return (torch.rand(b, dim, device="cuda"),) + + with torch.device("meta"): + model = Model(dim) + + # Track if schedule_overlap_bucketing_from_inductor_configs was called + overlap_bucketing_called = [] + + original_schedule_overlap_bucketing = None + try: + from torch._inductor.fx_passes.overlap_scheduling import ( + schedule_overlap_bucketing_from_inductor_configs, + ) + + original_schedule_overlap_bucketing = ( + schedule_overlap_bucketing_from_inductor_configs + ) + except ImportError: + pytest.skip( + "schedule_overlap_bucketing_from_inductor_configs not available in this PyTorch version" + ) + + def tracking_overlap_bucketing(fx_g): + overlap_bucketing_called.append(fx_g) + # Call the original function + return original_schedule_overlap_bucketing(fx_g) + + with patch( + "torch._inductor.fx_passes.overlap_scheduling.schedule_overlap_bucketing_from_inductor_configs", + side_effect=tracking_overlap_bucketing, + ): + with AutoParallel( + model, + input_fn, + device_mesh_1d, + compile=False, # This should use build_compile_fn + ) as autop: + autop.add_input_constraints([(Shard(0),)]) + sharding_placement = autop.optimize_placement() + _ = autop.apply_placement(sharding_placement) + + # Verify schedule_overlap_bucketing_from_inductor_configs was called + assert ( + len(overlap_bucketing_called) > 0 + ), "schedule_overlap_bucketing_from_inductor_configs should be called when compile=False" + + # Verify a valid graph module was passed + for fx_g in overlap_bucketing_called: + assert isinstance(fx_g, torch.fx.GraphModule) + + +def test_recursive_post_grad_passes_not_called_when_compile_true(device_mesh_1d): + """Test that build_compile_fn is NOT used when compile=True. + + When compile=True, the full inductor pipeline (compile_fx_inner) should + be used instead of build_compile_fn. + """ + dim = 128 + + class Model(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear = nn.Linear(dim, dim) + + def forward(self, x): + return self.linear(x) + + def input_fn(): + b = 512 + return (torch.rand(b, dim, device="cuda"),) + + with torch.device("meta"): + model = Model(dim) + + from torch._inductor.compile_fx import compile_fx_inner + + # When compile=True, the compiler_fn should be compile_fx_inner + autop = AutoParallel( + model, + input_fn, + device_mesh_1d, + compile=True, + ) + + assert autop.compiler_fn is compile_fx_inner + + +def test_compile_false_end_to_end(device_mesh_1d): + """Test that compile=False produces a working model with post-grad passes applied. + + This is an end-to-end test that verifies the model can be created and + potentially executed (in fake mode). + """ + dim = 128 + + class Model(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear = nn.Linear(dim, dim) + self.register_buffer("scale", torch.ones(dim)) + + def forward(self, x): + return self.linear(x) * self.scale + + def init_weights(self): + # TODO: can't use eye_ because of https://github.com/pytorch/pytorch/issues/173357 + # nn.init.eye_(self.linear.weight) + nn.init.ones_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def input_fn(): + b = 512 + return (torch.rand(b, dim, device="cuda"),) + + with torch.device("meta"): + model = Model(dim) + + with AutoParallel( + model, + input_fn, + device_mesh_1d, + compile=False, + ) as autop: + autop.add_input_constraints([(Shard(0),)]) + sharding_placement = autop.optimize_placement() + parallel_mod = autop.apply_placement(sharding_placement) + + # Verify model structure + assert parallel_mod is not None + assert hasattr(parallel_mod, "linear") + + # Verify parameters are DTensors + from torch.distributed.tensor import DTensor + + for name, param in parallel_mod.named_parameters(): + assert isinstance(param, DTensor), f"Parameter {name} should be DTensor" + + # Initialize and verify + parallel_mod.to_empty(device="cuda") + parallel_mod.init_weights() + + # Verify weights were initialized correctly + weight = parallel_mod.get_parameter("linear.weight").full_tensor() + # TODO: can't use eye_ because of https://github.com/pytorch/pytorch/issues/173357 + # assert torch.allclose(weight, torch.eye(dim, device="cuda")) + assert torch.allclose(weight, torch.ones(dim, dim, device="cuda"))