Skip to content

Commit

Permalink
[pallas:mosaic_gpu] Support indexing barriers
Browse files Browse the repository at this point in the history
A barrier must be indexed via `.at` and not directly. I wish we could emit
an instructive error for the latter case, but I couldn't find a good place
to put it.

PiperOrigin-RevId: 681857034
  • Loading branch information
superbobry authored and Google-ML-Automation committed Oct 3, 2024
1 parent 5a2e5a5 commit 905c83c
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 26 deletions.
121 changes: 95 additions & 26 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ def _copy_smem_to_gmem_lowering(
src = lowering._handle_indexing(
src, src_transforms_treedef.unflatten(flat_src_transforms)
)
copy_params = parse_copy_params(
copy_params = _extract_copy_params(
dst_transforms_treedef.unflatten(flat_dst_transforms)
)
ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params)
return ()


def parse_copy_params(transforms):
def _extract_copy_params(transforms):
if not transforms:
return {}
if any(
Expand Down Expand Up @@ -132,17 +132,28 @@ def _copy_gmem_to_smem_lowering(
*flat_transforms,
src_transforms_treedef,
dst_transforms_treedef,
barrier_transforms_treedef,
):
flat_src_transforms, flat_dst_transforms = util.split_list(
flat_transforms,
[src_transforms_treedef.num_leaves],
flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = (
util.split_list(
flat_transforms,
[
src_transforms_treedef.num_leaves,
dst_transforms_treedef.num_leaves,
],
)
)
copy_params = parse_copy_params(
copy_params = _extract_copy_params(
src_transforms_treedef.unflatten(flat_src_transforms)
)
dst = lowering._handle_indexing(
dst, dst_transforms_treedef.unflatten(flat_dst_transforms)
)
barrier_indexer = _extract_barrier_indexer(
barrier_transforms_treedef.unflatten(flat_barrier_transforms)
)
if barrier_indexer is not None:
barrier = barrier.__getitem__(*barrier_indexer.indices)
ctx.launch_ctx.async_copy(
src_ref=src, dst_ref=dst, barrier=barrier, **copy_params
)
Expand Down Expand Up @@ -171,55 +182,113 @@ def copy_gmem_to_smem(
flat_dst_transforms, dst_transforms_treedef = tree_util.tree_flatten(
dst_transforms
)
barrier, barrier_transforms = state_primitives.get_ref_and_transforms(
barrier, None, "copy_gmem_to_smem"
)
flat_barrier_transforms, barrier_transforms_treedef = tree_util.tree_flatten(
barrier_transforms
)
copy_gmem_to_smem_p.bind(
src,
dst,
barrier,
*flat_src_transforms,
*flat_dst_transforms,
*flat_barrier_transforms,
src_transforms_treedef=src_transforms_treedef,
dst_transforms_treedef=dst_transforms_treedef,
barrier_transforms_treedef=barrier_transforms_treedef,
)
return None


def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None:
if not transforms:
return None
match transforms:
case [indexing.NDIndexer(indices=[idx]) as indexer]:
if not isinstance(idx, indexing.Slice):
return indexer
if indexing.Slice.from_slice(slice(None), *indexer.shape) == idx:
# Special-case: the whole slice.
return None
else:
raise ValueError(
f"Barrier can only be indexed with an integer, got {idx}"
)
case [indexing.NDIndexer()]:
raise NotImplementedError("Barrier does not support multiple indices")
case []:
return None
case _:
raise ValueError("Barrier does not support arbirary transforms")


class WaitEffect(jax_core.Effect):
...


wait_effect = WaitEffect()
_wait_effect = WaitEffect()


wait_p = jax_core.Primitive("wait")
wait_p.multiple_results = True
wait_barrier_p = jax_core.Primitive("wait")
wait_barrier_p.multiple_results = True


@wait_p.def_effectful_abstract_eval
def _wait_abstract_eval(*avals, **params):
@wait_barrier_p.def_effectful_abstract_eval
def _wait_barrier_abstract_eval(*avals, **params):
del avals, params # Unused.
return (), {wait_effect}
return (), {_wait_effect}


@lowering.register_lowering_rule(wait_p)
def _wait_lowering_rule(
ctx: lowering.LoweringRuleContext, barrier=None, allow_groups=None,
@lowering.register_lowering_rule(wait_barrier_p)
def _wait_barrier_lowering_rule(
ctx: lowering.LoweringRuleContext,
barrier,
*flat_transforms,
transforms_treedef,
):
if barrier is not None:
barrier.wait()
else:
assert allow_groups is not None
ctx.launch_ctx.await_async_copy(allow_groups=allow_groups)
del ctx # Unused.
transforms = transforms_treedef.unflatten(flat_transforms)
indexer = _extract_barrier_indexer(transforms)
if indexer is not None:
barrier = barrier.__getitem__(*indexer.indices)
barrier.wait()
return ()


def wait_smem_to_gmem(allow_groups: int) -> None:
"""Waits until there are no more than the given number of SMEM->GMEM copies in flight."""
wait_p.bind(allow_groups=allow_groups)


def wait_barrier(barrier: pallas_core.AbstractMemoryRef) -> None:
"""Waits on the given barrier."""
wait_p.bind(barrier)
barrier, transforms = state_primitives.get_ref_and_transforms(
barrier, None, "wait_barrier"
)
flat_transforms, transforms_treedef = tree_util.tree_flatten(transforms)
wait_barrier_p.bind(
barrier, *flat_transforms, transforms_treedef=transforms_treedef
)


wait_smem_to_gmem_p = jax_core.Primitive("wait_smem_to_gmem")
wait_smem_to_gmem_p.multiple_results = True


@wait_smem_to_gmem_p.def_effectful_abstract_eval
def _wait_smem_to_gmem_abstract_eval(*, allow_groups):
del allow_groups # Unused.
return (), {_wait_effect}


@lowering.register_lowering_rule(wait_smem_to_gmem_p)
def _wait_smem_to_gmem_lowering_rule(
ctx: lowering.LoweringRuleContext, allow_groups
):
ctx.launch_ctx.await_async_copy(allow_groups=allow_groups)
return ()


def wait_smem_to_gmem(allow_groups: int) -> None:
"""Waits until there are no more than the given number of SMEM->GMEM copies in flight."""
wait_smem_to_gmem_p.bind(allow_groups=allow_groups)


class _WGMMAPipelineEffect(effects.Effect):
Expand Down
21 changes: 21 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,27 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
x = jnp.arange(256).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0)

@parameterized.product(indexer=[0, 1, 2, 3])
def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),),
scratch_shapes=[
plgpu.SMEM((128,), jnp.float32),
plgpu.Barrier(num_arrivals=1, num_barriers=4),
],
)
def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref):
plgpu.copy_gmem_to_smem(
x_ref_gmem, scratch_ref, barrier=barrier_ref.at[indexer]
)
plgpu.wait_barrier(barrier_ref.at[indexer])
o_ref[...] = scratch_ref[...] + 1

x = jnp.arange(128).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)

def test_add_doubled_sum(self):
@functools.partial(
pl.pallas_call,
Expand Down

0 comments on commit 905c83c

Please sign in to comment.