diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 08eafedb6792..254aee3e8e28 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2257,9 +2257,13 @@ def _wrap_with_spmd_op(name: str, ctx: LoweringRuleContext, x: ir.Value, aval_out: core.AbstractValue, - sharding_proto: xc.OpSharding, + sharding: xc.OpSharding | sharding.SdyArraySharding, unspecified_dims: set[int] | None = None, - has_side_effect: bool = False): + has_side_effect: bool = False, + allow_shardy_lowering: bool = False): + if config.use_shardy_partitioner.value and allow_shardy_lowering: + return dialects.sdy.ShardingConstraintOp(x, sharding.build()).result # type: ignore + # unspecified_dims indicate dimensions whose shardings are not specified and # XLA sharding propagation can change them. if unspecified_dims: @@ -2280,11 +2284,12 @@ def _wrap_with_spmd_op(name: str, api_version=1, result_shapes=result_shapes, has_side_effect=has_side_effect) - set_sharding(op, sharding_proto) + set_sharding(op, sharding) return op.result -wrap_with_sharding_op = partial(_wrap_with_spmd_op, "Sharding") +wrap_with_sharding_op = partial(_wrap_with_spmd_op, "Sharding", + allow_shardy_lowering=True) wrap_with_full_to_shard_op = partial(_wrap_with_spmd_op, "SPMDFullToShardShape") wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape") diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 648ab3168d01..169f98433b1b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2447,6 +2447,8 @@ def with_sharding_constraint(x, shardings): shardings_flat = [_create_sharding_for_array(mesh, a, 'shardings', 'with_sharding_constraint') for a in user_shardings_flat] + # TODO(bartchr): remove `unconstrained_dims` after migrating to Shardy. It's + # already part of the shardings. unconstrained_dims = [get_unconstrained_dims(s) if isinstance(s, NamedSharding) else {} for s in shardings_flat] @@ -2496,9 +2498,12 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout, if (isinstance(axis_ctx, sharding_impls.SPMDAxisContext) and axis_ctx.manual_axes): sharding = mlir.add_manual_axes(axis_ctx, sharding, aval.ndim) + if config.use_shardy_partitioner.value: + sharding = sharding._to_sdy_sharding(aval.ndim) + else: + sharding = sharding._to_xla_hlo_sharding(aval.ndim).to_proto() out = mlir.wrap_with_sharding_op( - ctx, x_node, out_aval, sharding._to_xla_hlo_sharding(aval.ndim).to_proto(), - unspecified_dims=unconstrained_dims) + ctx, x_node, out_aval, sharding, unspecified_dims=unconstrained_dims) if layout is not None: out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, aval) return [out] diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 97ba67bb9c2e..dae3265a5425 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4967,6 +4967,29 @@ def f(x): self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text()) + def test_lowering_with_sharding_constraint(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + arr = np.arange(16).reshape(4, 2, 2) + + @jax.jit + def f(x): + return jax.lax.with_sharding_constraint( + x, NamedSharding(mesh, P('x', None, 'y'))) + lowered_str = jax.jit(f).lower(arr).as_text() + self.assertIn('sdy.sharding_constraint', lowered_str) + self.assertIn('<@mesh, [{"x"}, {}, {"y"}]>', lowered_str) + + def test_lowering_with_sharding_constraint_unconstrained(self): + mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + arr = np.arange(16).reshape(4, 2, 2) + + @jax.jit + def f(x): + return jax.lax.with_sharding_constraint( + x, NamedSharding(mesh, P('x', P.UNCONSTRAINED, 'y'))) + lowered_str = f.lower(arr).as_text() + self.assertIn('sdy.sharding_constraint', lowered_str) + self.assertIn('<@mesh, [{"x"}, {?}, {"y"}]>', lowered_str) if __name__ == '__main__':