diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index a82e50911e00..297a78751a32 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -15,6 +15,7 @@ """Contains GPU-specific Pallas abstractions.""" import abc +import collections from collections.abc import Sequence import dataclasses import enum @@ -24,6 +25,7 @@ from jax._src import dtypes from jax._src import tree_util from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call from jax._src.state.types import Transform import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp @@ -379,3 +381,65 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: def _ref_raise_to_shaped(ref_aval, weak_type): return _as_accum(jax_core.raise_to_shaped_mappings[AbstractMemoryRef](ref_aval, weak_type)) jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulatorRef] = _ref_raise_to_shaped + + +_WARPGROUP_AXIS_NAME = object() + +@dataclasses.dataclass(frozen=True, kw_only=True) +class GPUMesh: + grid: tuple[int, ...] = () + cluster: tuple[int, ...] = () + # Those are NOT CUDA threads. On Hopper they correspond to warpgroups. + num_threads: int | None = None + axis_names: tuple[str, ...] = () + + def __post_init__(self): + if len(self.axis_names) != len(self.grid) + (self.num_threads is not None): + raise ValueError("Need as many axis names as grid dimensions + warp groups") + if self.num_threads > 2048 // 128: + raise ValueError( + "Requested too many CUDA threads per block. Each Mosaic thread" + " corresponds to 128 CUDA threads." + ) + + @property + def shape(self): + if self.num_threads is not None: + pairs = zip(self.axis_names, (*self.grid, self.num_threads)) + else: + pairs = (*zip(self.axis_names, self.grid), (_WARPGROUP_AXIS_NAME, 1)) + return collections.OrderedDict(pairs) + + +def _gpu_mesh_discharge_rule( + in_avals, + out_avals, + *args, + mesh, + jaxpr, +): + del out_avals + assert isinstance(mesh, GPUMesh) + if mesh.grid or mesh.cluster: + raise NotImplementedError + if mesh.num_threads is None: + raise NotImplementedError + threads_axis_name, num_threads = list(mesh.shape.items())[0] + def body(*args): + # Due to aliasing, args contains aliased inputs and outputs so we remove + # outputs. + in_refs = args[:len(in_avals)] + jax_core.eval_jaxpr(jaxpr, in_refs) + assert len(jaxpr.outvars) == 0 + any_spec = pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY) + out = pallas_call.pallas_call( + body, + out_shape=in_avals, + in_specs=[any_spec] * len(in_avals), + out_specs=[any_spec] * len(in_avals), + input_output_aliases={i: i for i in range(len(in_avals))}, + grid=((threads_axis_name, num_threads),), + )(*args) + return out, () + +pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 3a5c26e63feb..87761fbdbcad 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -22,7 +22,7 @@ import functools import itertools as it import math -from typing import Any, Protocol, cast +from typing import Any, Hashable, Protocol, cast import jax from jax import lax @@ -346,7 +346,6 @@ def lower_jaxpr_to_module( block_mappings = grid_mapping.block_mappings _check_block_mappings(block_mappings, name_and_src_info) - block = (128, 1, 1) params = compiler_params.get("mosaic_gpu", {}) approx_math = params.get("approx_math", False) max_concurrent_steps = params.get("max_concurrent_steps", 1) @@ -368,7 +367,13 @@ def lower_jaxpr_to_module( f" {max_concurrent_steps=}, {delay_release=}" ) - grid = [d for i, d in enumerate(grid_mapping.grid) if i not in sequential_axes] + block = (128, 1, 1) + grid = grid_mapping.grid + if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count + block = (128 * grid_mapping.grid[-1], 1, 1) + grid = grid[:-1] + + grid = [d for i, d in enumerate(grid) if i not in sequential_axes] if len(grid) < 3: grid += (1,) * (3 - len(grid)) else: @@ -1064,6 +1069,24 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): ) +@register_lowering_rule(lax.axis_index_p) +def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): + grid_names = ctx.module_ctx.grid_mapping.grid_names + if grid_names and axis_name in grid_names: + if axis_name == grid_names[-1]: + return mgpu.warpgroup_idx(sync=False) + else: + raise NotImplementedError # The code below is untested + idx = grid_names.index(axis_name) + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.block_id(gpu_dialect.Dimension(idx)), + ) + raise ValueError( + "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels" + ) + + @register_lowering_rule(primitives.debug_print_p) def _debug_print_lowering_rule( ctx: LoweringRuleContext, diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 80eaea7534c1..a200588e71f4 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -26,6 +26,7 @@ from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 +from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 2a930a24449c..450c15741db5 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -806,5 +806,25 @@ def body(step, _): np.testing.assert_array_equal(kernel_fn(x), x + 1.0) +class CoreMapTest(PallasTest): + + def test_multiple_wg(self): + mesh = plgpu.GPUMesh(num_threads=2, axis_names=("y",)) + + @jax.jit + def f(): + @pl.run_state + def inner(y_ref): + @pl.core_map(mesh) + def kernel(): + wg_idx = jax.lax.axis_index("y") + y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) + y_init = jnp.zeros((2, 128), np.int32) + return inner(y_init) + np.testing.assert_array_equal( + f(), np.repeat(np.arange(2), 128).reshape(2, 128) + ) + + if __name__ == "__main__": absltest.main()