Skip to content

Commit

Permalink
[Pallas:MGPU] Fix the implementation of WGMMA with transposed RHS
Browse files Browse the repository at this point in the history
It's not enough that we have the physical transpose between the order
of tiled dimensions, we also need the user to explicitly transpose the
logical dimensions. This fixes a shape error that was previously hidden
because the RHS was square.

PiperOrigin-RevId: 687261078
  • Loading branch information
apaszke authored and Google-ML-Automation committed Oct 18, 2024
1 parent ade480f commit 5116ed3
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
13 changes: 13 additions & 0 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,19 @@ def tree_unflatten(cls, metadata, arrays):
return cls(*metadata)


def transpose_ref(
ref: pallas_core.TransformedRef | pallas_core.AbstractMemoryRef,
permutation: tuple[int, ...],
) -> pallas_core.TransformedRef:
if not isinstance(ref, pallas_core.TransformedRef):
if not isinstance(ref, pallas_core.AbstractMemoryRef):
raise TypeError("ref must be a reference")
ref = pallas_core.TransformedRef(ref, transforms=())
return pallas_core.TransformedRef(
ref.ref, (*ref.transforms, TransposeRef(permutation)),
)


@dataclasses.dataclass(frozen=True)
class SwizzleTransform(MemoryRefTransform):
swizzle: int
Expand Down
9 changes: 6 additions & 3 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,12 @@ def _wgmma_lowering(
match b_transforms:
case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)):
rhs_transpose = False
# TODO(apaszke): Actually what we really want to test here is that we're
# only doing transposes within the tiles!
case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.TransposeRef((1, 0, 2, 3)), gpu_core.UntileRef(rhs_tiling)):
case (
gpu_core.UnswizzleRef(rhs_swizzle),
gpu_core.TransposeRef((1, 0, 2, 3)), # Only transpose between tiles
gpu_core.UntileRef(rhs_tiling),
gpu_core.TransposeRef((1, 0)), # Transpose the two logical dims
):
rhs_transpose = True
case _:
raise ValueError(f"WGMMA rhs has unsupported transforms: {b_transforms}.")
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/pallas/mosaic_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh
from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform
from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform
from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref
from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
Expand Down
15 changes: 10 additions & 5 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,15 +639,20 @@ def test_wgmma(self, dtype):
swizzle = 128
elems_128b = swizzle // jnp.dtype(dtype).itemsize
def kernel(a_ref, b_ref, o_ref):
if rhs_transpose:
b_ref = plgpu.transpose_ref(b_ref, (1, 0))
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref, b_ref)
return acc_ref[...]

o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32))
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))

key1, key2 = jax.random.split(jax.random.key(42), 2)
a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype)
b = jax.random.uniform(key2, shape=(128, 128), dtype=dtype)
b_shape = (128, 192)
if rhs_transpose:
b_shape = b_shape[::-1]
b = jax.random.uniform(key2, shape=b_shape, dtype=dtype)

rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),)
if rhs_transpose:
Expand All @@ -664,13 +669,13 @@ def scope(acc_ref):
),
),
plgpu.GPUBlockSpec(
(128, 128),
b_shape,
lambda *i: i,
transforms=(*rhs_transforms, plgpu.SwizzleTransform(128)),
),
],
out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i),
out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32),
out_specs=plgpu.GPUBlockSpec((64, 192), lambda *i: i),
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
grid=(1, 1),
)(a, b)
np.testing.assert_allclose(
Expand Down

0 comments on commit 5116ed3

Please sign in to comment.