Skip to content

Commit

Permalink
Pipe tiled through all_to_all primitive
Browse files Browse the repository at this point in the history
The `_all_to_all_transpose_rule` calls `all_to_all` which can accept a `tiled`
argument. Thus, for the transpose to know the right value of `tiled` to pass, we
need to plumb the `tiled` argument through the primitive and various
interpreters, even though it's a no-op because the `tiled` argument is handled
outside the primitive. It would be cleaner to handle `tiled` inside the
primitive, but I will leave that for followup work.

Fixes #15982.

PiperOrigin-RevId: 607775078
  • Loading branch information
ppham27 authored and jax authors committed Mar 5, 2024
1 parent 40038d6 commit 981ca06
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
34 changes: 23 additions & 11 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,8 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis):
split_axis += 1 # we have a new axis before split_axis now
result = all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis,
axis_name=axis_name,
axis_index_groups=axis_index_groups)
axis_index_groups=axis_index_groups,
tiled=tiled)
if not tiled and split_axis != concat_axis:
result = lax.squeeze(result, (split_axis,))
return result
Expand Down Expand Up @@ -954,8 +955,10 @@ def _index_in_group(axis_name, axis_index_groups):
slicing.dynamic_slice_in_dim(device_id_to_idx, cur_device_id, 1), [0])


def _all_to_all_lowering(ctx, x, *,
split_axis, concat_axis, axis_name, axis_index_groups):
def _all_to_all_lowering(
ctx, x, *, split_axis, concat_axis, axis_name, axis_index_groups, tiled
):
del tiled # expand_dims and squeeze is done in `all_to_all` if `True`
# Workaround for AllToAll not being implemented on CPU.
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
axis_index_groups)
Expand Down Expand Up @@ -985,28 +988,34 @@ def _all_to_all_lowering(ctx, x, *,
replica_groups=_replica_groups_hlo(replica_groups),
**other_args).results

def _all_to_all_transpose_rule(cts, x, axis_name, split_axis, concat_axis, axis_index_groups):
def _all_to_all_transpose_rule(
cts, x, axis_name, split_axis, concat_axis, axis_index_groups, tiled
):
return (all_to_all(
cts,
axis_name=axis_name,
split_axis=concat_axis,
concat_axis=split_axis,
axis_index_groups=axis_index_groups),)
axis_index_groups=axis_index_groups,
tiled=tiled),)

def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, axis_index_groups):
def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, axis_index_groups,
tiled):
x, = vals_in
d, = dims_in
result = all_to_all_p.bind(
x,
axis_name=axis_name,
split_axis=split_axis + (d <= split_axis),
concat_axis=concat_axis + (d <= concat_axis),
axis_index_groups=axis_index_groups)
axis_index_groups=axis_index_groups,
tiled=tiled,
)
return result, d

def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in,
axis_name, split_axis, concat_axis,
axis_index_groups):
axis_index_groups, tiled):
if axis_index_groups is not None:
raise NotImplementedError("Please open a feature request!")
x, = vals_in
Expand Down Expand Up @@ -1041,7 +1050,8 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in,
if major_axes:
x = all_to_all_p.bind(x, axis_name=major_axes,
split_axis=split_axis, concat_axis=0,
axis_index_groups=axis_index_groups)
axis_index_groups=axis_index_groups,
tiled=tiled)
# Split out the local part into axis new_d (NOTE: d is already in axis 1)
x = _splitaxis(split_axis, axis_size, x)
new_d = split_axis
Expand All @@ -1050,7 +1060,8 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in,
if minor_axes:
x = all_to_all_p.bind(x, axis_name=minor_axes,
split_axis=split_axis, concat_axis=2,
axis_index_groups=axis_index_groups)
axis_index_groups=axis_index_groups,
tiled=tiled)

# Fold the chunk axes into a single one
x = _foldaxis(0, _foldaxis(0, x))
Expand All @@ -1062,8 +1073,9 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in,


def _all_to_all_effectful_abstract_eval(
x, axis_name, split_axis, concat_axis, axis_index_groups
x, axis_name, split_axis, concat_axis, axis_index_groups, tiled
):
del tiled
if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,)
input_aval = raise_to_shaped(x)
Expand Down
34 changes: 34 additions & 0 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,40 @@ def fwd(a):
for j, block in enumerate(np.split(row_block, 2, axis=-1)):
self.assertAllClose(block, c.addressable_data(2 * i + j))

def test_all_to_all_grad(self):
mesh_axes = dict(x=4)
devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))])
mesh = Mesh(
devices.reshape(tuple(mesh_axes.values())),
axis_names=tuple(mesh_axes.keys()),
)
a = jax.device_put(
jnp.arange(8 * 8, dtype=jnp.float32).reshape((8, 8)),
jax.sharding.NamedSharding(mesh, P('x', None)),
)
self.assertEqual(a.addressable_data(0).shape, (2, 8))

@jax.jit
@partial(
shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P(None, 'x')
)
def fwd(x):
return lax.all_to_all(x, 'x', split_axis=1, concat_axis=0, tiled=True)

c = fwd(a)
self.assertEqual(c.addressable_data(0).shape, (8, 2))
self.assertAllClose(a, c)

@jax.jit
@partial(jax.grad, has_aux=True)
def loss_and_grad(x):
loss = fwd(x).sum() * 2
return loss, loss

grad, loss = loss_and_grad(a)
self.assertEqual(loss, 2 * sum(range(64)))
self.assertAllClose(grad, 2 * np.ones_like(a))

def test_eager_repr(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
s = None
Expand Down

0 comments on commit 981ca06

Please sign in to comment.