From c0aa1d821a7e49891536a1163df2411363759fe5 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 21 Nov 2025 21:54:26 +0000 Subject: [PATCH 1/4] Add support for non-tensor inputs in AutoParallel --- autoparallel/apply_sharding.py | 8 ++++- autoparallel/graph_utils.py | 9 ++++-- autoparallel/optimize_sharding.py | 27 +++++++++++++++-- autoparallel/propagation_rules.py | 25 ---------------- tests/test_api.py | 49 +++++++++++++++++++++++++++++++ 5 files changed, 87 insertions(+), 31 deletions(-) diff --git a/autoparallel/apply_sharding.py b/autoparallel/apply_sharding.py index 1f50abc1..51de0f16 100644 --- a/autoparallel/apply_sharding.py +++ b/autoparallel/apply_sharding.py @@ -250,6 +250,12 @@ def shard_node_given_placements(node, sharding_placement, *, meta: bool): mesh = tgt_spec.mesh # all tensors start as replicated curr_placement = (Replicate(),) * mesh.ndim + if "val" not in node.meta: + # for non-tensor inputs, they are considered as being + # baked in the graph, so we don't need to do anything + # and just return a dummy value + assert len(node.users) == 0 + return "arbitrary value" tensor = node.meta["val"] ctx: Any @@ -303,7 +309,7 @@ def _get_inductor_decomp_table(): def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec): args = shard_nodes_given_placements(gm, sharding_placement) - local_args = [arg.to_local() for arg in args] + local_args = tree_map_only(DTensor, lambda x: x.to_local(), args) decomp_table = _get_inductor_decomp_table() # run with DTensor to apply the collectives given the graph diff --git a/autoparallel/graph_utils.py b/autoparallel/graph_utils.py index 1e708aea..bc8df925 100644 --- a/autoparallel/graph_utils.py +++ b/autoparallel/graph_utils.py @@ -52,7 +52,10 @@ def update_joint_with_descriptors( """ # TODO: should we upstream a util like this? placeholders = [n for n in updated_gm.graph.nodes if n.op == "placeholder"] - new_local_args = [n.meta["val"] for n in placeholders] + # assume if "val" is not present in meta, then it's a non-tensor input + # and there is no sharding associated with it and we can just forward + # the original input + new_local_args = [n.meta.get("val", None) for n in placeholders] joint_with_descriptors.graph_module = updated_gm joint_with_descriptors._aot_graph_capture.graph_module = updated_gm @@ -60,8 +63,10 @@ def update_joint_with_descriptors( for orig, new in zip(joint_with_descriptors._aot_state.flat_args, new_local_args): if isinstance(orig, torch.nn.Parameter): new_flat_args.append(torch.nn.Parameter(new)) - else: + elif new is not None: new_flat_args.append(new) + else: + new_flat_args.append(orig) tangent_idx = len(joint_with_descriptors._aot_state.flat_args) new_local_tangents = new_local_args[tangent_idx:] diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 5e1dca52..c0e20920 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -157,6 +157,18 @@ def build_sharding_metadata(self): strats = {} for node in self.graph.nodes: if node.op == "placeholder": + if node.meta.get("val", None) is None: + # For non-tensor inputs, they are considered as being + # replicated across all ranks. Given that those inputs + # seems to have been baked into the graph, we don't + # actually will use this OpStrategy + strats[node] = _create_all_options(self.mesh, ()) + # for now, seems like non-tensor inputs are baked in the graph + # so let's assert that this is indeed the case + assert ( + len(node.users) == 0 + ), f"{node} nas {len(node.users)}, expected 0" + continue strats[node] = _create_all_options( self.mesh, node.meta["val"].shape, tensor=node.meta["val"] ) @@ -828,9 +840,14 @@ def add_sharded_input_constraint( if input_placements is not None: mut_ips = {i: p for i, p in enumerate(input_placements)} - for desc, (node, grad_node) in get_plain_input_and_grad_nodes( - self.graph - ).items(): + inputs_and_grads = get_plain_input_and_grad_nodes(self.graph) + if mut_ips is not None and len(mut_ips) != len(inputs_and_grads): + raise ValueError( + f"Expected to have {len(inputs_and_grads)} " + f"input placements, got {len(mut_ips)}" + ) + + for desc, (node, grad_node) in inputs_and_grads.items(): if input_placements is None: placement = None else: @@ -838,6 +855,10 @@ def add_sharded_input_constraint( assert mut_ips is not None placement = mut_ips.pop(desc.idx) + if placement is None and "val" not in node.meta: + # this is a non-tensor input, we don't do anything about it + continue + self.add_node_constraint( node, placement, constraint_name="input_constraint" ) diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index 4a8e52cc..8b3a7361 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -646,31 +646,6 @@ def convert_element_type_rule(mesh, op_schema): return out_strat -@register_opschema_rule(torch.ops.aten.split.Tensor) -def split_rule(mesh, op_schema): - strat = op_schema.args_schema - op = torch.ops.aten.split.Tensor - from torch.distributed.tensor._ops._tensor_ops import split_rule - - res = [] - oo = [] - for i, ss in enumerate(strat[0].strategies): - ispec = ss.input_specs[0] - assert ss.output_spec == ispec - o = split_rule(OpSchema(op, (ispec, strat[1], strat[2]), {})) - # res.append(o) - oo.append(o) - if o.output_spec is not None: - s = OpSpec(o.output_spec, input_specs=(ispec,)) - s.redistribute_cost = [[math.inf] * len(ss.redistribute_cost[0])] - # s.redistribute_cost = [[0.0] * len(ss.redistribute_cost[0])] - s.redistribute_cost[0][i] = 0.0 - res.append(s) - - out_strat = OpStrategy(res) - return out_strat - - @register_opschema_rule(torch.ops.aten._unsafe_index.Tensor) def _unsafe_index_rule(mesh, op_schema): raise NotImplementedError() diff --git a/tests/test_api.py b/tests/test_api.py index 2369ba97..87348f0e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -117,6 +117,55 @@ def input_fn(): ) +def test_non_tensor_input(device_mesh_1d): + dim = 128 + + class Model(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + self.linear = nn.Linear(dim, dim) + + def forward(self, x, input_dim: int): + return self.linear(x).chunk(2, dim=input_dim) + + def init_weights(self): + dim = self.dim + self.linear.weight = torch.nn.Parameter(torch.ones(dim, dim) * 9.0) + with torch.no_grad(): + self.linear.bias.fill_(98.6) + + def input_fn(): + b = 512 + inputs = torch.rand(b, dim, device="cuda") + input_dim = 1 + return (inputs, input_dim) + + with torch.device("meta"): + model = Model(dim) + with AutoParallel( + model, + input_fn, + device_mesh_1d, + ) as autop: + x_sharding = (Shard(0),) + autop.add_input_constraints([x_sharding, None]) + sharding_placement = autop.optimize_placement() + + # AutoParallel produces a module with meta-DTensor parameters that need to be initialized + parallel_mod = autop.apply_placement(sharding_placement) + parallel_mod.to_empty(device="cuda") + parallel_mod.init_weights() + assert torch.equal( + parallel_mod.get_parameter("linear.weight").full_tensor(), + torch.full((dim, dim), 9.0, device="cuda"), + ) + assert torch.equal( + parallel_mod.get_parameter("linear.bias").full_tensor(), + torch.full((dim,), 98.6, device="cuda"), + ) + + def test_fx_graph_annotate(device_mesh_1d): dim = 128 From b75288254cbf323f888e18961a6db3b0a65d67b8 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 21 Nov 2025 22:06:09 +0000 Subject: [PATCH 2/4] Minor test improvement --- tests/test_api.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/test_api.py b/tests/test_api.py index 87348f0e..5e8f9078 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -152,17 +152,13 @@ def input_fn(): autop.add_input_constraints([x_sharding, None]) sharding_placement = autop.optimize_placement() - # AutoParallel produces a module with meta-DTensor parameters that need to be initialized parallel_mod = autop.apply_placement(sharding_placement) parallel_mod.to_empty(device="cuda") parallel_mod.init_weights() - assert torch.equal( - parallel_mod.get_parameter("linear.weight").full_tensor(), - torch.full((dim, dim), 9.0, device="cuda"), - ) - assert torch.equal( - parallel_mod.get_parameter("linear.bias").full_tensor(), - torch.full((dim,), 98.6, device="cuda"), + placeholders = autop.gm.graph.find_nodes(op="placeholder") + non_tensor_input = placeholders[3] + assert sharding_placement[non_tensor_input].output_specs.placements == ( + Replicate(), ) From c035d82e8761d6c56f3ee365e97a392dbc428080 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 28 Nov 2025 16:55:38 +0000 Subject: [PATCH 3/4] Fix cast_parametrization --- autoparallel/cast_parametrization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/autoparallel/cast_parametrization.py b/autoparallel/cast_parametrization.py index 36574ed7..a9b82057 100644 --- a/autoparallel/cast_parametrization.py +++ b/autoparallel/cast_parametrization.py @@ -187,6 +187,8 @@ def apply_dtype_cast(model, mp_policy: MixedPrecisionPolicy): class DTypeCastModule(torch.nn.Module): def forward(self, *args, **kwargs): def cast_fn(x): + if not isinstance(x, torch.Tensor): + return x if not torch.is_floating_point(x): return x return x.to(self._mp_policy.param_dtype) From 609a56531a194fb9681942ba7b64c8e5b0be780c Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 28 Nov 2025 16:56:32 +0000 Subject: [PATCH 4/4] Fix cast_parametrization for output now --- autoparallel/cast_parametrization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/autoparallel/cast_parametrization.py b/autoparallel/cast_parametrization.py index a9b82057..24c1fdf0 100644 --- a/autoparallel/cast_parametrization.py +++ b/autoparallel/cast_parametrization.py @@ -198,6 +198,8 @@ def cast_fn(x): output = super().forward(*args, **kwargs) def cast_out_fn(x): + if not isinstance(x, torch.Tensor): + return x return x.to(self._mp_policy.output_dtype) output = tree_map(cast_out_fn, output)