Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions flax/nnx/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,21 @@ def to_value_metadata(x):


def _to_variable(node):
# import here to avoid circular imports
from flax.nnx.spmd import get_var_pspec

def to_variable(x):
if isinstance(x, ValueMetadata):
var = x.var_type._new(x.value, x.metadata)

global_mesh = jax.sharding.get_mesh()
if global_mesh.axis_sizes == ():
global_mesh = None
mesh = var.get_metadata("mesh", None) or global_mesh
if mesh is not None:
pspec = get_var_pspec(var)
sharding = jax.sharding.NamedSharding(mesh=mesh, spec=pspec)
var.set_value(jax.ShapeDtypeStruct(shape=var.shape, dtype=var.dtype, sharding=sharding))
return var
return x

Expand Down
30 changes: 30 additions & 0 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,36 @@ def __init__(self, rngs):
assert jax.tree.leaves(abs_state)[0].sharding.is_equivalent_to(
NamedSharding(mesh, P(None, 'model')), ndim=2)

def test_eval_shape_with_sharding0(self):
# based on https://github.com/google/flax/issues/5110
mesh1 = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto))
mesh2 = jax.make_mesh((1, 4), ("c", "d"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto))

class Model(nnx.Module):
def __init__(self):
self.p1 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"sharding": ("a", "b"), "mesh": mesh1})
self.p2 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"sharding": ("c", "d"), "mesh": mesh2})

abs_model = nnx.eval_shape(lambda: Model())
assert isinstance(abs_model.p1.kernel.sharding, jax.sharding.NamedSharding)
assert abs_model.p1.kernel.sharding.mesh is mesh1
assert abs_model.p1.kernel.sharding.spec == jax.P("a", "b")
assert isinstance(abs_model.p2.kernel.sharding, jax.sharding.NamedSharding)
assert abs_model.p2.kernel.sharding.mesh is mesh2
assert abs_model.p2.kernel.sharding.spec == jax.P("c", "d")

def test_eval_shape_with_sharding1(self):
class Model(nnx.Module):
def __init__(self):
self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0), kernel_metadata={"sharding": ("a", "b")})

mesh = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto))
with jax.set_mesh(mesh):
abs_model = nnx.eval_shape(lambda: Model())
assert isinstance(abs_model.linear.kernel.sharding, jax.sharding.NamedSharding)
assert abs_model.linear.kernel.sharding.mesh is mesh
assert abs_model.linear.kernel.sharding.spec == jax.P("a", "b")

def test_explicit_sharding(self):
mesh = jax.make_mesh(
(2, 2),
Expand Down
Loading