From 18bc354305e9c60b90dce681c6dea1be96e932c5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 11 Oct 2024 16:04:35 -0700 Subject: [PATCH] [sharding_in_types] Add `dot_general` sharding rule. We only handle the simple cases and rely on xla to insert the collectives. Cases where we error * batch dimensions not having consistent sharding (ignore None) * contracting dimensions not having consistent sharding (ignore None) * lhs.mesh != rhs.mesh * if batch dimension and tensor dimension sharding match -> Error PiperOrigin-RevId: 684983567 --- jax/_src/lax/lax.py | 45 +++++++++++++++++++++++++-- jax/_src/lax/utils.py | 10 +++--- tests/pjit_test.py | 71 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 117 insertions(+), 9 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a740285bf250..7679c0d9ae6a 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3213,6 +3213,46 @@ def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers): rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch) return batch_shape + lhs_tensored_shape + rhs_tensored_shape + +def _check_specs_match(lhs_spec, rhs_spec, msg): + for l, r in zip(lhs_spec, rhs_spec): + if l is not None and r is not None and l != r: + raise TypeError(msg) + +def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, + preferred_element_type: DTypeLike | None): + if lhs.sharding.mesh != rhs.sharding.mesh: + raise ValueError( + 'Mesh of both lhs and rhs should match. Got lhs:' + f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') + + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch) + rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch) + msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions " + f"to have the consistent sharding, got {lhs_batch_spec} and " + f"{rhs_batch_spec}.") + _check_specs_match(lhs_batch_spec, rhs_batch_spec, msg) + + lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) + rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting) + msg = ("dot_general requires contracting dimensions to have consistent " + f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.") + _check_specs_match(lhs_contracting_spec, rhs_contracting_spec, msg) + + return _dot_general_sharding_computation( + lhs.sharding.spec, rhs.sharding.spec, dimension_numbers, lhs.sharding.mesh) + +def _dot_general_sharding_computation(lhs_spec, rhs_spec, + dimension_numbers, mesh): + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + batch_spec = tuple(lhs_spec[i] for i in lhs_batch) + lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch))) + lhs_tensored_spec = tuple_delete(lhs_spec, lhs_contract_or_batch) + rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch))) + rhs_tensored_spec = tuple_delete(rhs_spec, rhs_contract_or_batch) + return NamedSharding(mesh, P(*(batch_spec + lhs_tensored_spec + rhs_tensored_spec))) + def tuple_delete(tup, idx): idx_ = set(idx) return tuple(tup[i] for i in range(len(tup)) if i not in idx_) @@ -3419,8 +3459,9 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc: (list(lhs_cont), list(rhs_cont)), (list(lhs_batch), list(rhs_batch))) return core._pp_eqn(eqn.replace(params=printed_params), context, settings) -dot_general_p = standard_primitive(_dot_general_shape_rule, - _dot_general_dtype_rule, 'dot_general') +dot_general_p = standard_primitive( + _dot_general_shape_rule, _dot_general_dtype_rule, 'dot_general', + sharding_rule=_dot_general_sharding_rule) def _dot_general_batch_unpack_args(batch_args): diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index a6eeb18b5203..deb3c19c0a61 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -56,11 +56,11 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, out = prim.impl(*[x.val for x in avals], **kwargs) return core.ConcreteArray(out.dtype, out, weak_type=weak_type) elif least_specialized is core.ShapedArray: - out_sharding = (sharding_rule(*avals, **kwargs) - if config.sharding_in_types.value else None) - return core.ShapedArray(shape_rule(*avals, **kwargs), - dtype_rule(*avals, **kwargs), weak_type=weak_type, - sharding=out_sharding) + return core.ShapedArray( + shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs), + weak_type=weak_type, + sharding=(sharding_rule(*avals, **kwargs) + if config.sharding_in_types.value else None)) elif least_specialized is core.DShapedArray: shape = shape_rule(*avals, **kwargs) ty = (core.ShapedArray if all(type(d) is int for d in shape) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9d0389a4799f..442c67b87ff7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4603,9 +4603,9 @@ def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") +@jtu.with_config(jax_sharding_in_types=True) class ShardingInTypesTest(jtu.JaxTestCase): - @config.sharding_in_types(True) def test_basic_mul(self): mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) @@ -4631,7 +4631,6 @@ def f(x): else: self.assertEqual(lowered_text.count('@Sharding'), 2) - @config.sharding_in_types(True) def test_fully_replicated_array_mul(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp1 = np.arange(16).reshape(8, 2) @@ -4672,6 +4671,74 @@ def g(x, y): TypeError, "mul got incompatible shardings for broadcasting"): g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x', 'y'))))) + @parameterized.named_parameters( + ('x_y', P('x', None), P(None, 'y'), P('x', 'y'), None), + ('x_None', P('x', None), P(None, None), P('x', None), None), + ('contracting1', P('x', 'y'), P('y', None), P('x', None), 'all-reduce'), + ('contracting2', P('x', 'y'), P(None, None), P('x', None), 'all-gather'), + ('fsdp', P('x', None), P('x', None), P('x', None), 'all-gather'), + ('half_tp', P(None, 'y'), P(None, 'y'), P(None, 'y'), 'all-gather'), + ('other_half_tp', P(None, 'y'), P('y', None), P(None, None), 'all-reduce') + ) + def test_dot_general_basic(self, spec1, spec2, out_spec, collective_name): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp1 = np.arange(16).reshape(8, 2) + arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) + arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2)) + + @jax.jit + def f(x, y): + out = x @ y + self.assertEqual(out.sharding.spec, out_spec) + return out + + out = f(arr1, arr2) + self.assertArraysEqual(out, np_inp1 @ np_inp1.T) + self.assertEqual(out.aval.sharding.spec, out_spec) + + compiled_text = f.lower(arr1, arr2).compile().as_text() + if collective_name is not None and compiled_text is not None: + self.assertIn(collective_name, compiled_text) + + @parameterized.named_parameters( + ('fail1', P('x', 'y'), P('y', 'x'), + "PartitionSpec.*x.*x.*has duplicate entries", ValueError), + ('fail2', P('x', 'y'), P('x', 'y'), + "dot_general requires contracting dimensions to have consistent sharding", + TypeError), + ) + def test_dot_general_error(self, spec1, spec2, error_msg, error_type): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp1 = np.arange(16).reshape(8, 2) + arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) + arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2)) + + @jax.jit + def f(x, y): + return x @ y + + with self.assertRaisesRegex(error_type, error_msg): + f(arr1, arr2) + + def test_dot_general_batch_error(self): + mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z')) + arr1 = jax.device_put(np.ones((8, 4, 2)), + NamedSharding(mesh, P('x', 'y', 'z'))) + arr2 = jax.device_put(np.ones((8, 2, 4)), + NamedSharding(mesh, P('y', 'z', 'x'))) + with self.assertRaisesRegex( + TypeError, + 'dot_general requires lhs batch dimensions and rhs batch dimensions to' + ' have the consistent sharding'): + jax.lax.dot_general( + arr1, arr2, dimension_numbers=(([2], [1]), ([0], [0]))) + + with self.assertRaisesRegex( + TypeError, + 'dot_general requires lhs batch dimensions and rhs batch dimensions to' + ' have the consistent sharding'): + jnp.einsum('abc,acz->abz', arr1, arr2) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):