From 5116ed37b8f9f7473998c68b689577b995a951b6 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 18 Oct 2024 05:04:20 -0700 Subject: [PATCH] [Pallas:MGPU] Fix the implementation of WGMMA with transposed RHS It's not enough that we have the physical transpose between the order of tiled dimensions, we also need the user to explicitly transpose the logical dimensions. This fixes a shape error that was previously hidden because the RHS was square. PiperOrigin-RevId: 687261078 --- jax/_src/pallas/mosaic_gpu/core.py | 13 +++++++++++++ jax/_src/pallas/mosaic_gpu/primitives.py | 9 ++++++--- jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 15 ++++++++++----- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 978ce8dd9051..c44a27e56c29 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -225,6 +225,19 @@ def tree_unflatten(cls, metadata, arrays): return cls(*metadata) +def transpose_ref( + ref: pallas_core.TransformedRef | pallas_core.AbstractMemoryRef, + permutation: tuple[int, ...], +) -> pallas_core.TransformedRef: + if not isinstance(ref, pallas_core.TransformedRef): + if not isinstance(ref, pallas_core.AbstractMemoryRef): + raise TypeError("ref must be a reference") + ref = pallas_core.TransformedRef(ref, transforms=()) + return pallas_core.TransformedRef( + ref.ref, (*ref.transforms, TransposeRef(permutation)), + ) + + @dataclasses.dataclass(frozen=True) class SwizzleTransform(MemoryRefTransform): swizzle: int diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 19545b11f97a..def7f84944ba 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -464,9 +464,12 @@ def _wgmma_lowering( match b_transforms: case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)): rhs_transpose = False - # TODO(apaszke): Actually what we really want to test here is that we're - # only doing transposes within the tiles! - case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.TransposeRef((1, 0, 2, 3)), gpu_core.UntileRef(rhs_tiling)): + case ( + gpu_core.UnswizzleRef(rhs_swizzle), + gpu_core.TransposeRef((1, 0, 2, 3)), # Only transpose between tiles + gpu_core.UntileRef(rhs_tiling), + gpu_core.TransposeRef((1, 0)), # Transpose the two logical dims + ): rhs_transpose = True case _: raise ValueError(f"WGMMA rhs has unsupported transforms: {b_transforms}.") diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 3777f683cf28..fbb3a3857c68 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -24,6 +24,7 @@ from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform +from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6ab4bf8df145..de66929ff203 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -639,15 +639,20 @@ def test_wgmma(self, dtype): swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize def kernel(a_ref, b_ref, o_ref): + if rhs_transpose: + b_ref = plgpu.transpose_ref(b_ref, (1, 0)) def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref, b_ref) return acc_ref[...] - o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32)) + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32)) key1, key2 = jax.random.split(jax.random.key(42), 2) a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype) - b = jax.random.uniform(key2, shape=(128, 128), dtype=dtype) + b_shape = (128, 192) + if rhs_transpose: + b_shape = b_shape[::-1] + b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),) if rhs_transpose: @@ -664,13 +669,13 @@ def scope(acc_ref): ), ), plgpu.GPUBlockSpec( - (128, 128), + b_shape, lambda *i: i, transforms=(*rhs_transforms, plgpu.SwizzleTransform(128)), ), ], - out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i), - out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32), + out_specs=plgpu.GPUBlockSpec((64, 192), lambda *i: i), + out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), grid=(1, 1), )(a, b) np.testing.assert_allclose(