Skip to content

Commit

Permalink
[pallas:mosaic_gpu] Dereferencing the accumulator now supports slicing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683235013
  • Loading branch information
cperivol authored and Google-ML-Automation committed Oct 7, 2024
1 parent e8cea0d commit 9ac6723
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
10 changes: 7 additions & 3 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,14 @@ def at_least_vspace(self):
return _as_accum(super().at_least_vspace())

def _getitem(self, tracer, idx):
if not _is_trivial_index(idx):
raise NotImplementedError(f"Can only dereference accumulators, not slice ({idx=}).")
from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error
return wgmma_accumulator_deref(tracer)
arr = wgmma_accumulator_deref(tracer)

if not _is_trivial_index(idx):
arr = arr[idx]

return arr


def _as_accum(ref) -> WGMMAAbstractAccumulatorRef:
return WGMMAAbstractAccumulatorRef(
Expand Down
10 changes: 10 additions & 0 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,16 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
)


@register_lowering_rule(lax.slice_p)
def _slice_lowering_rule(
ctx: LoweringRuleContext, x, limit_indices, start_indices, strides
):
if strides is not None:
raise NotImplementedError("Strides are not supported.")

return x[tuple(slice(b, e) for b, e in zip(start_indices, limit_indices))]


@register_lowering_rule(lax.select_n_p)
def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases):
if len(cases) != 2:
Expand Down
36 changes: 36 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,42 @@ def scope(acc_ref):
res, a @ (b.T if rhs_transpose else b), rtol=1e-3
)

def test_wgmma_sliced(self):
swizzle = 128
elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize
def kernel(a_ref, b_ref, o_ref):
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref, b_ref)
return acc_ref[:, :64], acc_ref[:, 64:]

o_ref[:, :64], o_ref[:, 64:] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32))

key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16)
res = pl.pallas_call(
kernel,
in_specs=[
plgpu.GPUBlockSpec(
(64, 128),
lambda i, j: (i, j),
transforms=plgpu.TilingTransform((64, elems_128b)),
swizzle=128,
),
plgpu.GPUBlockSpec(
(128, 128),
lambda *i: i,
transforms=plgpu.TilingTransform((elems_128b, elems_128b)),
swizzle=128,
),
],
out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i),
out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32),
grid=(1, 1),
)(a, b)
np.testing.assert_allclose(res, a @ b, rtol=1e-3)


def test_input_output_aliases(self):
# Note that we're writing to the input pointer, which should alias b_ptr.
def kernel(a_ref, b_ref):
Expand Down

0 comments on commit 9ac6723

Please sign in to comment.