From 47626d88004f7b40298bff897168b9f0e46003dd Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 24 Nov 2025 15:36:08 +0000 Subject: [PATCH] Added sharding propagation support in nnx.eval_shape --- flax/nnx/transforms/transforms.py | 12 ++++++++++++ tests/nnx/spmd_test.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py index 47dfe9517..d81b35f9a 100644 --- a/flax/nnx/transforms/transforms.py +++ b/flax/nnx/transforms/transforms.py @@ -231,9 +231,21 @@ def to_value_metadata(x): def _to_variable(node): + # import here to avoid circular imports + from flax.nnx.spmd import get_var_pspec + def to_variable(x): if isinstance(x, ValueMetadata): var = x.var_type._new(x.value, x.metadata) + + global_mesh = jax.sharding.get_mesh() + if global_mesh.axis_sizes == (): + global_mesh = None + mesh = var.get_metadata("mesh", None) or global_mesh + if mesh is not None: + pspec = get_var_pspec(var) + sharding = jax.sharding.NamedSharding(mesh=mesh, spec=pspec) + var.set_value(jax.ShapeDtypeStruct(shape=var.shape, dtype=var.dtype, sharding=sharding)) return var return x diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index b13334e0f..e31d5faf5 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -269,6 +269,36 @@ def __init__(self, rngs): assert jax.tree.leaves(abs_state)[0].sharding.is_equivalent_to( NamedSharding(mesh, P(None, 'model')), ndim=2) + def test_eval_shape_with_sharding0(self): + # based on https://github.com/google/flax/issues/5110 + mesh1 = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto)) + mesh2 = jax.make_mesh((1, 4), ("c", "d"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto)) + + class Model(nnx.Module): + def __init__(self): + self.p1 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"sharding": ("a", "b"), "mesh": mesh1}) + self.p2 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"sharding": ("c", "d"), "mesh": mesh2}) + + abs_model = nnx.eval_shape(lambda: Model()) + assert isinstance(abs_model.p1.kernel.sharding, jax.sharding.NamedSharding) + assert abs_model.p1.kernel.sharding.mesh is mesh1 + assert abs_model.p1.kernel.sharding.spec == jax.P("a", "b") + assert isinstance(abs_model.p2.kernel.sharding, jax.sharding.NamedSharding) + assert abs_model.p2.kernel.sharding.mesh is mesh2 + assert abs_model.p2.kernel.sharding.spec == jax.P("c", "d") + + def test_eval_shape_with_sharding1(self): + class Model(nnx.Module): + def __init__(self): + self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0), kernel_metadata={"sharding": ("a", "b")}) + + mesh = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto)) + with jax.set_mesh(mesh): + abs_model = nnx.eval_shape(lambda: Model()) + assert isinstance(abs_model.linear.kernel.sharding, jax.sharding.NamedSharding) + assert abs_model.linear.kernel.sharding.mesh is mesh + assert abs_model.linear.kernel.sharding.spec == jax.P("a", "b") + def test_explicit_sharding(self): mesh = jax.make_mesh( (2, 2),