Skip to content

Commit f717ff6

Browse files
committed
Added sharding propagation support in nnx.eval_shape
1 parent 697f4e5 commit f717ff6

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

flax/nnx/transforms/transforms.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,22 @@ def to_value_metadata(x):
231231

232232

233233
def _to_variable(node):
234+
# import here to avoid circular imports
235+
from flax.nnx.spmd import get_var_pspec
236+
234237
def to_variable(x):
235238
if isinstance(x, ValueMetadata):
236239
var = x.var_type._new(x.value, x.metadata)
240+
241+
# global_mesh = jax.sharding.get_concrete_mesh()
242+
global_mesh = jax._src.mesh.get_concrete_mesh()
243+
if global_mesh.axis_sizes == ():
244+
global_mesh = None
245+
mesh = var.get_metadata("mesh", None) or global_mesh
246+
if mesh is not None:
247+
pspec = get_var_pspec(var)
248+
sharding = jax.sharding.NamedSharding(mesh=mesh, spec=pspec)
249+
var.set_value(jax.ShapeDtypeStruct(shape=var.shape, dtype=var.dtype, sharding=sharding))
237250
return var
238251
return x
239252

tests/nnx/spmd_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,36 @@ def __init__(self, rngs):
269269
assert jax.tree.leaves(abs_state)[0].sharding.is_equivalent_to(
270270
NamedSharding(mesh, P(None, 'model')), ndim=2)
271271

272+
def test_eval_shape_with_sharding0(self):
273+
# based on https://github.com/google/flax/issues/5110
274+
mesh1 = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto))
275+
mesh2 = jax.make_mesh((1, 4), ("c", "d"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto))
276+
277+
class Model(nnx.Module):
278+
def __init__(self):
279+
self.p1 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"sharding": ("a", "b"), "mesh": mesh1})
280+
self.p2 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"sharding": ("c", "d"), "mesh": mesh2})
281+
282+
abs_model = nnx.eval_shape(lambda: Model())
283+
assert isinstance(abs_model.p1.kernel.sharding, jax.sharding.NamedSharding)
284+
assert abs_model.p1.kernel.sharding.mesh is mesh1
285+
assert abs_model.p1.kernel.sharding.spec == jax.P("a", "b")
286+
assert isinstance(abs_model.p2.kernel.sharding, jax.sharding.NamedSharding)
287+
assert abs_model.p2.kernel.sharding.mesh is mesh2
288+
assert abs_model.p2.kernel.sharding.spec == jax.P("c", "d")
289+
290+
def test_eval_shape_with_sharding1(self):
291+
class Model(nnx.Module):
292+
def __init__(self):
293+
self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0), kernel_metadata={"sharding": ("a", "b")})
294+
295+
mesh = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto))
296+
with jax.set_mesh(mesh):
297+
abs_model = nnx.eval_shape(lambda: Model())
298+
assert isinstance(abs_model.linear.kernel.sharding, jax.sharding.NamedSharding)
299+
assert abs_model.linear.kernel.sharding.mesh is mesh
300+
assert abs_model.linear.kernel.sharding.spec == jax.P("a", "b")
301+
272302
def test_explicit_sharding(self):
273303
mesh = jax.make_mesh(
274304
(2, 2),

0 commit comments

Comments
 (0)