diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index 7fc9d13bf4..9ac8440048 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -38,7 +38,7 @@ def get_compile_backend_with_passes( # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 from torch._inductor.config import aten_distributed_optimizations as dist_opts from torch._inductor.fx_passes.overlap_scheduling import ( - schedule_overlap_bucketing, + schedule_overlap_bucketing_from_inductor_configs, ) dist_opts.collective_bucketing = True @@ -52,7 +52,7 @@ def get_compile_backend_with_passes( def aot_eager_autobucketing_reordering_pass( gm: torch.fx.GraphModule, example_inputs: Any ) -> torch.fx.GraphModule: - schedule_overlap_bucketing(gm) + schedule_overlap_bucketing_from_inductor_configs(gm) gm.recompile() return gm @@ -67,7 +67,7 @@ def aot_eager_autobucketing_reordering_pass( def inductor_autobucketing_reordering_pass( gm: torch.fx.Graph, ) -> torch.fx.GraphModule: - return schedule_overlap_bucketing(gm.owning_module) + return schedule_overlap_bucketing_from_inductor_configs(gm.owning_module) dist_opts.insert_overlap_deps = True torch._inductor.config.reorder_for_peak_memory = False