Skip to content

Commit

Permalink
[sharding_in_types] Add sharding rule for reduce sum which is just dr…
Browse files Browse the repository at this point in the history
…op the specs for the axis we are reducing over

PiperOrigin-RevId: 685069065
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 12, 2024
1 parent 89fcd9f commit 5b8775d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
8 changes: 7 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4781,6 +4781,12 @@ def _reduce_number_dtype_rule(name, operand, *args, **kw):
def _reduce_sum_shape_rule(operand, *, axes):
return _reduce_op_shape_rule(operand, axes=axes)

def _reduce_sum_sharding_rule(operand, *, axes):
axes = frozenset(axes)
new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec)
if i not in axes))
return NamedSharding(operand.sharding.mesh, new_spec)

def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
assert ad.is_undefined_primal(operand)
input_shape = operand.aval.shape
Expand All @@ -4806,7 +4812,7 @@ def _replace_masked_values(x, val, padded_axes):

reduce_sum_p = standard_primitive(
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
'reduce_sum')
'reduce_sum', sharding_rule=_reduce_sum_sharding_rule)
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p, _get_sum_identity)
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
Expand Down
30 changes: 29 additions & 1 deletion tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4607,7 +4607,7 @@ def spec_regex(s):
class ShardingInTypesTest(jtu.JaxTestCase):

def test_basic_mul(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
Expand Down Expand Up @@ -4758,6 +4758,34 @@ def test_aval_repr(self):
aval = aval.update(sharding=NamedSharding(mesh, P(('x', 'y'), None)))
self.assertEqual(aval.str_short(), 'float32[8@xy,2]')

@parameterized.named_parameters(
('all', None,P('x', 'y'), P()),
('first', 0, P('x', 'y'), P('y')),
('second', 1, P('x', 'y'), P('x')),
('first2', 0, P(('x', 'y'), None), P(None)),
('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False),
)
def test_reduce_sum(self, axis, in_spec, out_spec, reduce=True):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, in_spec)
arr = jax.device_put(np_inp, s)

@jax.jit
def f(x):
self.assertEqual(x.sharding.spec, s.spec)
y = jnp.sum(x, axis=axis)
self.assertEqual(y.sharding.spec, out_spec)
return y

out = f(arr)
self.assertArraysEqual(out, np.sum(np_inp, axis=axis))
self.assertEqual(out.aval.sharding.spec, out_spec)

compiled_text = f.lower(arr).compile().as_text()
if reduce and compiled_text is not None:
self.assertIn('all-reduce', compiled_text)


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 5b8775d

Please sign in to comment.