Skip to content

Commit

Permalink
[Pallas:MGPU] Add support for debug_print of arrays that use the WGMM…
Browse files Browse the repository at this point in the history
…A layout

PiperOrigin-RevId: 686885229
  • Loading branch information
apaszke authored and Google-ML-Automation committed Oct 17, 2024
1 parent ef361f0 commit f72376a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
49 changes: 37 additions & 12 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,40 @@ def can_broadcast_to(self, shape) -> bool:
"""
return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1]))

def thread_idxs(self, shape):
assert shape == self.shape
raise NotImplementedError


@dataclasses.dataclass(frozen=True)
class WGMMAFragLayout:
"""[m, n] matrix, where m % 64 == 0 == n % 8."""

def thread_idxs(self, shape):
index = ir.IndexType.get()
assert shape[0] % 64 == 0 and shape[1] % 8 == 0
tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx())
tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index))
warp_idx = arith.divui(tid_wg, c(32, index))
tid_warp = arith.remui(tid_wg, c(32, index))
col_base = arith.muli(arith.remui(tid_warp, c(4, index)), c(2, index))
row_base = arith.addi(
arith.divui(tid_warp, c(4, index)), arith.muli(warp_idx, c(16, index))
)
for row_group in range(0, shape[0], 64):
for col_group in range(0, shape[1], 8):
for row_subgroup in range(0, 16, 8):
row = arith.addi(row_base, c(row_group + row_subgroup, index))
yield row, arith.addi(col_base, c(col_group, index))


@dataclasses.dataclass(frozen=True)
class WGMMARowFragLayout:
"""[m] matrix, where m % 64 == 0."""

def thread_idxs(self, shape):
raise NotImplementedError


@dataclasses.dataclass(frozen=True)
class WGStridedFragLayout:
Expand Down Expand Up @@ -110,9 +134,10 @@ def from_memref_type(cls, memref_ty: ir.Type):
shape=tuple(memref_type.shape), vec_size=min(8 // bw, max_vec_size)
)

def thread_vec_idxs(self):
def thread_idxs(self, shape):
assert shape == self.shape
index = ir.IndexType.get()
for v in self.linear_thread_vec_idxs():
for v in self.linear_thread_idxs():
res = []
for dim in reversed(self.shape):
dim = c(dim, index)
Expand All @@ -121,7 +146,7 @@ def thread_vec_idxs(self):
res.reverse()
yield res

def linear_thread_vec_idxs(self):
def linear_thread_idxs(self):
"""The indexes to be used for vector load/store WGStridedFragLayout.
Yields:
Expand Down Expand Up @@ -214,16 +239,17 @@ def load_strided(cls, ref: ir.Value, *, is_signed: bool | None = None):
raise TypeError(ref.type)

ref_ty = ir.MemRefType(ref.type)
shape = tuple(ref_ty.shape)
layout = WGStridedFragLayout.from_memref_type(ref_ty)
vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
try:
# Flattening the reference potentially produces simpler PTX but
# if the ref is not already 1D and has strided dimensions
# flattening won't work.
ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
vecs = [vector.load(vec_ty, ref_, [vec_idx]) for vec_idx in layout.linear_thread_vec_idxs()]
vecs = [vector.load(vec_ty, ref_, [vec_idx]) for vec_idx in layout.linear_thread_idxs()]
except NotImplementedError:
vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_vec_idxs()]
vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)]
return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed)

@classmethod
Expand Down Expand Up @@ -821,12 +847,11 @@ def select(self, on_true, on_false):

def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]):
"""Call a function for each value and index."""
if not isinstance(self.layout, WGStridedFragLayout):
raise NotImplementedError(self.layout)
index = ir.IndexType.get()
for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat):
for idx, reg in zip(self.layout.thread_idxs(self.shape), self.registers.flat, strict=True):
assert len(idx) == len(self.shape), (idx, self.shape)
for i in range(self.layout.vec_size):
[elems] = ir.VectorType(reg.type).shape
for i in range(elems):
i = c(i, index)
fn(vector.extractelement(reg, position=i), (*idx[:-1], arith.addi(idx[-1], i)))

Expand Down Expand Up @@ -867,12 +892,12 @@ def _store_untiled_wg_strided(self, ref: ir.Value):
# if the ref is not already 1D and has strided dimensions
# flattening won't work. We use a different variable for ref in
# case `NotImplementedError` is thrown by
# .linear_thread_vec_idxs().
# .linear_thread_idxs().
ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
idxs = ([i] for i in self.layout.linear_thread_vec_idxs())
idxs = ([i] for i in self.layout.linear_thread_idxs())
except NotImplementedError:
ref_ = ref
idxs = self.layout.thread_vec_idxs()
idxs = self.layout.thread_idxs()
ref_shape = tuple(ref_ty.shape)
if ref_shape != self.shape:
raise ValueError((ref_shape, self.shape))
Expand Down
20 changes: 20 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import functools
import math
import re
import traceback

from absl.testing import absltest
Expand Down Expand Up @@ -346,6 +347,25 @@ def kernel(x_ref, o_ref):

self.assertEqual(output(), "It works!\n")

def test_print_wgmma_tiled_layout(self):
shape = (128, 64)
size = math.prod(shape)
def kernel(x_ref, o_ref):
pl.debug_print("{}", x_ref[...])
spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)))
x = jnp.arange(size, dtype=jnp.float32).reshape(shape)
f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec)

with jtu.capture_stdout() as get_output:
jax.block_until_ready(f(x))

output = get_output()
results = re.findall(r"\[(\d+), (\d+)\]/\[128, 64\]: (\d+)", output)
self.assertLen(results, size)
for i, j, v in results:
i, j, v = map(int, (i, j, v))
self.assertEqual(v, i * shape[1] + j)

def test_print_scalar(self):
@functools.partial(
pl.pallas_call,
Expand Down

0 comments on commit f72376a

Please sign in to comment.