From 5c8784457ede61fd5bded79fc29763f7d91353eb Mon Sep 17 00:00:00 2001 From: Philip Pham Date: Fri, 16 Feb 2024 12:56:22 -0800 Subject: [PATCH] Pipe `tiled` through `all_to_all` primitive The `_all_to_all_transpose_rule` calls `all_to_all` which can accept a `tiled` argument. Thus, for the transpose to know the right value of `tiled` to pass, we need to plumb the `tiled` argument through the primitive and various interpreters, even though it's a no-op because the `tiled` argument is handled outside the primitive. It would be cleaner to handle `tiled` inside the primitive, but I will leave that for followup work. Fixes #15982. PiperOrigin-RevId: 607775078 --- jax/_src/lax/parallel.py | 34 +++++++++++++++++++++++----------- tests/shard_map_test.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 7226b26494db..e76641890598 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -396,7 +396,8 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis): split_axis += 1 # we have a new axis before split_axis now result = all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name, - axis_index_groups=axis_index_groups) + axis_index_groups=axis_index_groups, + tiled=tiled) if not tiled and split_axis != concat_axis: result = lax.squeeze(result, (split_axis,)) return result @@ -954,8 +955,10 @@ def _index_in_group(axis_name, axis_index_groups): slicing.dynamic_slice_in_dim(device_id_to_idx, cur_device_id, 1), [0]) -def _all_to_all_lowering(ctx, x, *, - split_axis, concat_axis, axis_name, axis_index_groups): +def _all_to_all_lowering( + ctx, x, *, split_axis, concat_axis, axis_name, axis_index_groups, tiled +): + del tiled # expand_dims and squeeze is done in `all_to_all` if `True` # Workaround for AllToAll not being implemented on CPU. replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, axis_index_groups) @@ -985,15 +988,19 @@ def _all_to_all_lowering(ctx, x, *, replica_groups=_replica_groups_hlo(replica_groups), **other_args).results -def _all_to_all_transpose_rule(cts, x, axis_name, split_axis, concat_axis, axis_index_groups): +def _all_to_all_transpose_rule( + cts, x, axis_name, split_axis, concat_axis, axis_index_groups, tiled +): return (all_to_all( cts, axis_name=axis_name, split_axis=concat_axis, concat_axis=split_axis, - axis_index_groups=axis_index_groups),) + axis_index_groups=axis_index_groups, + tiled=tiled),) -def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, axis_index_groups): +def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, axis_index_groups, + tiled): x, = vals_in d, = dims_in result = all_to_all_p.bind( @@ -1001,12 +1008,14 @@ def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, axis_name=axis_name, split_axis=split_axis + (d <= split_axis), concat_axis=concat_axis + (d <= concat_axis), - axis_index_groups=axis_index_groups) + axis_index_groups=axis_index_groups, + tiled=tiled, + ) return result, d def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, axis_name, split_axis, concat_axis, - axis_index_groups): + axis_index_groups, tiled): if axis_index_groups is not None: raise NotImplementedError("Please open a feature request!") x, = vals_in @@ -1041,7 +1050,8 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, if major_axes: x = all_to_all_p.bind(x, axis_name=major_axes, split_axis=split_axis, concat_axis=0, - axis_index_groups=axis_index_groups) + axis_index_groups=axis_index_groups, + tiled=tiled) # Split out the local part into axis new_d (NOTE: d is already in axis 1) x = _splitaxis(split_axis, axis_size, x) new_d = split_axis @@ -1050,7 +1060,8 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, if minor_axes: x = all_to_all_p.bind(x, axis_name=minor_axes, split_axis=split_axis, concat_axis=2, - axis_index_groups=axis_index_groups) + axis_index_groups=axis_index_groups, + tiled=tiled) # Fold the chunk axes into a single one x = _foldaxis(0, _foldaxis(0, x)) @@ -1062,8 +1073,9 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, def _all_to_all_effectful_abstract_eval( - x, axis_name, split_axis, concat_axis, axis_index_groups + x, axis_name, split_axis, concat_axis, axis_index_groups, tiled ): + del tiled # expand_dims and squeeze is done in `all_to_all` if `True` if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) input_aval = raise_to_shaped(x) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f1dcd60f4fe5..b112221bc9b6 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -359,6 +359,40 @@ def fwd(a): for j, block in enumerate(np.split(row_block, 2, axis=-1)): self.assertAllClose(block, c.addressable_data(2 * i + j)) + def test_all_to_all_grad(self): + mesh_axes = dict(x=4) + devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))]) + mesh = Mesh( + devices.reshape(tuple(mesh_axes.values())), + axis_names=tuple(mesh_axes.keys()), + ) + a = jax.device_put( + jnp.arange(8 * 8, dtype=jnp.float32).reshape((8, 8)), + jax.sharding.NamedSharding(mesh, P('x', None)), + ) + self.assertEqual(a.addressable_data(0).shape, (2, 8)) + + @jax.jit + @partial( + shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P(None, 'x') + ) + def fwd(x): + return lax.all_to_all(x, 'x', split_axis=1, concat_axis=0, tiled=True) + + c = fwd(a) + self.assertEqual(c.addressable_data(0).shape, (8, 2)) + self.assertAllClose(a, c) + + @jax.jit + @partial(jax.grad, has_aux=True) + def loss_and_grad(x): + loss = fwd(x).sum() * 2 + return loss, loss + + grad, loss = loss_and_grad(a) + self.assertEqual(loss, 2 * sum(range(64))) + self.assertAllClose(grad, 2 * np.ones_like(a)) + def test_eager_repr(self): mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) s = None