Skip to content

Commit

Permalink
[sharding_in_types] Add reduce max, integer_pow and standard_unop sha…
Browse files Browse the repository at this point in the history
…rding rules

PiperOrigin-RevId: 687073144
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 17, 2024
1 parent e92e119 commit 5df4878
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 26 deletions.
58 changes: 32 additions & 26 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1995,11 +1995,14 @@ def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs):

def unop(result_dtype, accepted_dtypes, name):
dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name)
prim = standard_primitive(_attrgetter('shape'), dtype_rule, name)
prim = standard_primitive(_attrgetter('shape'), dtype_rule, name,
sharding_rule=_attrgetter('sharding'))
batching.defvectorized(prim)
pe.def_trivial_padding(prim)
return prim

standard_unop = partial(unop, _identity)

_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)


Expand Down Expand Up @@ -2584,7 +2587,8 @@ def _integer_pow_jvp(g, x, *, y):
return _zeros(g) if y == 0 else mul(g, mul(_const(x, y), integer_pow(x, y - 1)))

integer_pow_p = standard_primitive(
_attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow')
_attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow',
sharding_rule=_attrgetter('sharding'))
batching.defvectorized(integer_pow_p)
ad.defjvp(integer_pow_p, _integer_pow_jvp)
pe.def_trivial_padding(integer_pow_p)
Expand All @@ -2611,17 +2615,23 @@ def _integer_pow_lowering(ctx, x, *, y):
# These cases are subsumed by the general case, but it's faster to emit these
# common cases directly.
if y == 2:
return (hlo.multiply(x, x),)
out = hlo.multiply(x, x)
elif y == 3:
return (hlo.multiply(hlo.multiply(x, x), x),)
out = hlo.multiply(hlo.multiply(x, x), x)
else:
lowering = mlir.lower_fun(_integer_pow, multiple_results=False)
# TODO(b/217551391): emitting an out-of-line call leads to a large
# expansion when the MLIR is lowered to HLO, because the HLO lowering
# clones the callee. Consider unconditionally caching when the MLIR->HLO
# lowering doesn't expand the program.
lowering = mlir.cache_lowering(lowering)
return lowering(ctx, x, y=y)
out = lowering(ctx, x, y=y)
if config.sharding_in_types.value:
aval_out, = ctx.avals_out
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
out = out[0] if isinstance(out, list) else out
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return out if isinstance(out, list) else [out]

mlir.register_lowering(integer_pow_p, _integer_pow_lowering)

Expand Down Expand Up @@ -4846,15 +4856,6 @@ def _reduce_number_dtype_rule(name, operand, *args, **kw):
"of number.".format(name, dtype_to_string(operand.dtype)))
return dtypes.canonicalize_dtype(operand.dtype)

def _reduce_sum_shape_rule(operand, *, axes):
return _reduce_op_shape_rule(operand, axes=axes)

def _reduce_sum_sharding_rule(operand, *, axes):
axes = frozenset(axes)
new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec)
if i not in axes))
return NamedSharding(operand.sharding.mesh, new_spec)

def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
assert ad.is_undefined_primal(operand)
input_shape = operand.aval.shape
Expand All @@ -4877,16 +4878,6 @@ def _replace_masked_values(x, val, padded_axes):
masks = [broadcasted_iota(dtype, x.shape, i) < d for i, d in padded_axes]
return select(_reduce(operator.and_, masks), x, full_like(x, val))


reduce_sum_p = standard_primitive(
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
'reduce_sum', sharding_rule=_reduce_sum_sharding_rule)
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p, _get_sum_identity)
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
_get_sum_identity)


def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
del input_shape # Unused.
if len(axes) != len(set(axes)):
Expand All @@ -4896,6 +4887,20 @@ def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
axes = frozenset(axes)
return tuple(d for i, d in enumerate(operand.shape) if i not in axes)

def _reduce_op_sharding_rule(operand, *, axes):
axes = frozenset(axes)
new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec)
if i not in axes))
return NamedSharding(operand.sharding.mesh, new_spec)

reduce_sum_p = standard_primitive(
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
'reduce_sum', sharding_rule=_reduce_op_sharding_rule)
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p, _get_sum_identity)
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
_get_sum_identity)

def _reduce_prod_jvp_rule(primals, tangents, *, axes):
reducer = lambda x, y: [mul(x, y)]
primals_out, tangents_out = _reduce_jvp(reducer, [_const(primals[0], 1)],
Expand All @@ -4922,8 +4927,9 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
return div(_reduce_sum(mul(g, location_indicators), axes), counts)


reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
'reduce_max')
reduce_max_p = standard_primitive(
_reduce_op_shape_rule, _input_dtype, 'reduce_max',
sharding_rule=_reduce_op_sharding_rule)
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
batching.defreducer(reduce_max_p, _get_max_identity)
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
Expand Down
73 changes: 73 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4784,6 +4784,37 @@ def f(x):
if reduce and compiled_text is not None:
self.assertIn('all-reduce', compiled_text)

@parameterized.named_parameters(
('all', None, P('x', 'y'), P()),
('first', 0, P('x', 'y'), P('y')),
('second', 1, P('x', 'y'), P('x')),
('first2', 0, P(('x', 'y'), None), P(None)),
('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False),
)
def test_reduce_max(self, axis, in_spec, out_spec, reduce=True):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, in_spec)
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
self.assertEqual(x.sharding.spec, s.spec)
y = jnp.max(x, axis=axis)
self.assertEqual(y.sharding.spec, out_spec)
return y

out = f(arr)
self.assertArraysEqual(out, np.max(np_inp, axis=axis))
self.assertEqual(out.aval.sharding.spec, out_spec)

lowered = f.lower(arr)
self.assertIn('@Sharding', lowered.as_text())

compiled_text = lowered.compile().as_text()
if reduce and compiled_text is not None:
self.assertIn('all-reduce', compiled_text)

@parameterized.named_parameters(
('0', 0, P(None, 'x', 'y')),
('1', 1, P('x', None, 'y')),
Expand Down Expand Up @@ -4811,6 +4842,48 @@ def f(x):
lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)

@parameterized.named_parameters(
('2', 2),
('3', 3),
('4', 4),
)
def test_integer_pow(self, pow):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
y = x ** pow
self.assertEqual(y.sharding.spec, s.spec)
return y

out = f(arr)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, np_inp ** pow)

lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)

def test_sin_unop(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
y = lax.sin(x)
self.assertEqual(y.sharding.spec, s.spec)
return y

out = f(arr)
self.assertEqual(out.sharding, s)

lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 5df4878

Please sign in to comment.