diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 5d68024a9042..ac66342f5928 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -70,16 +70,40 @@ def can_broadcast_to(self, shape) -> bool: """ return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + def thread_idxs(self, shape): + assert shape == self.shape + raise NotImplementedError + @dataclasses.dataclass(frozen=True) class WGMMAFragLayout: """[m, n] matrix, where m % 64 == 0 == n % 8.""" + def thread_idxs(self, shape): + index = ir.IndexType.get() + assert shape[0] % 64 == 0 and shape[1] % 8 == 0 + tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) + tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index)) + warp_idx = arith.divui(tid_wg, c(32, index)) + tid_warp = arith.remui(tid_wg, c(32, index)) + col_base = arith.muli(arith.remui(tid_warp, c(4, index)), c(2, index)) + row_base = arith.addi( + arith.divui(tid_warp, c(4, index)), arith.muli(warp_idx, c(16, index)) + ) + for row_group in range(0, shape[0], 64): + for col_group in range(0, shape[1], 8): + for row_subgroup in range(0, 16, 8): + row = arith.addi(row_base, c(row_group + row_subgroup, index)) + yield row, arith.addi(col_base, c(col_group, index)) + @dataclasses.dataclass(frozen=True) class WGMMARowFragLayout: """[m] matrix, where m % 64 == 0.""" + def thread_idxs(self, shape): + raise NotImplementedError + @dataclasses.dataclass(frozen=True) class WGStridedFragLayout: @@ -110,9 +134,10 @@ def from_memref_type(cls, memref_ty: ir.Type): shape=tuple(memref_type.shape), vec_size=min(8 // bw, max_vec_size) ) - def thread_vec_idxs(self): + def thread_idxs(self, shape): + assert shape == self.shape index = ir.IndexType.get() - for v in self.linear_thread_vec_idxs(): + for v in self.linear_thread_idxs(): res = [] for dim in reversed(self.shape): dim = c(dim, index) @@ -121,7 +146,7 @@ def thread_vec_idxs(self): res.reverse() yield res - def linear_thread_vec_idxs(self): + def linear_thread_idxs(self): """The indexes to be used for vector load/store WGStridedFragLayout. Yields: @@ -214,6 +239,7 @@ def load_strided(cls, ref: ir.Value, *, is_signed: bool | None = None): raise TypeError(ref.type) ref_ty = ir.MemRefType(ref.type) + shape = tuple(ref_ty.shape) layout = WGStridedFragLayout.from_memref_type(ref_ty) vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type) try: @@ -221,9 +247,9 @@ def load_strided(cls, ref: ir.Value, *, is_signed: bool | None = None): # if the ref is not already 1D and has strided dimensions # flattening won't work. ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) - vecs = [vector.load(vec_ty, ref_, [vec_idx]) for vec_idx in layout.linear_thread_vec_idxs()] + vecs = [vector.load(vec_ty, ref_, [vec_idx]) for vec_idx in layout.linear_thread_idxs()] except NotImplementedError: - vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_vec_idxs()] + vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)] return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) @classmethod @@ -821,12 +847,11 @@ def select(self, on_true, on_false): def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]): """Call a function for each value and index.""" - if not isinstance(self.layout, WGStridedFragLayout): - raise NotImplementedError(self.layout) index = ir.IndexType.get() - for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat): + for idx, reg in zip(self.layout.thread_idxs(self.shape), self.registers.flat, strict=True): assert len(idx) == len(self.shape), (idx, self.shape) - for i in range(self.layout.vec_size): + [elems] = ir.VectorType(reg.type).shape + for i in range(elems): i = c(i, index) fn(vector.extractelement(reg, position=i), (*idx[:-1], arith.addi(idx[-1], i))) @@ -867,12 +892,12 @@ def _store_untiled_wg_strided(self, ref: ir.Value): # if the ref is not already 1D and has strided dimensions # flattening won't work. We use a different variable for ref in # case `NotImplementedError` is thrown by - # .linear_thread_vec_idxs(). + # .linear_thread_idxs(). ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) - idxs = ([i] for i in self.layout.linear_thread_vec_idxs()) + idxs = ([i] for i in self.layout.linear_thread_idxs()) except NotImplementedError: ref_ = ref - idxs = self.layout.thread_vec_idxs() + idxs = self.layout.thread_idxs() ref_shape = tuple(ref_ty.shape) if ref_shape != self.shape: raise ValueError((ref_shape, self.shape)) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 450c15741db5..032809210de1 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -14,6 +14,7 @@ import functools import math +import re import traceback from absl.testing import absltest @@ -346,6 +347,25 @@ def kernel(x_ref, o_ref): self.assertEqual(output(), "It works!\n") + def test_print_wgmma_tiled_layout(self): + shape = (128, 64) + size = math.prod(shape) + def kernel(x_ref, o_ref): + pl.debug_print("{}", x_ref[...]) + spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))) + x = jnp.arange(size, dtype=jnp.float32).reshape(shape) + f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec) + + with jtu.capture_stdout() as get_output: + jax.block_until_ready(f(x)) + + output = get_output() + results = re.findall(r"\[(\d+), (\d+)\]/\[128, 64\]: (\d+)", output) + self.assertLen(results, size) + for i, j, v in results: + i, j, v = map(int, (i, j, v)) + self.assertEqual(v, i * shape[1] + j) + def test_print_scalar(self): @functools.partial( pl.pallas_call,