From 5b8775dc2fca477f9848f9b5c4f5248a047beafa Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 11 Oct 2024 21:30:30 -0700 Subject: [PATCH] [sharding_in_types] Add sharding rule for reduce sum which is just drop the specs for the axis we are reducing over PiperOrigin-RevId: 685069065 --- jax/_src/lax/lax.py | 8 +++++++- tests/pjit_test.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 7679c0d9ae6a..552fdfd0023c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4781,6 +4781,12 @@ def _reduce_number_dtype_rule(name, operand, *args, **kw): def _reduce_sum_shape_rule(operand, *, axes): return _reduce_op_shape_rule(operand, axes=axes) +def _reduce_sum_sharding_rule(operand, *, axes): + axes = frozenset(axes) + new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec) + if i not in axes)) + return NamedSharding(operand.sharding.mesh, new_spec) + def _reduce_sum_transpose_rule(cotangent, operand, *, axes): assert ad.is_undefined_primal(operand) input_shape = operand.aval.shape @@ -4806,7 +4812,7 @@ def _replace_masked_values(x, val, padded_axes): reduce_sum_p = standard_primitive( _reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), - 'reduce_sum') + 'reduce_sum', sharding_rule=_reduce_sum_sharding_rule) ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) batching.defreducer(reduce_sum_p, _get_sum_identity) pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index c0cccc1fe6c2..c6519c351d16 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4607,7 +4607,7 @@ def spec_regex(s): class ShardingInTypesTest(jtu.JaxTestCase): def test_basic_mul(self): - mesh = jtu.create_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -4758,6 +4758,34 @@ def test_aval_repr(self): aval = aval.update(sharding=NamedSharding(mesh, P(('x', 'y'), None))) self.assertEqual(aval.str_short(), 'float32[8@xy,2]') + @parameterized.named_parameters( + ('all', None,P('x', 'y'), P()), + ('first', 0, P('x', 'y'), P('y')), + ('second', 1, P('x', 'y'), P('x')), + ('first2', 0, P(('x', 'y'), None), P(None)), + ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), + ) + def test_reduce_sum(self, axis, in_spec, out_spec, reduce=True): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, in_spec) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + self.assertEqual(x.sharding.spec, s.spec) + y = jnp.sum(x, axis=axis) + self.assertEqual(y.sharding.spec, out_spec) + return y + + out = f(arr) + self.assertArraysEqual(out, np.sum(np_inp, axis=axis)) + self.assertEqual(out.aval.sharding.spec, out_spec) + + compiled_text = f.lower(arr).compile().as_text() + if reduce and compiled_text is not None: + self.assertIn('all-reduce', compiled_text) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):