From 905c83c781f5c6d6a3fa2a315709772012e8c97e Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 3 Oct 2024 06:47:20 -0700 Subject: [PATCH] [pallas:mosaic_gpu] Support indexing barriers A barrier must be indexed via `.at` and not directly. I wish we could emit an instructive error for the latter case, but I couldn't find a good place to put it. PiperOrigin-RevId: 681857034 --- jax/_src/pallas/mosaic_gpu/primitives.py | 121 ++++++++++++++++++----- tests/pallas/mosaic_gpu_test.py | 21 ++++ 2 files changed, 116 insertions(+), 26 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index fb04c3f4d827..1a9995ee1d91 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -57,14 +57,14 @@ def _copy_smem_to_gmem_lowering( src = lowering._handle_indexing( src, src_transforms_treedef.unflatten(flat_src_transforms) ) - copy_params = parse_copy_params( + copy_params = _extract_copy_params( dst_transforms_treedef.unflatten(flat_dst_transforms) ) ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params) return () -def parse_copy_params(transforms): +def _extract_copy_params(transforms): if not transforms: return {} if any( @@ -132,17 +132,28 @@ def _copy_gmem_to_smem_lowering( *flat_transforms, src_transforms_treedef, dst_transforms_treedef, + barrier_transforms_treedef, ): - flat_src_transforms, flat_dst_transforms = util.split_list( - flat_transforms, - [src_transforms_treedef.num_leaves], + flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( + util.split_list( + flat_transforms, + [ + src_transforms_treedef.num_leaves, + dst_transforms_treedef.num_leaves, + ], + ) ) - copy_params = parse_copy_params( + copy_params = _extract_copy_params( src_transforms_treedef.unflatten(flat_src_transforms) ) dst = lowering._handle_indexing( dst, dst_transforms_treedef.unflatten(flat_dst_transforms) ) + barrier_indexer = _extract_barrier_indexer( + barrier_transforms_treedef.unflatten(flat_barrier_transforms) + ) + if barrier_indexer is not None: + barrier = barrier.__getitem__(*barrier_indexer.indices) ctx.launch_ctx.async_copy( src_ref=src, dst_ref=dst, barrier=barrier, **copy_params ) @@ -171,55 +182,113 @@ def copy_gmem_to_smem( flat_dst_transforms, dst_transforms_treedef = tree_util.tree_flatten( dst_transforms ) + barrier, barrier_transforms = state_primitives.get_ref_and_transforms( + barrier, None, "copy_gmem_to_smem" + ) + flat_barrier_transforms, barrier_transforms_treedef = tree_util.tree_flatten( + barrier_transforms + ) copy_gmem_to_smem_p.bind( src, dst, barrier, *flat_src_transforms, *flat_dst_transforms, + *flat_barrier_transforms, src_transforms_treedef=src_transforms_treedef, dst_transforms_treedef=dst_transforms_treedef, + barrier_transforms_treedef=barrier_transforms_treedef, ) return None +def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None: + if not transforms: + return None + match transforms: + case [indexing.NDIndexer(indices=[idx]) as indexer]: + if not isinstance(idx, indexing.Slice): + return indexer + if indexing.Slice.from_slice(slice(None), *indexer.shape) == idx: + # Special-case: the whole slice. + return None + else: + raise ValueError( + f"Barrier can only be indexed with an integer, got {idx}" + ) + case [indexing.NDIndexer()]: + raise NotImplementedError("Barrier does not support multiple indices") + case []: + return None + case _: + raise ValueError("Barrier does not support arbirary transforms") + + class WaitEffect(jax_core.Effect): ... -wait_effect = WaitEffect() +_wait_effect = WaitEffect() -wait_p = jax_core.Primitive("wait") -wait_p.multiple_results = True +wait_barrier_p = jax_core.Primitive("wait") +wait_barrier_p.multiple_results = True -@wait_p.def_effectful_abstract_eval -def _wait_abstract_eval(*avals, **params): +@wait_barrier_p.def_effectful_abstract_eval +def _wait_barrier_abstract_eval(*avals, **params): del avals, params # Unused. - return (), {wait_effect} + return (), {_wait_effect} -@lowering.register_lowering_rule(wait_p) -def _wait_lowering_rule( - ctx: lowering.LoweringRuleContext, barrier=None, allow_groups=None, +@lowering.register_lowering_rule(wait_barrier_p) +def _wait_barrier_lowering_rule( + ctx: lowering.LoweringRuleContext, + barrier, + *flat_transforms, + transforms_treedef, ): - if barrier is not None: - barrier.wait() - else: - assert allow_groups is not None - ctx.launch_ctx.await_async_copy(allow_groups=allow_groups) + del ctx # Unused. + transforms = transforms_treedef.unflatten(flat_transforms) + indexer = _extract_barrier_indexer(transforms) + if indexer is not None: + barrier = barrier.__getitem__(*indexer.indices) + barrier.wait() return () -def wait_smem_to_gmem(allow_groups: int) -> None: - """Waits until there are no more than the given number of SMEM->GMEM copies in flight.""" - wait_p.bind(allow_groups=allow_groups) - - def wait_barrier(barrier: pallas_core.AbstractMemoryRef) -> None: """Waits on the given barrier.""" - wait_p.bind(barrier) + barrier, transforms = state_primitives.get_ref_and_transforms( + barrier, None, "wait_barrier" + ) + flat_transforms, transforms_treedef = tree_util.tree_flatten(transforms) + wait_barrier_p.bind( + barrier, *flat_transforms, transforms_treedef=transforms_treedef + ) + + +wait_smem_to_gmem_p = jax_core.Primitive("wait_smem_to_gmem") +wait_smem_to_gmem_p.multiple_results = True + + +@wait_smem_to_gmem_p.def_effectful_abstract_eval +def _wait_smem_to_gmem_abstract_eval(*, allow_groups): + del allow_groups # Unused. + return (), {_wait_effect} + + +@lowering.register_lowering_rule(wait_smem_to_gmem_p) +def _wait_smem_to_gmem_lowering_rule( + ctx: lowering.LoweringRuleContext, allow_groups +): + ctx.launch_ctx.await_async_copy(allow_groups=allow_groups) + return () + + +def wait_smem_to_gmem(allow_groups: int) -> None: + """Waits until there are no more than the given number of SMEM->GMEM copies in flight.""" + wait_smem_to_gmem_p.bind(allow_groups=allow_groups) class _WGMMAPipelineEffect(effects.Effect): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 353c01a4e337..1eb62b9f9e55 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -201,6 +201,27 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0) + @parameterized.product(indexer=[0, 1, 2, 3]) + def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM((128,), jnp.float32), + plgpu.Barrier(num_arrivals=1, num_barriers=4), + ], + ) + def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): + plgpu.copy_gmem_to_smem( + x_ref_gmem, scratch_ref, barrier=barrier_ref.at[indexer] + ) + plgpu.wait_barrier(barrier_ref.at[indexer]) + o_ref[...] = scratch_ref[...] + 1 + + x = jnp.arange(128).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + def test_add_doubled_sum(self): @functools.partial( pl.pallas_call,