Skip to content

Commit

Permalink
[sharding_in_types] Add lax.transpose sharding propagation rule
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687057762
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 17, 2024
1 parent e92e119 commit 1aeb146
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 32 deletions.
74 changes: 45 additions & 29 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1995,11 +1995,14 @@ def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs):

def unop(result_dtype, accepted_dtypes, name):
dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name)
prim = standard_primitive(_attrgetter('shape'), dtype_rule, name)
prim = standard_primitive(_attrgetter('shape'), dtype_rule, name,
sharding_rule=_attrgetter('sharding'))
batching.defvectorized(prim)
pe.def_trivial_padding(prim)
return prim

standard_unop = partial(unop, _identity)

_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)


Expand Down Expand Up @@ -2584,7 +2587,8 @@ def _integer_pow_jvp(g, x, *, y):
return _zeros(g) if y == 0 else mul(g, mul(_const(x, y), integer_pow(x, y - 1)))

integer_pow_p = standard_primitive(
_attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow')
_attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow',
sharding_rule=_attrgetter('sharding'))
batching.defvectorized(integer_pow_p)
ad.defjvp(integer_pow_p, _integer_pow_jvp)
pe.def_trivial_padding(integer_pow_p)
Expand All @@ -2611,17 +2615,23 @@ def _integer_pow_lowering(ctx, x, *, y):
# These cases are subsumed by the general case, but it's faster to emit these
# common cases directly.
if y == 2:
return (hlo.multiply(x, x),)
out = hlo.multiply(x, x)
elif y == 3:
return (hlo.multiply(hlo.multiply(x, x), x),)
out = hlo.multiply(hlo.multiply(x, x), x)
else:
lowering = mlir.lower_fun(_integer_pow, multiple_results=False)
# TODO(b/217551391): emitting an out-of-line call leads to a large
# expansion when the MLIR is lowered to HLO, because the HLO lowering
# clones the callee. Consider unconditionally caching when the MLIR->HLO
# lowering doesn't expand the program.
lowering = mlir.cache_lowering(lowering)
return lowering(ctx, x, y=y)
out = lowering(ctx, x, y=y)
if config.sharding_in_types.value:
aval_out, = ctx.avals_out
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
out = out[0] if isinstance(out, list) else out
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return out if isinstance(out, list) else [out]

mlir.register_lowering(integer_pow_p, _integer_pow_lowering)

Expand Down Expand Up @@ -4536,6 +4546,11 @@ def _transpose_shape_rule(operand, *, permutation):
raise TypeError(msg.format(permutation, operand.shape))
return tuple(operand.shape[old_idx] for old_idx in permutation)

def _transpose_sharding_rule(operand, *, permutation):
o_spec = operand.sharding.spec
new_spec = [o_spec[old_idx] for old_idx in permutation]
return NamedSharding(operand.sharding.mesh, P(*new_spec))

def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
operand, = batched_args
bdim, = batch_dims
Expand All @@ -4553,10 +4568,15 @@ def _transpose_lower(ctx, x, *, permutation):
elt_shape = core.physical_element_aval(aval_out.dtype).shape
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))]
permutation = [*permutation, *trailing_dims]
return [hlo.transpose(x, mlir.dense_int_array(permutation))]
out = hlo.transpose(x, mlir.dense_int_array(permutation))
if config.sharding_in_types.value:
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return [out]

transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
'transpose')
transpose_p = standard_primitive(
_transpose_shape_rule, _input_dtype, 'transpose',
sharding_rule=_transpose_sharding_rule)
ad.deflinear2(transpose_p,
lambda t, _, permutation: [transpose(t, np.argsort(permutation))])
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
Expand Down Expand Up @@ -4846,15 +4866,6 @@ def _reduce_number_dtype_rule(name, operand, *args, **kw):
"of number.".format(name, dtype_to_string(operand.dtype)))
return dtypes.canonicalize_dtype(operand.dtype)

def _reduce_sum_shape_rule(operand, *, axes):
return _reduce_op_shape_rule(operand, axes=axes)

def _reduce_sum_sharding_rule(operand, *, axes):
axes = frozenset(axes)
new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec)
if i not in axes))
return NamedSharding(operand.sharding.mesh, new_spec)

def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
assert ad.is_undefined_primal(operand)
input_shape = operand.aval.shape
Expand All @@ -4877,16 +4888,6 @@ def _replace_masked_values(x, val, padded_axes):
masks = [broadcasted_iota(dtype, x.shape, i) < d for i, d in padded_axes]
return select(_reduce(operator.and_, masks), x, full_like(x, val))


reduce_sum_p = standard_primitive(
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
'reduce_sum', sharding_rule=_reduce_sum_sharding_rule)
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p, _get_sum_identity)
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
_get_sum_identity)


