Skip to content

Commit

Permalink
[sharding_in_types] Add lax.transpose sharding propagation rule
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687094297
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 18, 2024
1 parent 57a95a7 commit 3e634d9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
16 changes: 13 additions & 3 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4546,6 +4546,11 @@ def _transpose_shape_rule(operand, *, permutation):
raise TypeError(msg.format(permutation, operand.shape))
return tuple(operand.shape[old_idx] for old_idx in permutation)

def _transpose_sharding_rule(operand, *, permutation):
o_spec = operand.sharding.spec
new_spec = [o_spec[old_idx] for old_idx in permutation]
return NamedSharding(operand.sharding.mesh, P(*new_spec))

def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
operand, = batched_args
bdim, = batch_dims
Expand All @@ -4563,10 +4568,15 @@ def _transpose_lower(ctx, x, *, permutation):
elt_shape = core.physical_element_aval(aval_out.dtype).shape
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))]
permutation = [*permutation, *trailing_dims]
return [hlo.transpose(x, mlir.dense_int_array(permutation))]
out = hlo.transpose(x, mlir.dense_int_array(permutation))
if config.sharding_in_types.value:
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return [out]

transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
'transpose')
transpose_p = standard_primitive(
_transpose_shape_rule, _input_dtype, 'transpose',
sharding_rule=_transpose_sharding_rule)
ad.deflinear2(transpose_p,
lambda t, _, permutation: [transpose(t, np.argsort(permutation))])
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
Expand Down
19 changes: 19 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4899,6 +4899,25 @@ def f(x):

f(arr)

def test_lax_transpose_rule(self):
mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z'))
np_inp = np.arange(16).reshape(4, 2, 2)
s = NamedSharding(mesh, P('x', 'y', 'z'))
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
y = jnp.transpose(x, (1, 2, 0))
self.assertEqual(y.sharding.spec, P('y', 'z', 'x'))
return y

out = f(arr)
self.assertArraysEqual(out, np.transpose(arr, (1, 2, 0)))
self.assertEqual(out.aval.sharding.spec, P('y', 'z', 'x'))

lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 3e634d9

Please sign in to comment.