From 7fcaf221bda8c20c3172e96d8d5e1af6b98224eb Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 17 Oct 2024 13:55:53 -0700 Subject: [PATCH] [sharding_in_types] Add reduce max, integer_pow and standard_unop sharding rules PiperOrigin-RevId: 687034219 --- jax/_src/lax/lax.py | 58 +++++++++++++++++++---------------- tests/pjit_test.py | 73 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 26 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 00db56bf563a..06b47c6ca572 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1995,11 +1995,14 @@ def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs): def unop(result_dtype, accepted_dtypes, name): dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name) - prim = standard_primitive(_attrgetter('shape'), dtype_rule, name) + prim = standard_primitive(_attrgetter('shape'), dtype_rule, name, + sharding_rule=_attrgetter('sharding')) batching.defvectorized(prim) pe.def_trivial_padding(prim) return prim + standard_unop = partial(unop, _identity) + _attrgetter = lambda name: lambda x, **kwargs: getattr(x, name) @@ -2584,7 +2587,8 @@ def _integer_pow_jvp(g, x, *, y): return _zeros(g) if y == 0 else mul(g, mul(_const(x, y), integer_pow(x, y - 1))) integer_pow_p = standard_primitive( - _attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow') + _attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow', + sharding_rule=_attrgetter('sharding')) batching.defvectorized(integer_pow_p) ad.defjvp(integer_pow_p, _integer_pow_jvp) pe.def_trivial_padding(integer_pow_p) @@ -2611,9 +2615,9 @@ def _integer_pow_lowering(ctx, x, *, y): # These cases are subsumed by the general case, but it's faster to emit these # common cases directly. if y == 2: - return (hlo.multiply(x, x),) + out = hlo.multiply(x, x) elif y == 3: - return (hlo.multiply(hlo.multiply(x, x), x),) + out = hlo.multiply(hlo.multiply(x, x), x) else: lowering = mlir.lower_fun(_integer_pow, multiple_results=False) # TODO(b/217551391): emitting an out-of-line call leads to a large @@ -2621,7 +2625,13 @@ def _integer_pow_lowering(ctx, x, *, y): # clones the callee. Consider unconditionally caching when the MLIR->HLO # lowering doesn't expand the program. lowering = mlir.cache_lowering(lowering) - return lowering(ctx, x, y=y) + out = lowering(ctx, x, y=y) + if config.sharding_in_types.value: + aval_out, = ctx.avals_out + proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() + out = out[0] if isinstance(out, list) else out + return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return out if isinstance(out, list) else [out] mlir.register_lowering(integer_pow_p, _integer_pow_lowering) @@ -4846,15 +4856,6 @@ def _reduce_number_dtype_rule(name, operand, *args, **kw): "of number.".format(name, dtype_to_string(operand.dtype))) return dtypes.canonicalize_dtype(operand.dtype) -def _reduce_sum_shape_rule(operand, *, axes): - return _reduce_op_shape_rule(operand, axes=axes) - -def _reduce_sum_sharding_rule(operand, *, axes): - axes = frozenset(axes) - new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec) - if i not in axes)) - return NamedSharding(operand.sharding.mesh, new_spec) - def _reduce_sum_transpose_rule(cotangent, operand, *, axes): assert ad.is_undefined_primal(operand) input_shape = operand.aval.shape @@ -4877,16 +4878,6 @@ def _replace_masked_values(x, val, padded_axes): masks = [broadcasted_iota(dtype, x.shape, i) < d for i, d in padded_axes] return select(_reduce(operator.and_, masks), x, full_like(x, val)) - -reduce_sum_p = standard_primitive( - _reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), - 'reduce_sum', sharding_rule=_reduce_sum_sharding_rule) -ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) -batching.defreducer(reduce_sum_p, _get_sum_identity) -pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum, - _get_sum_identity) - - def _reduce_op_shape_rule(operand, *, axes, input_shape=None): del input_shape # Unused. if len(axes) != len(set(axes)): @@ -4896,6 +4887,20 @@ def _reduce_op_shape_rule(operand, *, axes, input_shape=None): axes = frozenset(axes) return tuple(d for i, d in enumerate(operand.shape) if i not in axes) +def _reduce_op_sharding_rule(operand, *, axes): + axes = frozenset(axes) + new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec) + if i not in axes)) + return NamedSharding(operand.sharding.mesh, new_spec) + +reduce_sum_p = standard_primitive( + _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), + 'reduce_sum', sharding_rule=_reduce_op_sharding_rule) +ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) +batching.defreducer(reduce_sum_p, _get_sum_identity) +pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum, + _get_sum_identity) + def _reduce_prod_jvp_rule(primals, tangents, *, axes): reducer = lambda x, y: [mul(x, y)] primals_out, tangents_out = _reduce_jvp(reducer, [_const(primals[0], 1)], @@ -4922,8 +4927,9 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): return div(_reduce_sum(mul(g, location_indicators), axes), counts) -reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype, - 'reduce_max') +reduce_max_p = standard_primitive( + _reduce_op_shape_rule, _input_dtype, 'reduce_max', + sharding_rule=_reduce_op_sharding_rule) ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_max_p, _get_max_identity) pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5fe89750e216..2585c41711f7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4784,6 +4784,37 @@ def f(x): if reduce and compiled_text is not None: self.assertIn('all-reduce', compiled_text) + @parameterized.named_parameters( + ('all', None, P('x', 'y'), P()), + ('first', 0, P('x', 'y'), P('y')), + ('second', 1, P('x', 'y'), P('x')), + ('first2', 0, P(('x', 'y'), None), P(None)), + ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), + ) + def test_reduce_max(self, axis, in_spec, out_spec, reduce=True): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, in_spec) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + self.assertEqual(x.sharding.spec, s.spec) + y = jnp.max(x, axis=axis) + self.assertEqual(y.sharding.spec, out_spec) + return y + + out = f(arr) + self.assertArraysEqual(out, np.max(np_inp, axis=axis)) + self.assertEqual(out.aval.sharding.spec, out_spec) + + lowered = f.lower(arr) + self.assertIn('@Sharding', lowered.as_text()) + + compiled_text = lowered.compile().as_text() + if reduce and compiled_text is not None: + self.assertIn('all-reduce', compiled_text) + @parameterized.named_parameters( ('0', 0, P(None, 'x', 'y')), ('1', 1, P('x', None, 'y')), @@ -4811,6 +4842,48 @@ def f(x): lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) + @parameterized.named_parameters( + ('2', 2), + ('3', 3), + ('4', 4), + ) + def test_integer_pow(self, pow): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = x ** pow + self.assertEqual(y.sharding.spec, s.spec) + return y + + out = f(arr) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np_inp ** pow) + + lowered_text = f.lower(arr).as_text() + self.assertIn('@Sharding', lowered_text) + + def test_sin_unop(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = lax.sin(x) + self.assertEqual(y.sharding.spec, s.spec) + return y + + out = f(arr) + self.assertEqual(out.sharding, s) + + lowered_text = f.lower(arr).as_text() + self.assertIn('@Sharding', lowered_text) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):