From 6bd09d23837377518a3421d652e17edb589d3256 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 30 Jan 2026 16:22:42 +0000 Subject: [PATCH 1/3] Add support for arbitrary inputs and outputs in AutoParallel Non-tensor inputs gets baked in the graph. Need to add an assert to ensure they haven't changed from user side --- autoparallel/api.py | 196 ++++++++++++++++++++- autoparallel/shardings/ordered_sharding.py | 8 +- tests/test_api.py | 47 +++++ 3 files changed, 245 insertions(+), 6 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 83e86564..03f6d3f2 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -48,6 +48,169 @@ _APPLY_VIEW_MM_VIEW_PATTERN = False +def normalize_placeholder_name(name: str) -> str: + """ + Normalize a placeholder name to match the format of dynamo_flat_name_to_original_fqn keys. + + Placeholder names come from: re.sub(r"[^a-zA-Z0-9]+", "_", source.name) + dynamo_flat_name_to_original_fqn keys come from: OutputGraph.module_key_name(source.name) + + Both start from the same source name (e.g., L['self']._modules['linear']._parameters['weight']) + but apply different transformations: + - Placeholder: L_self___modules_linear___parameters_weight_ + - module_key_name: self_linear_weight + + We normalize the placeholder name to match the module_key_name format by: + 1. Removing _modules_, _parameters_, _buffers_ patterns + 2. Collapsing consecutive underscores + 3. Removing the guard prefix (l_self_, L_self_, etc.) + 4. Stripping leading/trailing underscores + """ + import re + + # Remove _modules_, _parameters_, _buffers_ patterns + name = re.sub(r"_modules_", "_", name) + name = re.sub(r"_parameters_", "_", name) + name = re.sub(r"_buffers_", "_", name) + # Collapse multiple underscores + name = re.sub(r"_+", "_", name) + # Strip leading/trailing underscores + name = name.strip("_") + # Remove l_self_ or L_self_ prefix (common guard access pattern) + name = re.sub(r"^[lL]_self_", "", name) + return name + + +def create_graph_removing_unused_inputs_and_adding_unused_parameters( + src_gm: torch.fx.GraphModule, + # model: nn.Module, + fake_mode, +) -> tuple[torch.fx.GraphModule, list[torch.Tensor]]: + """ + Create a new GraphModule from src_gm where parameter/buffer placeholders + are replaced with get_attr nodes. + + Uses dynamo_flat_name_to_original_fqn metadata to map placeholder names to FQNs. + + Returns (new_gm, inputs) where inputs are the non-param/buffer placeholder values. + """ + from torch.export.unflatten import _assign_attr, _AttrKind + + # fake_mode = src_gm.meta["fake_mode"] + # Create new GraphModule and register all parameters/buffers from src_gm + gm = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + for name, mod in src_gm.named_children(): + if isinstance(mod, torch.fx.GraphModule): + gm.add_submodule(name, mod) + # gm.add_submodule(name, mod) + for fqn, param in src_gm.named_parameters(): + _assign_attr(param, gm, fqn, _AttrKind.PARAMETER) + for fqn, buf in src_gm.named_buffers(): + _assign_attr(buf, gm, fqn, _AttrKind.BUFFER) + + # Build lookup from normalized name to FQN using dynamo_flat_name_to_original_fqn + # The keys in dynamo_flat_name_to_original_fqn are created by module_key_name(source.name) + # We normalize these keys to match our normalized placeholder names + flat_name_to_fqn = src_gm.meta.get("dynamo_flat_name_to_original_fqn", {}) + normalized_to_fqn = {} + for flat_name, fqn in flat_name_to_fqn.items(): + # module_key_name output is already mostly normalized, but we apply the same + # normalization to ensure consistency + normalized = normalize_placeholder_name(flat_name) + normalized_to_fqn[normalized] = fqn + + param_fqns = {fqn for fqn, _ in src_gm.named_parameters()} + buffer_fqns = {fqn for fqn, _ in src_gm.named_buffers()} + + graph = gm.graph + val_map = {} + inputs = [] + used_params = set() + used_buffers = set() + + for node in src_gm.graph.nodes: + if node.op == "placeholder": + example_val = node.meta.get("example_value") + + # Try to find FQN by normalizing placeholder name and looking up + normalized_name = normalize_placeholder_name(node.name) + fqn = normalized_to_fqn.get(normalized_name) # type: ignore[assignment] + + is_param = fqn is not None and fqn in param_fqns + is_buffer = fqn is not None and fqn in buffer_fqns + + if is_param: + # Parameter placeholder -> get_attr + get_attr_node = graph.get_attr(fqn) + get_attr_node.meta = node.meta.copy() + val_map[node] = get_attr_node + used_params.add(fqn) + elif is_buffer: + # Buffer placeholder -> get_attr + get_attr_node = graph.get_attr(fqn) + get_attr_node.meta = node.meta.copy() + val_map[node] = get_attr_node + used_buffers.add(fqn) + else: + # Regular input placeholder - copy as-is + val_map[node] = graph.node_copy(node, lambda n: val_map[n]) + if example_val is not None and hasattr(example_val, "shape"): + with fake_mode: + inputs.append( + torch.empty_strided( + example_val.shape, + example_val.stride(), + dtype=example_val.dtype, + device=example_val.device, + requires_grad=example_val.requires_grad, + ) + ) + else: + # Copy all other nodes + val_map[node] = graph.node_copy(node, lambda n: val_map[n]) + + # Add get_attr for unused parameters (not in the original graph) + # Insert before the first non-placeholder/non-get_attr node + insert_point = None + for node in graph.nodes: + if node.op not in ("placeholder", "get_attr"): + insert_point = node + break + + for fqn in param_fqns - used_params: + param = src_gm.get_parameter(fqn) + if insert_point is not None: + with graph.inserting_before(insert_point): + get_attr_node = graph.get_attr(fqn) + with fake_mode: + get_attr_node.meta["example_value"] = torch.empty_strided( + param.shape, + param.stride(), + dtype=param.dtype, + device=param.device, + requires_grad=param.requires_grad, + ) + + for fqn in buffer_fqns - used_buffers: + buf = src_gm.get_buffer(fqn) + if insert_point is not None: + with graph.inserting_before(insert_point): + get_attr_node = graph.get_attr(fqn) + with fake_mode: + get_attr_node.meta["example_value"] = torch.empty_strided( + buf.shape, + buf.stride(), + dtype=buf.dtype, + device=buf.device, + requires_grad=buf.requires_grad, + ) + + graph.lint() + gm.recompile() + + return gm, inputs + + def _assign_attr( attr: Any, target_module: torch.nn.Module, @@ -376,13 +539,29 @@ def build_model_graph(self): with set_dtype_cast( True ), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): - torch_ir_with_fqn = _export(self.model, model_wrapper, formatted_inputs) + # torch_ir_with_fqn = _export(self.model, model_wrapper, formatted_inputs) + from torch._dynamo.functional_export import dynamo_graph_capture_for_export + + torch_ir_with_fqn = dynamo_graph_capture_for_export(self.model)( + *formatted_inputs + ) + self.flatten_fn = torch_ir_with_fqn._dynamo_bytecode_flatten + self.unflatten_fn = torch_ir_with_fqn._dynamo_bytecode_unflatten + # from IPython import embed; embed(); exit() + ( + torch_ir_with_fqn2, + inputs2, + ) = create_graph_removing_unused_inputs_and_adding_unused_parameters( + torch_ir_with_fqn, self.fake_mode + ) + # from IPython import embed; embed(); exit() # TODO Cna't use fake mode here because it clashes with the user level # fake mode. Ideally dynamo should reuse the user level fake mode. self.joint_with_descriptors = aot_export_joint_with_descriptors( self.stack, - torch_ir_with_fqn, - formatted_inputs, + torch_ir_with_fqn2, + # formatted_inputs, + inputs2, decompositions=decomp_table, ) gm = self.joint_with_descriptors.graph_module @@ -607,6 +786,8 @@ def apply_placement(self, sharding_placement=None): bw_compiler=self.compiler_fn, ) + unflatten_fn = self.unflatten_fn + # TODO: this probably belongs in the AOTAutograd API # TODO: pytree handling class AutoParallelModule(torch.nn.Module): @@ -624,10 +805,17 @@ def forward(self, *args): dict(self.named_buffers(remove_duplicate=False)).items(), ) ] - boxed_args = [*params, *args] + from IPython import embed + + embed() + exit() + # new_args = flatten_fn(*args) + filtered_args = [x for x in args if isinstance(x, torch.Tensor)] + boxed_args = [*params, *filtered_args] del params # NB: don't do self.parallel_model_fn work around Dynamo bug out = parallel_model_fn(boxed_args) + out = unflatten_fn(out, args) return out self.parallel_model = AutoParallelModule() diff --git a/autoparallel/shardings/ordered_sharding.py b/autoparallel/shardings/ordered_sharding.py index 5298a505..926c99e0 100644 --- a/autoparallel/shardings/ordered_sharding.py +++ b/autoparallel/shardings/ordered_sharding.py @@ -94,7 +94,7 @@ def get_redistributed_input_placements( x for x in tree_flatten(node.args)[0] if isinstance(x, torch.fx.Node) ] num_input_nodes = len(all_input_nodes) - curr_specs: list[Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]]] = [ + curr_specs: list[Union[DTensorSpec, tuple[Optional[DTensorSpec], ...], None]] = [ sharding_placement[n].output_specs for n in all_input_nodes ] # FIXME ? if node.target == operator.getitem: @@ -149,7 +149,11 @@ def compute_optimal_placement_order_for_parameters(module, sharding_placement): param_and_grad_users = {} param_grad_chain = {} for param, grad in param_and_grad_nodes: - last_p = list(param.users)[0] + param_users = list(param.users) + if not param_users: + # if unused parameter, don't bother with it + continue + last_p = param_users[0] p_chain = [param] # get all linear chain of users of the parameter while len(last_p.all_input_nodes) == 1: diff --git a/tests/test_api.py b/tests/test_api.py index 50d85402..d92ec77c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -117,6 +117,53 @@ def input_fn(): ) +def test_non_tensor_input(device_mesh_1d): + dim = 128 + + class Model(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + self.linear = nn.Linear(dim, dim) + self.unused = nn.Parameter(torch.rand(1)) + + def forward(self, x, input_dim: int, has_compute: bool): + if has_compute: + return self.linear(x) + input_dim, x.shape + else: + return x + + def init_weights(self): + dim = self.dim + self.linear.weight = torch.nn.Parameter(torch.ones(dim, dim) * 9.0) + with torch.no_grad(): + self.linear.bias.fill_(98.6) + + def input_fn(divisor=1): + b = 512 // divisor + inputs = torch.rand(b, dim, device="cuda") + input_dim = 1 + has_compute = True + return (inputs, input_dim, has_compute) + + with torch.device("meta"): + model = Model(dim) + with AutoParallel(model, input_fn, device_mesh_1d, compile=True) as autop: + x_sharding = (Shard(0),) + autop.add_input_constraints([x_sharding, None]) + sharding_placement = autop.optimize_placement() + + parallel_mod = autop.apply_placement(sharding_placement) + parallel_mod.to_empty(device="cuda") + parallel_mod.init_weights() + placeholders = autop.gm.graph.find_nodes(op="placeholder") + # 2 used parameters, 1 unused parameter, 1 input and 1 tangent + assert len(placeholders) == 5 + inputs = input_fn(device_mesh_1d.shape[0]) + out, shape = parallel_mod(*inputs) + assert input_fn()[0].shape == shape + + def test_fx_graph_annotate(device_mesh_1d): dim = 128 From fa87740207b2571a3cbb388c50786001dcb778af Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 31 Jan 2026 15:01:43 +0000 Subject: [PATCH 2/3] Remove debug code --- autoparallel/api.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 03f6d3f2..fb3f356c 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -805,10 +805,6 @@ def forward(self, *args): dict(self.named_buffers(remove_duplicate=False)).items(), ) ] - from IPython import embed - - embed() - exit() # new_args = flatten_fn(*args) filtered_args = [x for x in args if isinstance(x, torch.Tensor)] boxed_args = [*params, *filtered_args] From d0ec2e2b482b04319fe9c9e692e9a1dbead087bd Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 9 Feb 2026 14:31:27 +0000 Subject: [PATCH 3/3] [WIP] --- autoparallel/api.py | 83 ++++++++++++++++++++++---------- tests/test_optimize_placement.py | 7 +++ 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index fb3f356c..0572f88e 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -83,7 +83,6 @@ def normalize_placeholder_name(name: str) -> str: def create_graph_removing_unused_inputs_and_adding_unused_parameters( src_gm: torch.fx.GraphModule, - # model: nn.Module, fake_mode, ) -> tuple[torch.fx.GraphModule, list[torch.Tensor]]: """ @@ -94,19 +93,59 @@ def create_graph_removing_unused_inputs_and_adding_unused_parameters( Returns (new_gm, inputs) where inputs are the non-param/buffer placeholder values. """ + from torch._subclasses.fake_tensor import FakeTensor from torch.export.unflatten import _assign_attr, _AttrKind - # fake_mode = src_gm.meta["fake_mode"] + # Helper to create fresh fake tensors in the target fake mode + def to_fake(t: torch.Tensor) -> torch.Tensor: + with fake_mode: + fake_t = torch.empty_strided( + t.shape, + t.stride(), + dtype=t.dtype, + device=t.device, + requires_grad=t.requires_grad, + ) + if isinstance(t, torch.nn.Parameter): + return torch.nn.Parameter(fake_t, requires_grad=t.requires_grad) + return fake_t + + # Helper to convert example_value in node metadata to the target fake mode + def convert_node_meta(meta: dict) -> dict: + new_meta = meta.copy() + example_val = new_meta.get("example_value") + if example_val is not None: + if isinstance(example_val, FakeTensor) and hasattr(example_val, "shape"): + new_meta["example_value"] = to_fake(example_val) + elif isinstance(example_val, (list, tuple)): + # Handle tuple/list of tensors + converted = [] + for v in example_val: + if isinstance(v, FakeTensor) and hasattr(v, "shape"): + converted.append(to_fake(v)) + else: + converted.append(v) + new_meta["example_value"] = type(example_val)(converted) + return new_meta + # Create new GraphModule and register all parameters/buffers from src_gm + # Convert them to the target fake mode to avoid fake mode mixing gm = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + + # Only copy specific metadata keys we need, avoiding tracing_context and other + # internal state that might hold references to dynamo's fake mode + for key in ["dynamo_flat_name_to_original_fqn", "module_call_specs"]: + if key in src_gm.meta: + gm.meta[key] = src_gm.meta[key] + gm.meta["fake_mode"] = fake_mode + for name, mod in src_gm.named_children(): if isinstance(mod, torch.fx.GraphModule): gm.add_submodule(name, mod) - # gm.add_submodule(name, mod) for fqn, param in src_gm.named_parameters(): - _assign_attr(param, gm, fqn, _AttrKind.PARAMETER) + _assign_attr(to_fake(param), gm, fqn, _AttrKind.PARAMETER) for fqn, buf in src_gm.named_buffers(): - _assign_attr(buf, gm, fqn, _AttrKind.BUFFER) + _assign_attr(to_fake(buf), gm, fqn, _AttrKind.BUFFER) # Build lookup from normalized name to FQN using dynamo_flat_name_to_original_fqn # The keys in dynamo_flat_name_to_original_fqn are created by module_key_name(source.name) @@ -142,32 +181,27 @@ def create_graph_removing_unused_inputs_and_adding_unused_parameters( if is_param: # Parameter placeholder -> get_attr get_attr_node = graph.get_attr(fqn) - get_attr_node.meta = node.meta.copy() + get_attr_node.meta = convert_node_meta(node.meta) val_map[node] = get_attr_node used_params.add(fqn) elif is_buffer: # Buffer placeholder -> get_attr get_attr_node = graph.get_attr(fqn) - get_attr_node.meta = node.meta.copy() + get_attr_node.meta = convert_node_meta(node.meta) val_map[node] = get_attr_node used_buffers.add(fqn) else: # Regular input placeholder - copy as-is - val_map[node] = graph.node_copy(node, lambda n: val_map[n]) + new_node = graph.node_copy(node, lambda n: val_map[n]) + new_node.meta = convert_node_meta(node.meta) if example_val is not None and hasattr(example_val, "shape"): - with fake_mode: - inputs.append( - torch.empty_strided( - example_val.shape, - example_val.stride(), - dtype=example_val.dtype, - device=example_val.device, - requires_grad=example_val.requires_grad, - ) - ) + inputs.append(to_fake(example_val)) + val_map[node] = new_node else: - # Copy all other nodes - val_map[node] = graph.node_copy(node, lambda n: val_map[n]) + # Copy all other nodes and convert their metadata + new_node = graph.node_copy(node, lambda n: val_map[n]) + new_node.meta = convert_node_meta(node.meta) + val_map[node] = new_node # Add get_attr for unused parameters (not in the original graph) # Insert before the first non-placeholder/non-get_attr node @@ -539,7 +573,6 @@ def build_model_graph(self): with set_dtype_cast( True ), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): - # torch_ir_with_fqn = _export(self.model, model_wrapper, formatted_inputs) from torch._dynamo.functional_export import dynamo_graph_capture_for_export torch_ir_with_fqn = dynamo_graph_capture_for_export(self.model)( @@ -547,20 +580,18 @@ def build_model_graph(self): ) self.flatten_fn = torch_ir_with_fqn._dynamo_bytecode_flatten self.unflatten_fn = torch_ir_with_fqn._dynamo_bytecode_unflatten - # from IPython import embed; embed(); exit() ( torch_ir_with_fqn2, inputs2, ) = create_graph_removing_unused_inputs_and_adding_unused_parameters( torch_ir_with_fqn, self.fake_mode ) - # from IPython import embed; embed(); exit() - # TODO Cna't use fake mode here because it clashes with the user level - # fake mode. Ideally dynamo should reuse the user level fake mode. + # Clear references to the original dynamo graph to help garbage collection + # and reduce state leakage between uses + del torch_ir_with_fqn self.joint_with_descriptors = aot_export_joint_with_descriptors( self.stack, torch_ir_with_fqn2, - # formatted_inputs, inputs2, decompositions=decomp_table, ) diff --git a/tests/test_optimize_placement.py b/tests/test_optimize_placement.py index 6bc1af56..786e91df 100644 --- a/tests/test_optimize_placement.py +++ b/tests/test_optimize_placement.py @@ -16,6 +16,13 @@ from autoparallel.api import AutoParallel +# @pytest.fixture(autouse=True) +# def reset_dynamo(): +# """Reset dynamo state before each test to ensure test isolation.""" +# torch._dynamo.reset() +# yield +# torch._dynamo.reset() + @pytest.fixture(scope="module", autouse=True) def init_pg():