From 9ac67235614abad20309bada7b81ab45d5bcefdd Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Mon, 7 Oct 2024 10:29:50 -0700 Subject: [PATCH] [pallas:mosaic_gpu] Dereferencing the accumulator now supports slicing PiperOrigin-RevId: 683235013 --- jax/_src/pallas/mosaic_gpu/core.py | 10 ++++--- jax/_src/pallas/mosaic_gpu/lowering.py | 10 +++++++ tests/pallas/mosaic_gpu_test.py | 36 ++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 9156811312f4..786f34ce5988 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -369,10 +369,14 @@ def at_least_vspace(self): return _as_accum(super().at_least_vspace()) def _getitem(self, tracer, idx): - if not _is_trivial_index(idx): - raise NotImplementedError(f"Can only dereference accumulators, not slice ({idx=}).") from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error - return wgmma_accumulator_deref(tracer) + arr = wgmma_accumulator_deref(tracer) + + if not _is_trivial_index(idx): + arr = arr[idx] + + return arr + def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: return WGMMAAbstractAccumulatorRef( diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 20ba3793181d..e721d5a841b4 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -815,6 +815,16 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): ) +@register_lowering_rule(lax.slice_p) +def _slice_lowering_rule( + ctx: LoweringRuleContext, x, limit_indices, start_indices, strides +): + if strides is not None: + raise NotImplementedError("Strides are not supported.") + + return x[tuple(slice(b, e) for b, e in zip(start_indices, limit_indices))] + + @register_lowering_rule(lax.select_n_p) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): if len(cases) != 2: diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 78d80b7c990a..4fa459e2895f 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -576,6 +576,42 @@ def scope(acc_ref): res, a @ (b.T if rhs_transpose else b), rtol=1e-3 ) + def test_wgmma_sliced(self): + swizzle = 128 + elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize + def kernel(a_ref, b_ref, o_ref): + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref, b_ref) + return acc_ref[:, :64], acc_ref[:, 64:] + + o_ref[:, :64], o_ref[:, 64:] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32)) + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) + b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16) + res = pl.pallas_call( + kernel, + in_specs=[ + plgpu.GPUBlockSpec( + (64, 128), + lambda i, j: (i, j), + transforms=plgpu.TilingTransform((64, elems_128b)), + swizzle=128, + ), + plgpu.GPUBlockSpec( + (128, 128), + lambda *i: i, + transforms=plgpu.TilingTransform((elems_128b, elems_128b)), + swizzle=128, + ), + ], + out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i), + out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32), + grid=(1, 1), + )(a, b) + np.testing.assert_allclose(res, a @ b, rtol=1e-3) + + def test_input_output_aliases(self): # Note that we're writing to the input pointer, which should alias b_ptr. def kernel(a_ref, b_ref):