@@ -269,6 +269,36 @@ def __init__(self, rngs):
269269 assert jax .tree .leaves (abs_state )[0 ].sharding .is_equivalent_to (
270270 NamedSharding (mesh , P (None , 'model' )), ndim = 2 )
271271
272+ def test_eval_shape_with_sharding0 (self ):
273+ # based on https://github.com/google/flax/issues/5110
274+ mesh1 = jax .make_mesh ((2 , 2 ), ("a" , "b" ), (jax .sharding .AxisType .Auto , jax .sharding .AxisType .Auto ))
275+ mesh2 = jax .make_mesh ((1 , 4 ), ("c" , "d" ), (jax .sharding .AxisType .Auto , jax .sharding .AxisType .Auto ))
276+
277+ class Model (nnx .Module ):
278+ def __init__ (self ):
279+ self .p1 = nnx .Linear (16 , 16 , rngs = nnx .Rngs (0 ), kernel_metadata = {"sharding" : ("a" , "b" ), "mesh" : mesh1 })
280+ self .p2 = nnx .Linear (16 , 16 , rngs = nnx .Rngs (0 ), kernel_metadata = {"sharding" : ("c" , "d" ), "mesh" : mesh2 })
281+
282+ abs_model = nnx .eval_shape (lambda : Model ())
283+ assert isinstance (abs_model .p1 .kernel .sharding , jax .sharding .NamedSharding )
284+ assert abs_model .p1 .kernel .sharding .mesh is mesh1
285+ assert abs_model .p1 .kernel .sharding .spec == jax .P ("a" , "b" )
286+ assert isinstance (abs_model .p2 .kernel .sharding , jax .sharding .NamedSharding )
287+ assert abs_model .p2 .kernel .sharding .mesh is mesh2
288+ assert abs_model .p2 .kernel .sharding .spec == jax .P ("c" , "d" )
289+
290+ def test_eval_shape_with_sharding1 (self ):
291+ class Model (nnx .Module ):
292+ def __init__ (self ):
293+ self .linear = nnx .Linear (10 , 10 , rngs = nnx .Rngs (0 ), kernel_metadata = {"sharding" : ("a" , "b" )})
294+
295+ mesh = jax .make_mesh ((2 , 2 ), ("a" , "b" ), (jax .sharding .AxisType .Auto , jax .sharding .AxisType .Auto ))
296+ with jax .set_mesh (mesh ):
297+ abs_model = nnx .eval_shape (lambda : Model ())
298+ assert isinstance (abs_model .linear .kernel .sharding , jax .sharding .NamedSharding )
299+ assert abs_model .linear .kernel .sharding .mesh is mesh
300+ assert abs_model .linear .kernel .sharding .spec == jax .P ("a" , "b" )
301+
272302 def test_explicit_sharding (self ):
273303 mesh = jax .make_mesh (
274304 (2 , 2 ),
0 commit comments