diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 978ce8dd9051..c44a27e56c29 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -225,6 +225,19 @@ def tree_unflatten(cls, metadata, arrays): return cls(*metadata) +def transpose_ref( + ref: pallas_core.TransformedRef | pallas_core.AbstractMemoryRef, + permutation: tuple[int, ...], +) -> pallas_core.TransformedRef: + if not isinstance(ref, pallas_core.TransformedRef): + if not isinstance(ref, pallas_core.AbstractMemoryRef): + raise TypeError("ref must be a reference") + ref = pallas_core.TransformedRef(ref, transforms=()) + return pallas_core.TransformedRef( + ref.ref, (*ref.transforms, TransposeRef(permutation)), + ) + + @dataclasses.dataclass(frozen=True) class SwizzleTransform(MemoryRefTransform): swizzle: int diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 19545b11f97a..def7f84944ba 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -464,9 +464,12 @@ def _wgmma_lowering( match b_transforms: case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)): rhs_transpose = False - # TODO(apaszke): Actually what we really want to test here is that we're - # only doing transposes within the tiles! - case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.TransposeRef((1, 0, 2, 3)), gpu_core.UntileRef(rhs_tiling)): + case ( + gpu_core.UnswizzleRef(rhs_swizzle), + gpu_core.TransposeRef((1, 0, 2, 3)), # Only transpose between tiles + gpu_core.UntileRef(rhs_tiling), + gpu_core.TransposeRef((1, 0)), # Transpose the two logical dims + ): rhs_transpose = True case _: raise ValueError(f"WGMMA rhs has unsupported transforms: {b_transforms}.") diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 3777f683cf28..fbb3a3857c68 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -24,6 +24,7 @@ from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform +from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6ab4bf8df145..de66929ff203 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -639,15 +639,20 @@ def test_wgmma(self, dtype): swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize def kernel(a_ref, b_ref, o_ref): + if rhs_transpose: + b_ref = plgpu.transpose_ref(b_ref, (1, 0)) def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref, b_ref) return acc_ref[...] - o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32)) + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32)) key1, key2 = jax.random.split(jax.random.key(42), 2) a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype) - b = jax.random.uniform(key2, shape=(128, 128), dtype=dtype) + b_shape = (128, 192) + if rhs_transpose: + b_shape = b_shape[::-1] + b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),) if rhs_transpose: @@ -664,13 +669,13 @@ def scope(acc_ref): ), ), plgpu.GPUBlockSpec( - (128, 128), + b_shape, lambda *i: i, transforms=(*rhs_transforms, plgpu.SwizzleTransform(128)), ), ], - out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i), - out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32), + out_specs=plgpu.GPUBlockSpec((64, 192), lambda *i: i), + out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), grid=(1, 1), )(a, b) np.testing.assert_allclose(