Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from torch._functorch.aot_autograd import (
aot_compile_joint_with_descriptors,
aot_export_joint_with_descriptors,
boxed_nop_preserve_node_meta,
)
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.decomposition import select_decomp_table
Expand Down Expand Up @@ -51,6 +50,32 @@
_APPLY_VIEW_MM_VIEW_PATTERN = False


def build_compile_fn(fake_mode):
from torch._inductor.fx_passes.overlap_scheduling import (
schedule_overlap_bucketing_from_inductor_configs,
)

def boxed_nop_preserve_node_meta(fx_g, example_inputs):
# when not running with inductor, disable additional flags which are
# inductor-specific
with torch._inductor.config.patch(
{
"aten_distributed_optimizations.insert_overlap_deps": (False),
"aten_distributed_optimizations.enable_fusion_regions": (False),
}
):
schedule_overlap_bucketing_from_inductor_configs(fx_g)

def run(args):
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(fx_g).boxed_run(args)

run._boxed_call = True
return run

return boxed_nop_preserve_node_meta


def _assign_attr(
attr: Any,
target_module: torch.nn.Module,
Expand Down Expand Up @@ -314,7 +339,7 @@ def __init__(
debug_boxed_nop_preserve_node_meta, numerics_logger=numerics_logger
)
else:
self.compiler_fn = boxed_nop_preserve_node_meta # type: ignore[assignment]
self.compiler_fn = build_compile_fn(self.fake_mode) # type: ignore[assignment]
self.enable_ac = enable_ac
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
self.reshard_after_forward = reshard_after_forward
Expand Down
171 changes: 171 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,3 +677,174 @@ def forward(self, x):
)

assert parallel_model is not None


# Tests for build_compile_fn and schedule_overlap_bucketing_from_inductor_configs


def test_overlap_bucketing_called_when_compile_false(device_mesh_1d):
"""Test that schedule_overlap_bucketing_from_inductor_configs is called when compile=False.

This verifies the fix that ensures overlap scheduling passes are run
even when not using the full inductor compilation pipeline.
"""
from unittest.mock import patch

dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)

def forward(self, x):
return self.linear(x)

def input_fn():
b = 512
return (torch.rand(b, dim, device="cuda"),)

with torch.device("meta"):
model = Model(dim)

# Track if schedule_overlap_bucketing_from_inductor_configs was called
overlap_bucketing_called = []

original_schedule_overlap_bucketing = None
try:
from torch._inductor.fx_passes.overlap_scheduling import (
schedule_overlap_bucketing_from_inductor_configs,
)

original_schedule_overlap_bucketing = (
schedule_overlap_bucketing_from_inductor_configs
)
except ImportError:
pytest.skip(
"schedule_overlap_bucketing_from_inductor_configs not available in this PyTorch version"
)

def tracking_overlap_bucketing(fx_g):
overlap_bucketing_called.append(fx_g)
# Call the original function
return original_schedule_overlap_bucketing(fx_g)

with patch(
"torch._inductor.fx_passes.overlap_scheduling.schedule_overlap_bucketing_from_inductor_configs",
side_effect=tracking_overlap_bucketing,
):
with AutoParallel(
model,
input_fn,
device_mesh_1d,
compile=False, # This should use build_compile_fn
) as autop:
autop.add_input_constraints([(Shard(0),)])
sharding_placement = autop.optimize_placement()
_ = autop.apply_placement(sharding_placement)

# Verify schedule_overlap_bucketing_from_inductor_configs was called
assert (
len(overlap_bucketing_called) > 0
), "schedule_overlap_bucketing_from_inductor_configs should be called when compile=False"

# Verify a valid graph module was passed
for fx_g in overlap_bucketing_called:
assert isinstance(fx_g, torch.fx.GraphModule)


def test_recursive_post_grad_passes_not_called_when_compile_true(device_mesh_1d):
"""Test that build_compile_fn is NOT used when compile=True.

When compile=True, the full inductor pipeline (compile_fx_inner) should
be used instead of build_compile_fn.
"""
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)

def forward(self, x):
return self.linear(x)

def input_fn():
b = 512
return (torch.rand(b, dim, device="cuda"),)

with torch.device("meta"):
model = Model(dim)

from torch._inductor.compile_fx import compile_fx_inner

# When compile=True, the compiler_fn should be compile_fx_inner
autop = AutoParallel(
model,
input_fn,
device_mesh_1d,
compile=True,
)

assert autop.compiler_fn is compile_fx_inner


def test_compile_false_end_to_end(device_mesh_1d):
"""Test that compile=False produces a working model with post-grad passes applied.

This is an end-to-end test that verifies the model can be created and
potentially executed (in fake mode).
"""
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)
self.register_buffer("scale", torch.ones(dim))

def forward(self, x):
return self.linear(x) * self.scale

def init_weights(self):
# TODO: can't use eye_ because of https://github.com/pytorch/pytorch/issues/173357
# nn.init.eye_(self.linear.weight)
nn.init.ones_(self.linear.weight)
nn.init.zeros_(self.linear.bias)

def input_fn():
b = 512
return (torch.rand(b, dim, device="cuda"),)

with torch.device("meta"):
model = Model(dim)

with AutoParallel(
model,
input_fn,
device_mesh_1d,
compile=False,
) as autop:
autop.add_input_constraints([(Shard(0),)])
sharding_placement = autop.optimize_placement()
parallel_mod = autop.apply_placement(sharding_placement)

# Verify model structure
assert parallel_mod is not None
assert hasattr(parallel_mod, "linear")

# Verify parameters are DTensors
from torch.distributed.tensor import DTensor

for name, param in parallel_mod.named_parameters():
assert isinstance(param, DTensor), f"Parameter {name} should be DTensor"

# Initialize and verify
parallel_mod.to_empty(device="cuda")
parallel_mod.init_weights()

# Verify weights were initialized correctly
weight = parallel_mod.get_parameter("linear.weight").full_tensor()
# TODO: can't use eye_ because of https://github.com/pytorch/pytorch/issues/173357
# assert torch.allclose(weight, torch.eye(dim, device="cuda"))
assert torch.allclose(weight, torch.ones(dim, dim, device="cuda"))
Loading