Skip to content

Commit

Permalink
Better repr of aval when shardings are present
Browse files Browse the repository at this point in the history
Example: (for array for shape (8, 2) with dtype float32

```
P('x', 'y') -- float32[8@x,2@y]

P('x', None) -- float32[8@x,2]

P(('x', 'y'), None) -- float32[8@xy,2]

P(None, None) -- float32[8, 2]
```

PiperOrigin-RevId: 684996577
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 11, 2024
1 parent 18bc354 commit 89fcd9f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 5 deletions.
2 changes: 1 addition & 1 deletion jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ def _get_aval_array(self):
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
return self.aval.update(sharding=NamedSharding(
self.sharding.mesh.abstract_mesh,
self.sharding.normalized_spec(self.ndim)))
self.sharding._normalized_spec(self.ndim)))
else:
return self.aval
api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
Expand Down
21 changes: 18 additions & 3 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,8 @@ def __init__(self, shape, dtype, weak_type=False, sharding=None):
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
if config.sharding_in_types.value:
if sharding is not None:
assert len(sharding.spec) == len(self.shape)
self.sharding = sharding

def update(self, shape=None, dtype=None, weak_type=None, sharding=None):
Expand Down Expand Up @@ -1805,12 +1807,14 @@ def join(self, other):
raise TypeError(self, other)

def str_short(self, short_dtypes=False):
dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
self.dtype.name)
dt_str = dt_str.replace('void', 'float0')
shapestr = ','.join(map(str, self.shape))
if hasattr(self, 'sharding'):
return f'{dt_str}[{shapestr}]({self.sharding})'
shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec))
return f'{dt_str}[{shapestr}]'
else:
shapestr = ','.join(map(str, self.shape))
return f'{dt_str}[{shapestr}]'

def _len(self, ignored_tracer):
Expand All @@ -1820,6 +1824,17 @@ def _len(self, ignored_tracer):
raise TypeError("len() of unsized object") from err # same as numpy error


def _get_shape_sharding_str(shape, spec):
for s1, s2 in zip(shape, spec):
if s2 is None:
yield f"{s1}"
elif isinstance(s2, tuple):
ss = ''.join(s for s in s2)
yield f"{s1}@{ss}"
else:
yield f"{s1}@{s2}"


def _forward_to_value(self, fun, ignored_tracer, *args):
return fun(self.val, *args)

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def is_fully_replicated(self) -> bool:
def with_memory_kind(self, kind: str) -> NamedSharding:
return NamedSharding(self.mesh, self.spec, memory_kind=kind)

def normalized_spec(self, ndim: int) -> PartitionSpec:
def _normalized_spec(self, ndim: int) -> PartitionSpec:
out = [] # type: ignore
for p in self._parsed_pspec:
if p is None:
Expand Down
19 changes: 19 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4739,6 +4739,25 @@ def test_dot_general_batch_error(self):
' have the consistent sharding'):
jnp.einsum('abc,acz->abz', arr1, arr2)

def test_aval_repr(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))

aval = core.ShapedArray((8, 2), np.float32,
sharding=NamedSharding(mesh, P('x', 'y')))
self.assertEqual(aval.str_short(), 'float32[8@x,2@y]')

aval = aval.update(sharding=NamedSharding(mesh, P('x', None)))
self.assertEqual(aval.str_short(), 'float32[8@x,2]')

aval = aval.update(sharding=NamedSharding(mesh, P(None, 'y')))
self.assertEqual(aval.str_short(), 'float32[8,2@y]')

aval = aval.update(sharding=NamedSharding(mesh, P(None, None)))
self.assertEqual(aval.str_short(), 'float32[8,2]')

aval = aval.update(sharding=NamedSharding(mesh, P(('x', 'y'), None)))
self.assertEqual(aval.str_short(), 'float32[8@xy,2]')


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 89fcd9f

Please sign in to comment.