Skip to content

Commit

Permalink
Pipe tiled through all_to_all primitive
Browse files Browse the repository at this point in the history
Fixes #15982.

PiperOrigin-RevId: 607775078
  • Loading branch information
ppham27 authored and jax authors committed Mar 1, 2024
1 parent 1615e7a commit 5b3257d
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 25 deletions.
100 changes: 75 additions & 25 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,14 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis):
else: # concat_axis < split_axis
x = lax.expand_dims(x, (concat_axis,)) # insert the new axis
split_axis += 1 # we have a new axis before split_axis now
result = all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis,
axis_name=axis_name,
axis_index_groups=axis_index_groups)
result = all_to_all_p.bind(
x,
split_axis=split_axis,
concat_axis=concat_axis,
axis_name=axis_name,
axis_index_groups=axis_index_groups,
tiled=tiled,
)
if not tiled and split_axis != concat_axis:
result = lax.squeeze(result, (split_axis,))
return result
Expand Down Expand Up @@ -954,8 +959,10 @@ def _index_in_group(axis_name, axis_index_groups):
slicing.dynamic_slice_in_dim(device_id_to_idx, cur_device_id, 1), [0])


def _all_to_all_lowering(ctx, x, *,
split_axis, concat_axis, axis_name, axis_index_groups):
def _all_to_all_lowering(
ctx, x, *, split_axis, concat_axis, axis_name, axis_index_groups, tiled
):
del tiled
# Workaround for AllToAll not being implemented on CPU.
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
axis_index_groups)
Expand Down Expand Up @@ -985,28 +992,57 @@ def _all_to_all_lowering(ctx, x, *,
replica_groups=_replica_groups_hlo(replica_groups),
**other_args).results

def _all_to_all_transpose_rule(cts, x, axis_name, split_axis, concat_axis, axis_index_groups):
return (all_to_all(
cts,
axis_name=axis_name,
split_axis=concat_axis,
concat_axis=split_axis,
axis_index_groups=axis_index_groups),)

def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, axis_index_groups):
def _all_to_all_transpose_rule(
cts, x, axis_name, split_axis, concat_axis, axis_index_groups, tiled
):
return (
all_to_all(
cts,
axis_name=axis_name,
split_axis=concat_axis,
concat_axis=split_axis,
axis_index_groups=axis_index_groups,
tiled=tiled,
),
)


def _all_to_all_batcher(
vals_in,
dims_in,
*,
axis_name,
split_axis,
concat_axis,
axis_index_groups,
tiled,
):
x, = vals_in
d, = dims_in
result = all_to_all_p.bind(
x,
axis_name=axis_name,
split_axis=split_axis + (d <= split_axis),
concat_axis=concat_axis + (d <= concat_axis),
axis_index_groups=axis_index_groups)
axis_index_groups=axis_index_groups,
tiled=tiled,
)
return result, d

def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in,
axis_name, split_axis, concat_axis,
axis_index_groups):

def _all_to_all_batched_collective(
axis_size,
frame_name,
_,
vals_in,
dims_in,
axis_name,
split_axis,
concat_axis,
axis_index_groups,
tiled,
):
if axis_index_groups is not None:
raise NotImplementedError("Please open a feature request!")
x, = vals_in
Expand Down Expand Up @@ -1039,18 +1075,28 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in,
split_axis += 3; concat_axis += 3 # Offset by extra three leading dims

if major_axes:
x = all_to_all_p.bind(x, axis_name=major_axes,
split_axis=split_axis, concat_axis=0,
axis_index_groups=axis_index_groups)
x = all_to_all_p.bind(
x,
axis_name=major_axes,
split_axis=split_axis,
concat_axis=0,
axis_index_groups=axis_index_groups,
tiled=tiled,
)
# Split out the local part into axis new_d (NOTE: d is already in axis 1)
x = _splitaxis(split_axis, axis_size, x)
new_d = split_axis
concat_axis += (split_axis <= concat_axis) # Offset the existing axes by the new batch axis
split_axis += 1
if minor_axes:
x = all_to_all_p.bind(x, axis_name=minor_axes,
split_axis=split_axis, concat_axis=2,
axis_index_groups=axis_index_groups)
x = all_to_all_p.bind(
x,
axis_name=minor_axes,
split_axis=split_axis,
concat_axis=2,
axis_index_groups=axis_index_groups,
tiled=tiled,
)

# Fold the chunk axes into a single one
x = _foldaxis(0, _foldaxis(0, x))
Expand All @@ -1060,7 +1106,11 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in,
new_d -= 1 # We've removed 0th dimension, so new_d needs to be adjusted
return x, new_d

def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_groups):

def _all_to_all_abstract_eval(
x, axis_name, split_axis, concat_axis, axis_index_groups, tiled
):
del tiled
input_aval = raise_to_shaped(x)
shape = list(input_aval.shape)
axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0])
Expand All @@ -1069,6 +1119,7 @@ def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_
shape[concat_axis] *= axis_size
return input_aval.update(shape=tuple(shape), weak_type=False)


all_to_all_p = core.AxisPrimitive('all_to_all')
all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval)
mlir.register_lowering(all_to_all_p, _all_to_all_lowering)
Expand Down Expand Up @@ -1327,7 +1378,6 @@ def _reduce_scatter_lowering(
return [hlo.reshape(mlir.aval_to_ir_type(aval_out), op.result)]



def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
axis_index_groups, axis_size, tiled):
if not isinstance(axis_name, (list, tuple)):
Expand Down
34 changes: 34 additions & 0 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,40 @@ def fwd(a):
for j, block in enumerate(np.split(row_block, 2, axis=-1)):
self.assertAllClose(block, c.addressable_data(2 * i + j))

def test_all_to_all_grad(self):
mesh_axes = dict(x=4)
devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))])
mesh = Mesh(
devices.reshape(tuple(mesh_axes.values())),
axis_names=tuple(mesh_axes.keys()),
)
a = jax.device_put(
jnp.arange(8 * 8, dtype=jnp.float32).reshape((8, 8)),
jax.sharding.NamedSharding(mesh, P('x', None)),
)
self.assertEqual(a.addressable_data(0).shape, (2, 8))

@jax.jit
@partial(
shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P(None, 'x')
)
def fwd(x):
return lax.all_to_all(x, 'x', split_axis=1, concat_axis=0, tiled=True)

c = fwd(a)
self.assertEqual(c.addressable_data(0).shape, (8, 2))
self.assertAllClose(a, c)

@jax.jit
@partial(jax.grad, has_aux=True)
def loss_and_grad(x):
loss = fwd(x).sum() * 2
return loss, loss

grad, loss = loss_and_grad(a)
self.assertEqual(loss, 2 * sum(range(64)))
self.assertAllClose(grad, 2 * np.ones_like(a))

def test_eager_repr(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
s = None
Expand Down

0 comments on commit 5b3257d

Please sign in to comment.