Skip to content

Commit

Permalink
[sharding_in_types] If out_aval.sharding is not None and the user spe…
Browse files Browse the repository at this point in the history
…cified out_sharding is None, concretize it with the device assignment available and add it to the final out_shardings that's used for lowering and compilation.

This will allow us to return the exact sharding spec that sharding propagation rules figured out.

PiperOrigin-RevId: 687349015
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 18, 2024
1 parent f8a3c03 commit 2153de4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 26 deletions.
54 changes: 35 additions & 19 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources,
SingleDeviceSharding, GSPMDSharding)
SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding)
from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name,
tuple_update, tuple_delete, distributed_debug_log,
unzip2, HashableFunction, weakref_lru_cache)
Expand Down Expand Up @@ -1257,7 +1257,7 @@ def _handle_token_bufs(self, token_bufs, sharded_token):
for token in token_buf:
assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding)
token_devices.append(token.sharding._device_assignment[0])
s = sharding_impls.PositionalSharding(token_devices)
s = PositionalSharding(token_devices)
global_token_array = jax.make_array_from_single_device_arrays(
(0,), s, token_buf
)
Expand Down Expand Up @@ -1608,7 +1608,7 @@ def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
sharding_proto = (
sharding_impls.NamedSharding(mesh, array_mapping_to_axis_resources(axes))
NamedSharding(mesh, array_mapping_to_axis_resources(axes))
._to_xla_hlo_sharding(aval_in.ndim).to_proto())
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto,
Expand All @@ -1634,7 +1634,7 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto,
unspecified_dims=unspecified_dims)
sharding_proto = (
sharding_impls.NamedSharding(mesh, array_mapping_to_axis_resources(axes))
NamedSharding(mesh, array_mapping_to_axis_resources(axes))
._to_xla_hlo_sharding(aval_out.ndim).to_proto())
return (mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto,
unspecified_dims),)
Expand Down Expand Up @@ -2117,6 +2117,24 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
donated_invars, out_shardings, out_layouts)

def _concretize_abstract_shardings(shardings, avals, device_assignment):
np_dev = np.vectorize(lambda i: device_assignment[i],
otypes=[object])(np.arange(len(device_assignment)))

@lru_cache(maxsize=128)
def _abstract_to_concrete_mesh(abstract_mesh):
return mesh_lib.Mesh(
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names)

out = []
for s, a in zip(shardings, avals):
if is_unspecified(s) and a.sharding is not None:
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
a.sharding.spec))
else:
out.append(s)
return tuple(out)


