Skip to content

Commit

Permalink
#sdy Support with_sharding_constraint lowering through Shardy.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 655905063
  • Loading branch information
bartchr808 authored and jax authors committed Jul 25, 2024
1 parent f15f971 commit b00f978
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
13 changes: 9 additions & 4 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2257,9 +2257,13 @@ def _wrap_with_spmd_op(name: str,
ctx: LoweringRuleContext,
x: ir.Value,
aval_out: core.AbstractValue,
sharding_proto: xc.OpSharding,
sharding: xc.OpSharding | sharding.SdyArraySharding,
unspecified_dims: set[int] | None = None,
has_side_effect: bool = False):
has_side_effect: bool = False,
allow_shardy_lowering: bool = False):
if config.use_shardy_partitioner.value and allow_shardy_lowering:
return dialects.sdy.ShardingConstraintOp(x, sharding.build()).result # type: ignore

# unspecified_dims indicate dimensions whose shardings are not specified and
# XLA sharding propagation can change them.
if unspecified_dims:
Expand All @@ -2280,11 +2284,12 @@ def _wrap_with_spmd_op(name: str,
api_version=1,
result_shapes=result_shapes,
has_side_effect=has_side_effect)
set_sharding(op, sharding_proto)
set_sharding(op, sharding)
return op.result


wrap_with_sharding_op = partial(_wrap_with_spmd_op, "Sharding")
wrap_with_sharding_op = partial(_wrap_with_spmd_op, "Sharding",
allow_shardy_lowering=True)
wrap_with_full_to_shard_op = partial(_wrap_with_spmd_op, "SPMDFullToShardShape")
wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape")

Expand Down
9 changes: 7 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2447,6 +2447,8 @@ def with_sharding_constraint(x, shardings):
shardings_flat = [_create_sharding_for_array(mesh, a, 'shardings',
'with_sharding_constraint')
for a in user_shardings_flat]
# TODO(bartchr): remove `unconstrained_dims` after migrating to Shardy. It's
# already part of the shardings.
unconstrained_dims = [get_unconstrained_dims(s)
if isinstance(s, NamedSharding) else {}
for s in shardings_flat]
Expand Down Expand Up @@ -2496,9 +2498,12 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout,
if (isinstance(axis_ctx, sharding_impls.SPMDAxisContext) and
axis_ctx.manual_axes):
sharding = mlir.add_manual_axes(axis_ctx, sharding, aval.ndim)
if config.use_shardy_partitioner.value:
sharding = sharding._to_sdy_sharding(aval.ndim)
else:
sharding = sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
out = mlir.wrap_with_sharding_op(
ctx, x_node, out_aval, sharding._to_xla_hlo_sharding(aval.ndim).to_proto(),
unspecified_dims=unconstrained_dims)
ctx, x_node, out_aval, sharding, unspecified_dims=unconstrained_dims)
if layout is not None:
out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, aval)
return [out]
Expand Down
23 changes: 23 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4967,6 +4967,29 @@ def f(x):

self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text())

def test_lowering_with_sharding_constraint(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
arr = np.arange(16).reshape(4, 2, 2)

@jax.jit
def f(x):
return jax.lax.with_sharding_constraint(
x, NamedSharding(mesh, P('x', None, 'y')))
lowered_str = jax.jit(f).lower(arr).as_text()
self.assertIn('sdy.sharding_constraint', lowered_str)
self.assertIn('<@mesh, [{"x"}, {}, {"y"}]>', lowered_str)

def test_lowering_with_sharding_constraint_unconstrained(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
arr = np.arange(16).reshape(4, 2, 2)

@jax.jit
def f(x):
return jax.lax.with_sharding_constraint(
x, NamedSharding(mesh, P('x', P.UNCONSTRAINED, 'y')))
lowered_str = f.lower(arr).as_text()
self.assertIn('sdy.sharding_constraint', lowered_str)
self.assertIn('<@mesh, [{"x"}, {?}, {"y"}]>', lowered_str)


if __name__ == '__main__':
Expand Down

0 comments on commit b00f978

Please sign in to comment.