def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
del input_shape # Unused.
if len(axes) != len(set(axes)):
Expand All @@ -4896,6 +4897,20 @@ def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
axes = frozenset(axes)
return tuple(d for i, d in enumerate(operand.shape) if i not in axes)

def _reduce_op_sharding_rule(operand, *, axes):
axes = frozenset(axes)
new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec)
if i not in axes))
return NamedSharding(operand.sharding.mesh, new_spec)

reduce_sum_p = standard_primitive(
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
'reduce_sum', sharding_rule=_reduce_op_sharding_rule)
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p, _get_sum_identity)
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
_get_sum_identity)

def _reduce_prod_jvp_rule(primals, tangents, *, axes):
reducer = lambda x, y: [mul(x, y)]
primals_out, tangents_out = _reduce_jvp(reducer, [_const(primals[0], 1)],
Expand All @@ -4922,8 +4937,9 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
return div(_reduce_sum(mul(g, location_indicators), axes), counts)


reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
'reduce_max')
reduce_max_p = standard_primitive(
_reduce_op_shape_rule, _input_dtype, 'reduce_max',
sharding_rule=_reduce_op_sharding_rule)
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
batching.defreducer(reduce_max_p, _get_max_identity)
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5299,7 +5299,11 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# whenever x is weak, but avoids introducing weak types with something like
# array([1, 2, 3])
weak_type = dtype is None and dtypes.is_weakly_typed(object)
sharding = canonicalize_device_to_sharding(device)
if (config.sharding_in_types.value and device is None and
isinstance(object, Array)):
sharding = object.sharding
else:
sharding = canonicalize_device_to_sharding(device)

# Use device_put to avoid a copy for ndarray inputs.
if (not copy and isinstance(object, np.ndarray) and
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ def promote_dtypes(*args: ArrayLike) -> list[Array]:
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment]
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
return [lax._convert_element_type(x, to_dtype, weak_type,
getattr(x, "sharding", None))
for x in args]


def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]:
Expand Down
109 changes: 108 additions & 1 deletion tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4420,7 +4420,6 @@ def f():
return jnp.array(inp, dtype=np.float32, device=s)

out = f()
print(f.trace().jaxpr)
self.assertArraysEqual(out, inp.astype('float32'))
self.assertEqual(out.sharding, s)
self.assertEqual(out.dtype, np.float32)
Expand Down Expand Up @@ -4784,6 +4783,37 @@ def f(x):
if reduce and compiled_text is not None:
self.assertIn('all-reduce', compiled_text)

@parameterized.named_parameters(
('all', None, P('x', 'y'), P()),
('first', 0, P('x', 'y'), P('y')),
('second', 1, P('x', 'y'), P('x')),
('first2', 0, P(('x', 'y'), None), P(None)),
('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False),
)
def test_reduce_max(self, axis, in_spec, out_spec, reduce=True):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, in_spec)
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
self.assertEqual(x.sharding.spec, s.spec)
y = jnp.max(x, axis=axis)
self.assertEqual(y.sharding.spec, out_spec)
return y

out = f(arr)
self.assertArraysEqual(out, np.max(np_inp, axis=axis))
self.assertEqual(out.aval.sharding.spec, out_spec)

lowered = f.lower(arr)
self.assertIn('@Sharding', lowered.as_text())

compiled_text = lowered.compile().as_text()
if reduce and compiled_text is not None:
self.assertIn('all-reduce', compiled_text)

@parameterized.named_parameters(
('0', 0, P(None, 'x', 'y')),
('1', 1, P('x', None, 'y')),
Expand Down Expand Up @@ -4811,6 +4841,83 @@ def f(x):
lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)

@parameterized.named_parameters(
('2', 2),
('3', 3),
('4', 4),
)
def test_integer_pow(self, pow):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
y = x ** pow
self.assertEqual(y.sharding.spec, s.spec)
return y

out = f(arr)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, np_inp ** pow)

lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)

def test_sin_unop(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
y = lax.sin(x)
self.assertEqual(y.sharding.spec, s.spec)
return y

out = f(arr)
self.assertEqual(out.sharding, s)

lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)

def test_jnp_array(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16, dtype=jnp.int32).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
assert x.dtype == jnp.int32
y = jnp.array(x, dtype=jnp.float32)
self.assertEqual(y.dtype, jnp.float32)
self.assertEqual(y.sharding.spec, s.spec)
return y

f(arr)

def test_lax_transpose_rule(self):
mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z'))
np_inp = np.arange(16).reshape(4, 2, 2)
s = NamedSharding(mesh, P('x', 'y', 'z'))
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
y = jnp.transpose(x, (1, 2, 0))
self.assertEqual(y.sharding.spec, P('y', 'z', 'x'))
return y

out = f(arr)
self.assertArraysEqual(out, np.transpose(arr, (1, 2, 0)))
self.assertEqual(out.aval.sharding.spec, P('y', 'z', 'x'))

lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 1aeb146

Please sign in to comment.