From 74d247ac517b19b78064ddd4cf92adb6a28f6563 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 17 Oct 2024 22:54:55 -0700 Subject: [PATCH] [sharding_in_types] If out_aval.sharding is not None and the user specified out_sharding is None, concretize it with the device assignment available and add it to the final out_shardings that's used for lowering and compilation. This will allow us to return the exact sharding spec that sharding propagation rules figured out. PiperOrigin-RevId: 687174276 --- jax/_src/interpreters/pxla.py | 54 +++++++++++++++++++++++------------ tests/pjit_test.py | 15 +++++----- 2 files changed, 43 insertions(+), 26 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index de9f393cf64a..6a65f1bd067f 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -70,7 +70,7 @@ ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources, - SingleDeviceSharding, GSPMDSharding) + SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_update, tuple_delete, distributed_debug_log, unzip2, HashableFunction, weakref_lru_cache) @@ -1257,7 +1257,7 @@ def _handle_token_bufs(self, token_bufs, sharded_token): for token in token_buf: assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding) token_devices.append(token.sharding._device_assignment[0]) - s = sharding_impls.PositionalSharding(token_devices) + s = PositionalSharding(token_devices) global_token_array = jax.make_array_from_single_device_arrays( (0,), s, token_buf ) @@ -1608,7 +1608,7 @@ def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, aval_in, = ctx.avals_in aval_out, = ctx.avals_out sharding_proto = ( - sharding_impls.NamedSharding(mesh, array_mapping_to_axis_resources(axes)) + NamedSharding(mesh, array_mapping_to_axis_resources(axes)) ._to_xla_hlo_sharding(aval_in.ndim).to_proto()) unspecified_dims = set(range(aval_in.ndim)) - set(axes.values()) sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto, @@ -1634,7 +1634,7 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto, unspecified_dims=unspecified_dims) sharding_proto = ( - sharding_impls.NamedSharding(mesh, array_mapping_to_axis_resources(axes)) + NamedSharding(mesh, array_mapping_to_axis_resources(axes)) ._to_xla_hlo_sharding(aval_out.ndim).to_proto()) return (mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto, unspecified_dims),) @@ -2117,6 +2117,24 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts, return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts, donated_invars, out_shardings, out_layouts) +def _concretize_abstract_shardings(shardings, avals, device_assignment): + np_dev = np.vectorize(lambda i: device_assignment[i], + otypes=[object])(np.arange(len(device_assignment))) + + @lru_cache(maxsize=128) + def _abstract_to_concrete_mesh(abstract_mesh): + return mesh_lib.Mesh( + np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names) + + out = [] + for s, a in zip(shardings, avals): + if is_unspecified(s) and a.sharding is not None: + out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh), + a.sharding.spec)) + else: + out.append(s) + return tuple(out) + @profiler.annotate_function def lower_sharding_computation( @@ -2142,7 +2160,6 @@ def lower_sharding_computation( lower_sharding_computation calculates the number of out_avals so it can apply the singleton UNSPECIFIED to all out_avals. """ - # 1. Trace to jaxpr and preprocess/verify it auto_spmd_lowering = check_if_any_auto( it.chain.from_iterable([in_shardings, out_shardings])) @@ -2189,6 +2206,10 @@ def lower_sharding_computation( for js, source_info in unique_intermediate_shardings)), devices_from_context) + if config.sharding_in_types.value: + out_shardings = _concretize_abstract_shardings( + out_shardings, global_out_avals, device_assignment) + platforms = lowering_platforms or (backend.platform,) committed = bool( @@ -2220,7 +2241,7 @@ def lower_sharding_computation( if prim_requires_devices: for sharding in it.chain(unique_in_shardings, unique_out_shardings, [js for js, _ in unique_intermediate_shardings]): - if isinstance(sharding, sharding_impls.NamedSharding): + if isinstance(sharding, NamedSharding): if (abstract_mesh is not None and abstract_mesh != sharding.mesh.abstract_mesh): raise ValueError( @@ -2423,13 +2444,12 @@ def _get_in_shardings_from_xla( # without mesh. def _get_mesh_pspec_shardings_from_executable( xla_executable, mesh: Mesh -) -> tuple[Sequence[sharding_impls.NamedSharding], - Sequence[sharding_impls.NamedSharding]]: +) -> tuple[Sequence[NamedSharding], Sequence[NamedSharding]]: from jax._src import pjit in_pspec, out_pspec = pjit.get_pspec_from_executable(xla_executable, mesh) - return ([sharding_impls.NamedSharding(mesh, i) for i in in_pspec], - [sharding_impls.NamedSharding(mesh, o) for o in out_pspec]) + return ([NamedSharding(mesh, i) for i in in_pspec], + [NamedSharding(mesh, o) for o in out_pspec]) _orig_out_sharding_handlers = {} @@ -2439,29 +2459,25 @@ def _get_mesh_pspec_shardings_from_executable( def _register_out_sharding_handler( sharding_cls: type[_ShardingT], - handler: Callable[[sharding_impls.GSPMDSharding, _ShardingT], _ShardingT], + handler: Callable[[GSPMDSharding, _ShardingT], _ShardingT], ) -> None: _orig_out_sharding_handlers[sharding_cls] = handler def _gspmd_to_named_sharding( - out_s: sharding_impls.GSPMDSharding, - orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding: + out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding: assert isinstance(orig_in_s.mesh, mesh_lib.Mesh) return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) -_register_out_sharding_handler( - sharding_impls.NamedSharding, _gspmd_to_named_sharding) +_register_out_sharding_handler(NamedSharding, _gspmd_to_named_sharding) def _gspmd_to_positional_sharding( - out_s: sharding_impls.GSPMDSharding, - orig_in_s: sharding_impls.PositionalSharding - ) -> sharding_impls.PositionalSharding: + out_s: GSPMDSharding, orig_in_s: PositionalSharding) -> PositionalSharding: return sharding_impls._op_sharding_to_pos_sharding( out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind) _register_out_sharding_handler( - sharding_impls.PositionalSharding, _gspmd_to_positional_sharding) + PositionalSharding, _gspmd_to_positional_sharding) def _gspmd_to_single_device_sharding( out_s: GSPMDSharding, orig_in_s: SingleDeviceSharding) -> SingleDeviceSharding: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8e34b6fb19fe..d3b96676afdc 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4685,7 +4685,7 @@ def f(x, y): out = f(arr1, arr2) self.assertArraysEqual(out, np_inp1 @ np_inp1.T) - self.assertEqual(out.aval.sharding.spec, out_spec) + self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) lowered = f.lower(arr1, arr2) self.assertIn('@Sharding', lowered.as_text()) @@ -4774,7 +4774,7 @@ def f(x): out = f(arr) self.assertArraysEqual(out, np.sum(np_inp, axis=axis)) - self.assertEqual(out.aval.sharding.spec, out_spec) + self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) lowered = f.lower(arr) self.assertIn('@Sharding', lowered.as_text()) @@ -4805,7 +4805,7 @@ def f(x): out = f(arr) self.assertArraysEqual(out, np.max(np_inp, axis=axis)) - self.assertEqual(out.aval.sharding.spec, out_spec) + self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) lowered = f.lower(arr) self.assertIn('@Sharding', lowered.as_text()) @@ -4836,7 +4836,7 @@ def f(x): return y out = f(arr) - self.assertEqual(out.aval.sharding.spec, out_spec) + self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) @@ -4913,7 +4913,7 @@ def f(x): out = f(arr) self.assertArraysEqual(out, np.transpose(arr, (1, 2, 0))) - self.assertEqual(out.aval.sharding.spec, P('y', 'z', 'x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'z', 'x'))) lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) @@ -4931,13 +4931,14 @@ def f(x): return y out = f(arr) - self.assertEqual(out.aval.sharding.spec, P('x', None)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) @jax.jit def g(x): x = x * 2 y = jax.lax.broadcasted_iota( - x.dtype, (8, 2), 0, _sharding=NamedSharding(mesh, P('x', 'y'))) + x.dtype, (8, 2), 0, + _sharding=NamedSharding(mesh.abstract_mesh, P('x', 'y'))) self.assertEqual(y.sharding.spec, P('x', 'y')) return x, y