From 351187d9dac6767e4e08845da87ccb918eb0f5b2 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 9 Oct 2024 21:23:57 -0700 Subject: [PATCH] [sharding_in_types] Add support for nary ops to propagate sharding when 1 input is sharded and all others are replicated. PiperOrigin-RevId: 684289345 --- jax/_src/lax/lax.py | 27 +++++++++++++++++---------- jax/_src/sharding_impls.py | 14 ++++++++++++++ tests/pjit_test.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 10 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index dd28e1ba6fca..fce19b319692 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -66,7 +66,8 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo -from jax._src.sharding_impls import PmapSharding, NamedSharding, PartitionSpec +from jax._src.sharding_impls import (PmapSharding, NamedSharding, + PartitionSpec as P) from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis, split_list, NumpyComplexWarning) @@ -2072,7 +2073,7 @@ def broadcasting_sharding_rule(name, *avals): msg = '{}: arrays must have same number of dimensions, got {}.' raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) - specs = [a.sharding.spec for a in avals if a.shape] + specs = [a.sharding.normalized_spec for a in avals if a.shape] mesh = None for a in avals: @@ -2084,23 +2085,29 @@ def broadcasting_sharding_rule(name, *avals): f' another mesh: {a.sharding.mesh}') assert mesh is not None - result_specs = [] - for ss, ds in zip(zip(*specs), zip(*shapes)): + result_specs = [None] * len(shapes[0]) + for i, (ss, ds) in enumerate(zip(zip(*specs), zip(*shapes))): if all(s == ss[0] for s in ss[1:]): # if all dimension shardings are same, the resulting dimension sharding is # the same. - result_specs.append(ss[0]) + result_specs[i] = ss[0] else: non_trivial_s = [s for s, d in zip(ss, ds) if not (core.definitely_equal(d, 1) and s is None)] if not non_trivial_s: - result_specs.append(None) + result_specs[i] = None elif all(non_trivial_s[0] == s for s in non_trivial_s[1:]): - result_specs.append(non_trivial_s[0]) + result_specs[i] = non_trivial_s[0] else: - raise TypeError(f'{name} got incompatible shardings for broadcasting: ' - f'{", ".join(map(str, map(tuple, specs)))}.') - return NamedSharding(mesh, PartitionSpec(*result_specs)) + for s in ss: + if result_specs[i] is None and s is not None: + result_specs[i] = s + elif (result_specs[i] is not None and s is not None and + result_specs[i] != s): + raise TypeError( + f'{name} got incompatible shardings for broadcasting: ' + f'{", ".join(map(str, map(tuple, specs)))}.') + return NamedSharding(mesh, P(*result_specs)) def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index b69e78fe9ddf..3aa8fafdd40a 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -307,6 +307,20 @@ def is_fully_replicated(self) -> bool: def with_memory_kind(self, kind: str) -> NamedSharding: return NamedSharding(self.mesh, self.spec, memory_kind=kind) + @functools.cached_property + def normalized_spec(self): + out = [] + for p in self._parsed_pspec: + if p is None: + raise ValueError("UNCONSTRAINED is not supported yet.") + if not p: + out.append(None) + elif isinstance(p, tuple) and len(p) == 1: + out.append(p[0]) + else: + out.append(p) + return tuple(out) + def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9ca4185bb0da..af9f55333be7 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4631,6 +4631,39 @@ def f(x): else: self.assertEqual(lowered_text.count('@Sharding'), 2) + @config.sharding_in_types(True) + def test_fully_replicated_array_mul(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp1 = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr1 = jax.device_put(np_inp1, s) + + np_inp2 = np.arange(2).reshape(1, 2) + arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P(None, None))) + + @jax.jit + def f(x, y): + self.assertEqual(x.sharding.spec, s.spec) + out = x * y + self.assertEqual(out.sharding.spec, s.spec) + return out + + out = f(arr1, arr2) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, (np_inp1 * np_inp2)) + + out = f(arr1, arr1) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, (np_inp1 * np_inp1)) + + @jax.jit + def g(x, y): + return x * y + + with self.assertRaisesRegex( + TypeError, "mul got incompatible shardings for broadcasting"): + g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P('y', 'x')))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):