Skip to content

Commit

Permalink
[Pallas/MGPU] Undo transforms on refs before giving them back to the …
Browse files Browse the repository at this point in the history
…users

This changes makes it so that the refs users receive inside their kernels have shapes
matching their block specs. However, the refs are not actually plain refs, but transformed
references that begin with the fully transformed abstract ref and then stack the inverse
of the transformation stack on top of it. This means that all primitives that take in refs
can also see the sequence of transforms the user applied in the block spec, which lets us
verify e.g. that the inputs to WGMMA are correctly tiled, even though their user-visible
shape remains 2D. We should be able to use the same trick in the future to propagate tiling
and better infer the layouts for loads and stores.

PiperOrigin-RevId: 680520185
  • Loading branch information
apaszke authored and Google-ML-Automation committed Sep 30, 2024
1 parent 38d2a57 commit 21fea5b
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 81 deletions.
32 changes: 22 additions & 10 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
from jax._src.state.types import TransformedRef
import jax.numpy as jnp


Expand Down Expand Up @@ -496,7 +497,7 @@ def to_block_mapping(

mapping = BlockMapping(
block_shape=mapped_block_shape,
block_aval=block_aval,
transformed_block_aval=block_aval, # There are no transforms by default
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
index_map_src_info=index_map_src_info,
indexing_mode=self.indexing_mode,
Expand All @@ -523,7 +524,7 @@ def __repr__(self):
class MemoryRefTransform(Protocol):
"""Transforms a memory reference on load or store."""

def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef:
def undo(self, ref: TransformedRef) -> TransformedRef:
raise NotImplementedError("Abstract evaluation not implemented.")


Expand All @@ -533,8 +534,10 @@ class BlockMapping:
See the `check_invariants` method for precise specification.
"""
# TODO(apaszke,sharadmv): Replace mapped dims in block_shape with a transform.
# After all, it's just indexing out singleton dimensions.
block_shape: tuple[Mapped | int, ...]
block_aval: AbstractMemoryRef # The block ref aval
transformed_block_aval: AbstractMemoryRef
index_map_jaxpr: jax_core.ClosedJaxpr
index_map_src_info: NameAndSrcInfo
indexing_mode: IndexingMode
Expand All @@ -546,8 +549,8 @@ def check_invariants(self) -> None:
if not config.enable_checks.value: return

unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped)
assert unmapped_block_shape == self.block_aval.shape, (
self.block_shape, self.block_aval)
assert unmapped_block_shape == self.ref_aval.shape, (
self.block_shape, self.ref_aval.shape)
assert len(self.block_shape) == len(self.array_shape_dtype.shape), (
self.block_shape, self.array_shape_dtype
)
Expand All @@ -568,12 +571,21 @@ def replace(self, **kwargs):
return new_self

@property
def ref_aval(self) -> AbstractMemoryRef:
def block_aval(self) -> AbstractMemoryRef:
# If you hit this, make sure you take transforms into account and use either
# ref_aval or transformed_block_aval.
assert not self.transforms, "Lowering failed to handle transforms"
return self.transformed_block_aval

@property
def ref_aval(self) -> AbstractMemoryRef | TransformedRef:
"""Returns the abstract value of the Ref after transformations."""
block_aval = self.block_aval
for transform in self.transforms:
block_aval = transform(block_aval)
return block_aval
if not self.transforms:
return self.transformed_block_aval
ref = TransformedRef(self.transformed_block_aval, ())
for transform in reversed(self.transforms):
ref = transform.undo(ref)
return ref

def compute_start_indices_interpret(self, loop_idx, *args):
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
Expand Down
113 changes: 80 additions & 33 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

"""Contains GPU-specific Pallas abstractions."""

import abc
from collections.abc import Sequence
import dataclasses
import enum
from typing import Any, ClassVar, Literal, Protocol
from typing import Any, ClassVar, Literal

from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import tree_util
from jax._src.state.types import Transform
from jax._src.pallas import core as pallas_core
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp
Expand Down Expand Up @@ -63,9 +65,15 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
return pallas_core.MemoryRef(shape, dtype, memory_space=self)


class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol):
class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC):
@abc.abstractmethod
def to_gpu_transform(self) -> mgpu.MemRefTransform:
...
pass

def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
return aval.update(
shape=self.to_gpu_transform().transform_shape(aval.shape)
)


@dataclasses.dataclass(frozen=True)
Expand All @@ -79,52 +87,86 @@ class TilingTransform(MemoryRefTransform):

tiling: tuple[int, ...]

def __call__(
self, block_aval: pallas_core.AbstractMemoryRef
) -> pallas_core.AbstractMemoryRef:
block_shape = block_aval.shape
old_tiled_dims = block_shape[-len(self.tiling) :]
num_tiles = tuple(
block_dim // tiling_dim
for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling)
)
rem = (
block_dim % tiling_dim
for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling)
)
if any(rem):
raise ValueError(
f"Block shape {block_shape} is not divisible by tiling {self.tiling}"
)
new_block_shape = block_shape[: -len(self.tiling)] + num_tiles + self.tiling
return block_aval.update(
inner_aval=block_aval.inner_aval.update(shape=new_block_shape)
def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
return dataclasses.replace(
ref, transforms=(*ref.transforms, UntileRef(self.tiling))
)

def to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TileTransform(self.tiling)


@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class UntileRef(Transform):

tiling: tuple[int, ...]

def transform_shape(self, shape):
if shape is None:
return None
assert shape[-len(self.tiling) :] == self.tiling
shape = shape[: -len(self.tiling)] # Drop tiling
return shape[: -len(self.tiling)] + tuple(
block_dim * tiling_dim
for block_dim, tiling_dim in zip(shape[-len(self.tiling) :], self.tiling)
)

def transform_dtype(self, dtype):
return dtype

def tree_flatten(self):
return (), (self.tiling,)

@classmethod
def tree_unflatten(cls, metadata, arrays):
assert not arrays
return cls(*metadata)


@dataclasses.dataclass(frozen=True)
class TransposeTransform(MemoryRefTransform):
"""Transpose a tiled memref."""

permutation: tuple[int, ...]

def __call__(
self, block_aval: pallas_core.AbstractMemoryRef
) -> pallas_core.AbstractMemoryRef:
shape = block_aval.shape # pytype: disable=attribute-error
return block_aval.update(
inner_aval=block_aval.inner_aval.update(
shape=self.to_gpu_transform().transform_shape(shape)
)
def __post_init__(self):
if set(self.permutation) != set(range(len(self.permutation))):
raise ValueError(f"Permutation {self.permutation} is not a permutation.")

def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
inverse = [-1] * len(self.permutation)
for i, p in enumerate(self.permutation):
inverse[p] = i
return dataclasses.replace(
ref, transforms=(*ref.transforms, TransposeRef(tuple(inverse)))
)

def to_gpu_transform(self) -> mgpu.MemRefTransform:
return mgpu.TransposeTransform(self.permutation)


@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class TransposeRef(Transform):
permutation: tuple[int, ...]

def transform_shape(self, shape):
if shape is None:
return None
return tuple(shape[i] for i in self.permutation)

def transform_dtype(self, dtype):
return dtype

def tree_flatten(self):
return (), (self.permutation,)

@classmethod
def tree_unflatten(cls, metadata, arrays):
assert not arrays
return cls(*metadata)


@dataclasses.dataclass(frozen=True)
class GPUBlockMapping(pallas_core.BlockMapping):
swizzle: int | None = None
Expand Down Expand Up @@ -156,9 +198,14 @@ def to_block_mapping(
transforms = self.transforms
if not isinstance(transforms, tuple):
transforms = (transforms,)
block_inner_aval = bm.block_aval.inner_aval
for t in transforms:
block_inner_aval = t(block_inner_aval)
return GPUBlockMapping(
block_shape=bm.block_shape,
block_aval=bm.block_aval,
transformed_block_aval=bm.block_aval.update(
inner_aval=block_inner_aval
),
origin=bm.origin,
index_map_jaxpr=bm.index_map_jaxpr,
index_map_src_info=bm.index_map_src_info,
Expand Down
11 changes: 9 additions & 2 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def lower_jaxpr_to_module(

in_in_smem, out_in_smem = util.split_list(
[
bm.block_aval.memory_space in (None, gpu_core.SMEM)
bm.transformed_block_aval.memory_space in (None, gpu_core.SMEM)
for bm in block_mappings
],
[grid_mapping.num_inputs],
Expand All @@ -295,9 +295,13 @@ def lower_jaxpr_to_module(
in_block_mappings, out_block_mappings = util.split_list(
block_mappings, [grid_mapping.num_inputs]
)
# TODO(apaszke): We can shrink allocation if max_concurrent_steps is more than the actual number of steps.
# We allocate the fully transformed shapes here. All primitives have seen the
# inverse transformation stack and will understand how to handle it.
in_structs_smem = [
jax.ShapeDtypeStruct(
[max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype
[max_concurrent_steps, *bm.transformed_block_aval.shape],
bm.transformed_block_aval.dtype,
)
if in_smem
else None
Expand All @@ -317,6 +321,9 @@ def lower_jaxpr_to_module(
)
out_structs_gmem = [*grid_mapping.out_shapes]
# TODO(justinfu): Implement output Memref transforms
for bm in block_mappings[grid_mapping.num_inputs :]:
if bm.transforms:
raise NotImplementedError("Output transforms are not supported")
out_structs_smem = [
jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype)
if in_smem
Expand Down
53 changes: 39 additions & 14 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class _WGMMAPipelineEffect(effects.Effect):
wgmma_ref_p = jax_core.Primitive("wgmma_ref")
wgmma_ref_p.multiple_results = True

def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128):
def wgmma(acc, a, b, *, swizzle: int = 128):
"""Asynchronous warp group matmul.
The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires
Expand All @@ -129,24 +129,49 @@ def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128):
acc: The accumulator register.
a: The left hand side operand.
b: The right hand side operand.
transpose: Whether to transpose b.
n_tile: The number of tiles to use.
swizzle: The swizzle pattern.
"""
if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef):
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}")

ma, ka, tma, tka = a.shape
kb, nb, tkb, tnb = b.shape
mc, nc = acc.shape

if rhs_transpose:
kb, nb, tkb, tnb = nb, kb, tnb, tkb

if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb:
raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}")

return wgmma_ref_p.bind(acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose)
# TODO(apaszke): Make swizzling another transform and read it from the refs.
if not isinstance(a, pallas_core.TransformedRef):
raise ValueError("WGMMA inputs must be tiled references.")

m, n = acc.shape
m2, k = a.shape
k2, n2 = b.shape

if m != m2 or n != n2 or k != k2:
raise ValueError(
f"Incompatible shapes for matrix multiplication: lhs={a.shape},"
f" rhs={b.shape=}, acc={acc.shape}"
)

if (dtype := a.dtype) != b.dtype:
raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}")
if not isinstance(a, pallas_core.TransformedRef):
raise ValueError("WGMMA lhs must be a tiled reference.")
if not isinstance(b, pallas_core.TransformedRef):
raise ValueError("WGMMA rhs must be a tiled reference.")

elems_128b = swizzle // dtype.itemsize
if a.transforms != (gpu_core.UntileRef((64, elems_128b)),):
raise ValueError(
f"WGMMA lhs must be tiled with 64x{elems_128b} tiles for element type"
f" {dtype}."
)
rhs_transpose_transform = gpu_core.TransposeRef((1, 0, 2, 3))
rhs_tiling = gpu_core.UntileRef((elems_128b, elems_128b))
if not (
rhs_transpose := (b.transforms == (rhs_transpose_transform, rhs_tiling))
) and not (b.transforms == (rhs_tiling,)):
raise ValueError(
f"WGMMA rhs must be tiled with {elems_128b}x{elems_128b} tiles for"
f" element type {dtype} (and optionally transposed)."
)

return wgmma_ref_p.bind(acc, a.ref, b.ref, swizzle=swizzle, rhs_transpose=rhs_transpose)


@wgmma_ref_p.def_effectful_abstract_eval
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/state/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,10 @@ def get_indexer_shape(self) -> tuple[int | Array, ...]:
# In NDIndexers, the int_indexer_shape is *always* at the front of the
# result.
return (*self.int_indexer_shape, *slice_shape)

def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]:
del shape # Unused
return self.get_indexer_shape()

def transform_dtype(self, dtype):
return dtype
Loading

0 comments on commit 21fea5b

Please sign in to comment.