@profiler.annotate_function
def lower_sharding_computation(
Expand All @@ -2142,7 +2160,6 @@ def lower_sharding_computation(
lower_sharding_computation calculates the number of out_avals so it can apply
the singleton UNSPECIFIED to all out_avals.
"""
# 1. Trace to jaxpr and preprocess/verify it
auto_spmd_lowering = check_if_any_auto(
it.chain.from_iterable([in_shardings, out_shardings]))

Expand Down Expand Up @@ -2189,6 +2206,10 @@ def lower_sharding_computation(
for js, source_info in unique_intermediate_shardings)),
devices_from_context)

if config.sharding_in_types.value:
out_shardings = _concretize_abstract_shardings(
out_shardings, global_out_avals, device_assignment)

platforms = lowering_platforms or (backend.platform,)

committed = bool(
Expand Down Expand Up @@ -2220,7 +2241,7 @@ def lower_sharding_computation(
if prim_requires_devices:
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
[js for js, _ in unique_intermediate_shardings]):
if isinstance(sharding, sharding_impls.NamedSharding):
if isinstance(sharding, NamedSharding):
if (abstract_mesh is not None and
abstract_mesh != sharding.mesh.abstract_mesh):
raise ValueError(
Expand Down Expand Up @@ -2423,13 +2444,12 @@ def _get_in_shardings_from_xla(
# without mesh.
def _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh: Mesh
) -> tuple[Sequence[sharding_impls.NamedSharding],
Sequence[sharding_impls.NamedSharding]]:
) -> tuple[Sequence[NamedSharding], Sequence[NamedSharding]]:
from jax._src import pjit

in_pspec, out_pspec = pjit.get_pspec_from_executable(xla_executable, mesh)
return ([sharding_impls.NamedSharding(mesh, i) for i in in_pspec],
[sharding_impls.NamedSharding(mesh, o) for o in out_pspec])
return ([NamedSharding(mesh, i) for i in in_pspec],
[NamedSharding(mesh, o) for o in out_pspec])


_orig_out_sharding_handlers = {}
Expand All @@ -2439,29 +2459,25 @@ def _get_mesh_pspec_shardings_from_executable(

def _register_out_sharding_handler(
sharding_cls: type[_ShardingT],
handler: Callable[[sharding_impls.GSPMDSharding, _ShardingT], _ShardingT],
handler: Callable[[GSPMDSharding, _ShardingT], _ShardingT],
) -> None:
_orig_out_sharding_handlers[sharding_cls] = handler

def _gspmd_to_named_sharding(
out_s: sharding_impls.GSPMDSharding,
orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding:
assert isinstance(orig_in_s.mesh, mesh_lib.Mesh)
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)

_register_out_sharding_handler(
sharding_impls.NamedSharding, _gspmd_to_named_sharding)
_register_out_sharding_handler(NamedSharding, _gspmd_to_named_sharding)


def _gspmd_to_positional_sharding(
out_s: sharding_impls.GSPMDSharding,
orig_in_s: sharding_impls.PositionalSharding
) -> sharding_impls.PositionalSharding:
out_s: GSPMDSharding, orig_in_s: PositionalSharding) -> PositionalSharding:
return sharding_impls._op_sharding_to_pos_sharding(
out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind)

_register_out_sharding_handler(
sharding_impls.PositionalSharding, _gspmd_to_positional_sharding)
PositionalSharding, _gspmd_to_positional_sharding)

def _gspmd_to_single_device_sharding(
out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding:
Expand Down
15 changes: 8 additions & 7 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4685,7 +4685,7 @@ def f(x, y):

out = f(arr1, arr2)
self.assertArraysEqual(out, np_inp1 @ np_inp1.T)
self.assertEqual(out.aval.sharding.spec, out_spec)
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))

lowered = f.lower(arr1, arr2)
self.assertIn('@Sharding', lowered.as_text())
Expand Down Expand Up @@ -4774,7 +4774,7 @@ def f(x):

out = f(arr)
self.assertArraysEqual(out, np.sum(np_inp, axis=axis))
self.assertEqual(out.aval.sharding.spec, out_spec)
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))

lowered = f.lower(arr)
self.assertIn('@Sharding', lowered.as_text())
Expand Down Expand Up @@ -4805,7 +4805,7 @@ def f(x):

out = f(arr)
self.assertArraysEqual(out, np.max(np_inp, axis=axis))
self.assertEqual(out.aval.sharding.spec, out_spec)
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))

lowered = f.lower(arr)
self.assertIn('@Sharding', lowered.as_text())
Expand Down Expand Up @@ -4836,7 +4836,7 @@ def f(x):
return y

out = f(arr)
self.assertEqual(out.aval.sharding.spec, out_spec)
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))

lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)
Expand Down Expand Up @@ -4913,7 +4913,7 @@ def f(x):

out = f(arr)
self.assertArraysEqual(out, np.transpose(arr, (1, 2, 0)))
self.assertEqual(out.aval.sharding.spec, P('y', 'z', 'x'))
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'z', 'x')))

lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)
Expand All @@ -4931,13 +4931,14 @@ def f(x):
return y

out = f(arr)
self.assertEqual(out.aval.sharding.spec, P('x', None))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))

@jax.jit
def g(x):
x = x * 2
y = jax.lax.broadcasted_iota(
x.dtype, (8, 2), 0, _sharding=NamedSharding(mesh, P('x', 'y')))
x.dtype, (8, 2), 0,
_sharding=NamedSharding(mesh.abstract_mesh, P('x', 'y')))
self.assertEqual(y.sharding.spec, P('x', 'y'))
return x, y

Expand Down

0 comments on commit 2153de4

Please sign in to comment.