Skip to content

Commit

Permalink
[Mosaic GPU] Add support for launching multiple warpgroups using core…
Browse files Browse the repository at this point in the history
…_map

PiperOrigin-RevId: 686876014
  • Loading branch information
apaszke authored and Google-ML-Automation committed Oct 17, 2024
1 parent 3bdc57d commit ef361f0
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 3 deletions.
64 changes: 64 additions & 0 deletions jax/_src/pallas/mosaic_gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Contains GPU-specific Pallas abstractions."""

import abc
import collections
from collections.abc import Sequence
import dataclasses
import enum
Expand All @@ -24,6 +25,7 @@
from jax._src import dtypes
from jax._src import tree_util
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
from jax._src.state.types import Transform
import jax.experimental.mosaic.gpu as mgpu
import jax.numpy as jnp
Expand Down Expand Up @@ -379,3 +381,65 @@ def _as_accum(ref) -> WGMMAAbstractAccumulatorRef:
def _ref_raise_to_shaped(ref_aval, weak_type):
return _as_accum(jax_core.raise_to_shaped_mappings[AbstractMemoryRef](ref_aval, weak_type))
jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulatorRef] = _ref_raise_to_shaped


_WARPGROUP_AXIS_NAME = object()

@dataclasses.dataclass(frozen=True, kw_only=True)
class GPUMesh:
grid: tuple[int, ...] = ()
cluster: tuple[int, ...] = ()
# Those are NOT CUDA threads. On Hopper they correspond to warpgroups.
num_threads: int | None = None
axis_names: tuple[str, ...] = ()

def __post_init__(self):
if len(self.axis_names) != len(self.grid) + (self.num_threads is not None):
raise ValueError("Need as many axis names as grid dimensions + warp groups")
if self.num_threads > 2048 // 128:
raise ValueError(
"Requested too many CUDA threads per block. Each Mosaic thread"
" corresponds to 128 CUDA threads."
)

@property
def shape(self):
if self.num_threads is not None:
pairs = zip(self.axis_names, (*self.grid, self.num_threads))
else:
pairs = (*zip(self.axis_names, self.grid), (_WARPGROUP_AXIS_NAME, 1))
return collections.OrderedDict(pairs)


def _gpu_mesh_discharge_rule(
in_avals,
out_avals,
*args,
mesh,
jaxpr,
):
del out_avals
assert isinstance(mesh, GPUMesh)
if mesh.grid or mesh.cluster:
raise NotImplementedError
if mesh.num_threads is None:
raise NotImplementedError
threads_axis_name, num_threads = list(mesh.shape.items())[0]
def body(*args):
# Due to aliasing, args contains aliased inputs and outputs so we remove
# outputs.
in_refs = args[:len(in_avals)]
jax_core.eval_jaxpr(jaxpr, in_refs)
assert len(jaxpr.outvars) == 0
any_spec = pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)
out = pallas_call.pallas_call(
body,
out_shape=in_avals,
in_specs=[any_spec] * len(in_avals),
out_specs=[any_spec] * len(in_avals),
input_output_aliases={i: i for i in range(len(in_avals))},
grid=((threads_axis_name, num_threads),),
)(*args)
return out, ()

pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule
29 changes: 26 additions & 3 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import functools
import itertools as it
import math
from typing import Any, Protocol, cast
from typing import Any, Hashable, Protocol, cast

import jax
from jax import lax
Expand Down Expand Up @@ -346,7 +346,6 @@ def lower_jaxpr_to_module(
block_mappings = grid_mapping.block_mappings
_check_block_mappings(block_mappings, name_and_src_info)

block = (128, 1, 1)
params = compiler_params.get("mosaic_gpu", {})
approx_math = params.get("approx_math", False)
max_concurrent_steps = params.get("max_concurrent_steps", 1)
Expand All @@ -368,7 +367,13 @@ def lower_jaxpr_to_module(
f" {max_concurrent_steps=}, {delay_release=}"
)

grid = [d for i, d in enumerate(grid_mapping.grid) if i not in sequential_axes]
block = (128, 1, 1)
grid = grid_mapping.grid
if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count
block = (128 * grid_mapping.grid[-1], 1, 1)
grid = grid[:-1]

grid = [d for i, d in enumerate(grid) if i not in sequential_axes]
if len(grid) < 3:
grid += (1,) * (3 - len(grid))
else:
Expand Down Expand Up @@ -1064,6 +1069,24 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
)


@register_lowering_rule(lax.axis_index_p)
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
grid_names = ctx.module_ctx.grid_mapping.grid_names
if grid_names and axis_name in grid_names:
if axis_name == grid_names[-1]:
return mgpu.warpgroup_idx(sync=False)
else:
raise NotImplementedError # The code below is untested
idx = grid_names.index(axis_name)
return arith_dialect.index_cast(
ir.IntegerType.get_signless(32),
gpu_dialect.block_id(gpu_dialect.Dimension(idx)),
)
raise ValueError(
"Named axes can only refer to GPUMesh axes in Mosaic GPU kernels"
)


@register_lowering_rule(primitives.debug_print_p)
def _debug_print_lowering_rule(
ctx: LoweringRuleContext,
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/pallas/mosaic_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401
from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
Expand Down
20 changes: 20 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,5 +806,25 @@ def body(step, _):
np.testing.assert_array_equal(kernel_fn(x), x + 1.0)


class CoreMapTest(PallasTest):

def test_multiple_wg(self):
mesh = plgpu.GPUMesh(num_threads=2, axis_names=("y",))

@jax.jit
def f():
@pl.run_state
def inner(y_ref):
@pl.core_map(mesh)
def kernel():
wg_idx = jax.lax.axis_index("y")
y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,))
y_init = jnp.zeros((2, 128), np.int32)
return inner(y_init)
np.testing.assert_array_equal(
f(), np.repeat(np.arange(2), 128).reshape(2, 128)
)


if __name__ == "__main__":
absltest.main()

0 comments on commit ef361f0

Please sign in to comment.