From 5b3257d90be5f6870e82efd666e64d8788d6d2d8 Mon Sep 17 00:00:00 2001 From: Philip Pham Date: Fri, 16 Feb 2024 12:56:22 -0800 Subject: [PATCH] Pipe `tiled` through `all_to_all` primitive Fixes #15982. PiperOrigin-RevId: 607775078 --- jax/_src/lax/parallel.py | 100 +++++++++++++++++++++++++++++---------- tests/shard_map_test.py | 34 +++++++++++++ 2 files changed, 109 insertions(+), 25 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index eaaf08c16ff6..acb5295599bc 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -396,9 +396,14 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis): else: # concat_axis < split_axis x = lax.expand_dims(x, (concat_axis,)) # insert the new axis split_axis += 1 # we have a new axis before split_axis now - result = all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis, - axis_name=axis_name, - axis_index_groups=axis_index_groups) + result = all_to_all_p.bind( + x, + split_axis=split_axis, + concat_axis=concat_axis, + axis_name=axis_name, + axis_index_groups=axis_index_groups, + tiled=tiled, + ) if not tiled and split_axis != concat_axis: result = lax.squeeze(result, (split_axis,)) return result @@ -954,8 +959,10 @@ def _index_in_group(axis_name, axis_index_groups): slicing.dynamic_slice_in_dim(device_id_to_idx, cur_device_id, 1), [0]) -def _all_to_all_lowering(ctx, x, *, - split_axis, concat_axis, axis_name, axis_index_groups): +def _all_to_all_lowering( + ctx, x, *, split_axis, concat_axis, axis_name, axis_index_groups, tiled +): + del tiled # Workaround for AllToAll not being implemented on CPU. replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, axis_index_groups) @@ -985,15 +992,32 @@ def _all_to_all_lowering(ctx, x, *, replica_groups=_replica_groups_hlo(replica_groups), **other_args).results -def _all_to_all_transpose_rule(cts, x, axis_name, split_axis, concat_axis, axis_index_groups): - return (all_to_all( - cts, - axis_name=axis_name, - split_axis=concat_axis, - concat_axis=split_axis, - axis_index_groups=axis_index_groups),) -def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, axis_index_groups): +def _all_to_all_transpose_rule( + cts, x, axis_name, split_axis, concat_axis, axis_index_groups, tiled +): + return ( + all_to_all( + cts, + axis_name=axis_name, + split_axis=concat_axis, + concat_axis=split_axis, + axis_index_groups=axis_index_groups, + tiled=tiled, + ), + ) + + +def _all_to_all_batcher( + vals_in, + dims_in, + *, + axis_name, + split_axis, + concat_axis, + axis_index_groups, + tiled, +): x, = vals_in d, = dims_in result = all_to_all_p.bind( @@ -1001,12 +1025,24 @@ def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, axis_name=axis_name, split_axis=split_axis + (d <= split_axis), concat_axis=concat_axis + (d <= concat_axis), - axis_index_groups=axis_index_groups) + axis_index_groups=axis_index_groups, + tiled=tiled, + ) return result, d -def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, - axis_name, split_axis, concat_axis, - axis_index_groups): + +def _all_to_all_batched_collective( + axis_size, + frame_name, + _, + vals_in, + dims_in, + axis_name, + split_axis, + concat_axis, + axis_index_groups, + tiled, +): if axis_index_groups is not None: raise NotImplementedError("Please open a feature request!") x, = vals_in @@ -1039,18 +1075,28 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, split_axis += 3; concat_axis += 3 # Offset by extra three leading dims if major_axes: - x = all_to_all_p.bind(x, axis_name=major_axes, - split_axis=split_axis, concat_axis=0, - axis_index_groups=axis_index_groups) + x = all_to_all_p.bind( + x, + axis_name=major_axes, + split_axis=split_axis, + concat_axis=0, + axis_index_groups=axis_index_groups, + tiled=tiled, + ) # Split out the local part into axis new_d (NOTE: d is already in axis 1) x = _splitaxis(split_axis, axis_size, x) new_d = split_axis concat_axis += (split_axis <= concat_axis) # Offset the existing axes by the new batch axis split_axis += 1 if minor_axes: - x = all_to_all_p.bind(x, axis_name=minor_axes, - split_axis=split_axis, concat_axis=2, - axis_index_groups=axis_index_groups) + x = all_to_all_p.bind( + x, + axis_name=minor_axes, + split_axis=split_axis, + concat_axis=2, + axis_index_groups=axis_index_groups, + tiled=tiled, + ) # Fold the chunk axes into a single one x = _foldaxis(0, _foldaxis(0, x)) @@ -1060,7 +1106,11 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, new_d -= 1 # We've removed 0th dimension, so new_d needs to be adjusted return x, new_d -def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_groups): + +def _all_to_all_abstract_eval( + x, axis_name, split_axis, concat_axis, axis_index_groups, tiled +): + del tiled input_aval = raise_to_shaped(x) shape = list(input_aval.shape) axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0]) @@ -1069,6 +1119,7 @@ def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_ shape[concat_axis] *= axis_size return input_aval.update(shape=tuple(shape), weak_type=False) + all_to_all_p = core.AxisPrimitive('all_to_all') all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval) mlir.register_lowering(all_to_all_p, _all_to_all_lowering) @@ -1327,7 +1378,6 @@ def _reduce_scatter_lowering( return [hlo.reshape(mlir.aval_to_ir_type(aval_out), op.result)] - def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled): if not isinstance(axis_name, (list, tuple)): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f1dcd60f4fe5..b112221bc9b6 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -359,6 +359,40 @@ def fwd(a): for j, block in enumerate(np.split(row_block, 2, axis=-1)): self.assertAllClose(block, c.addressable_data(2 * i + j)) + def test_all_to_all_grad(self): + mesh_axes = dict(x=4) + devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))]) + mesh = Mesh( + devices.reshape(tuple(mesh_axes.values())), + axis_names=tuple(mesh_axes.keys()), + ) + a = jax.device_put( + jnp.arange(8 * 8, dtype=jnp.float32).reshape((8, 8)), + jax.sharding.NamedSharding(mesh, P('x', None)), + ) + self.assertEqual(a.addressable_data(0).shape, (2, 8)) + + @jax.jit + @partial( + shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P(None, 'x') + ) + def fwd(x): + return lax.all_to_all(x, 'x', split_axis=1, concat_axis=0, tiled=True) + + c = fwd(a) + self.assertEqual(c.addressable_data(0).shape, (8, 2)) + self.assertAllClose(a, c) + + @jax.jit + @partial(jax.grad, has_aux=True) + def loss_and_grad(x): + loss = fwd(x).sum() * 2 + return loss, loss + + grad, loss = loss_and_grad(a) + self.assertEqual(loss, 2 * sum(range(64))) + self.assertAllClose(grad, 2 * np.ones_like(a)) + def test_eager_repr(self): mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) s = None