From 89fcd9f1f19e3c09b6297966b551f3ed85d1b7a7 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 11 Oct 2024 16:47:43 -0700 Subject: [PATCH] Better repr of aval when shardings are present Example: (for array for shape (8, 2) with dtype float32 ``` P('x', 'y') -- float32[8@x,2@y] P('x', None) -- float32[8@x,2] P(('x', 'y'), None) -- float32[8@xy,2] P(None, None) -- float32[8, 2] ``` PiperOrigin-RevId: 684996577 --- jax/_src/array.py | 2 +- jax/_src/core.py | 21 ++++++++++++++++++--- jax/_src/sharding_impls.py | 2 +- tests/pjit_test.py | 19 +++++++++++++++++++ 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 4e0cd3d16875..750963873ec8 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1030,7 +1030,7 @@ def _get_aval_array(self): if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding): return self.aval.update(sharding=NamedSharding( self.sharding.mesh.abstract_mesh, - self.sharding.normalized_spec(self.ndim))) + self.sharding._normalized_spec(self.ndim))) else: return self.aval api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array diff --git a/jax/_src/core.py b/jax/_src/core.py index a2d243de9ea5..592c425b584a 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1755,6 +1755,8 @@ def __init__(self, shape, dtype, weak_type=False, sharding=None): self.dtype = _dtype_object(dtype) self.weak_type = weak_type if config.sharding_in_types.value: + if sharding is not None: + assert len(sharding.spec) == len(self.shape) self.sharding = sharding def update(self, shape=None, dtype=None, weak_type=None, sharding=None): @@ -1805,12 +1807,14 @@ def join(self, other): raise TypeError(self, other) def str_short(self, short_dtypes=False): - dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else + self.dtype.name) dt_str = dt_str.replace('void', 'float0') - shapestr = ','.join(map(str, self.shape)) if hasattr(self, 'sharding'): - return f'{dt_str}[{shapestr}]({self.sharding})' + shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec)) + return f'{dt_str}[{shapestr}]' else: + shapestr = ','.join(map(str, self.shape)) return f'{dt_str}[{shapestr}]' def _len(self, ignored_tracer): @@ -1820,6 +1824,17 @@ def _len(self, ignored_tracer): raise TypeError("len() of unsized object") from err # same as numpy error +def _get_shape_sharding_str(shape, spec): + for s1, s2 in zip(shape, spec): + if s2 is None: + yield f"{s1}" + elif isinstance(s2, tuple): + ss = ''.join(s for s in s2) + yield f"{s1}@{ss}" + else: + yield f"{s1}@{s2}" + + def _forward_to_value(self, fun, ignored_tracer, *args): return fun(self.val, *args) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index b43cad745f3a..e53a80b98465 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -307,7 +307,7 @@ def is_fully_replicated(self) -> bool: def with_memory_kind(self, kind: str) -> NamedSharding: return NamedSharding(self.mesh, self.spec, memory_kind=kind) - def normalized_spec(self, ndim: int) -> PartitionSpec: + def _normalized_spec(self, ndim: int) -> PartitionSpec: out = [] # type: ignore for p in self._parsed_pspec: if p is None: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 442c67b87ff7..c0cccc1fe6c2 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4739,6 +4739,25 @@ def test_dot_general_batch_error(self): ' have the consistent sharding'): jnp.einsum('abc,acz->abz', arr1, arr2) + def test_aval_repr(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + + aval = core.ShapedArray((8, 2), np.float32, + sharding=NamedSharding(mesh, P('x', 'y'))) + self.assertEqual(aval.str_short(), 'float32[8@x,2@y]') + + aval = aval.update(sharding=NamedSharding(mesh, P('x', None))) + self.assertEqual(aval.str_short(), 'float32[8@x,2]') + + aval = aval.update(sharding=NamedSharding(mesh, P(None, 'y'))) + self.assertEqual(aval.str_short(), 'float32[8,2@y]') + + aval = aval.update(sharding=NamedSharding(mesh, P(None, None))) + self.assertEqual(aval.str_short(), 'float32[8,2]') + + aval = aval.update(sharding=NamedSharding(mesh, P(('x', 'y'), None))) + self.assertEqual(aval.str_short(), 'float32[8@xy,2]') + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):