Skip to content

Commit

Permalink
Exposed sequential iteration index via pl.program_id in Pallas Mosa…
Browse files Browse the repository at this point in the history
…ic GPU

PiperOrigin-RevId: 680502214
  • Loading branch information
superbobry authored and Google-ML-Automation committed Sep 30, 2024
1 parent 2cfbdb6 commit 38d2a57
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 9 deletions.
30 changes: 21 additions & 9 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int:
class ModuleContext:
name: str
grid_mapping: pallas_core.GridMapping
program_ids: Sequence[ir.Value] | None
approx_math: bool
runtime_smem: ir.Value # ir.MemRefType
smem_used_bytes: int = 0
Expand Down Expand Up @@ -266,7 +267,6 @@ def lower_jaxpr_to_module(
raise NotImplementedError(
"Only <=3D grids are supported in Mosaic GPU lowering."
)
# Compute the number of steps along each sequential axis.
if sequential_axes:
# TODO(slebedev): Support multiple sequential axes.
if len(sequential_axes) > 1:
Expand Down Expand Up @@ -346,11 +346,19 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):

parallel_count = it.count()
program_ids_template = [
_program_id(next(parallel_count)) if i not in sequential_axes else None
for i in range(len(grid_mapping.grid))
_program_id(next(parallel_count))
if axis not in sequential_axes
else None
for axis in range(len(grid_mapping.grid))
]

def make_program_ids(step: ir.Value):
assert ir.IndexType.isinstance(step.type)
step = arith_dialect.index_cast(ir.IntegerType.get_signless(32), step)
return [step if pid is None else pid for pid in program_ids_template]

module_ctx = ModuleContext(
name_and_src_info.name, grid_mapping, approx_math, runtime_smem
name_and_src_info.name, grid_mapping, None, approx_math, runtime_smem
)

smem_scratch_it = iter(scratch_buffers_smem)
Expand Down Expand Up @@ -412,7 +420,7 @@ def gmem_slice(
block_mapping: pallas_core.BlockMapping,
) -> Sequence[mgpu.DynamicSlice]:
assert len(sequential_axes) <= 1
program_ids = [step if i is None else i for i in program_ids_template]
program_ids = make_program_ids(step)
idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping)
return tuple(
mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape)
Expand Down Expand Up @@ -492,6 +500,7 @@ def store(
fetch(idx, _as_index(slot), _as_index(slot))

last_store_offsets = [None if inv else _as_index(-1) for inv in out_sequential_invariant]

@mgpu.fori(_as_index(num_steps), (accs, last_store_offsets))
def _(step, carry):
accs, last_store_offsets = carry
Expand Down Expand Up @@ -519,7 +528,10 @@ def _(step, carry):
# but that's not necessarily true.
args.extend(extra_barriers)
new_accs = lower_jaxpr_to_mosaic_gpu(
module_ctx, launch_ctx, lowered_jaxpr, args
dataclasses.replace(module_ctx, program_ids=make_program_ids(step)),
launch_ctx,
lowered_jaxpr,
args,
)

# TODO(apaszke): Elide this if we're not going to perform any stores
Expand Down Expand Up @@ -668,9 +680,9 @@ def write_env(var: jax_core.Var, val):

@register_lowering_rule(primitives.program_id_p)
def _program_id_lowering_rule(ctx: LoweringRuleContext, axis):
# TODO(apaszke): Sequential axis should be handled specially!!
del ctx # Unused.
return _program_id(axis)
if ctx.module_ctx.program_ids is None:
raise NotImplementedError("pl.program_id() is not supported in this context")
return ctx.module_ctx.program_ids[axis]


def _program_id(axis: int) -> ir.Value:
Expand Down
40 changes: 40 additions & 0 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import functools
import math
import traceback

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -116,6 +117,26 @@ def kernel(x_ref, o_ref):
x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x + 1.0)

def test_add_one_grid_pipelined_program_id(self):

@functools.partial(
pl.pallas_call,
out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)),
out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "sequential"],
max_concurrent_steps=2,
),
grid=(4, 4),
)
def kernel(o_ref):
o_ref[...] = jnp.broadcast_to(pl.program_id(1), o_ref.shape)

np.testing.assert_array_equal(
kernel(),
jnp.repeat(jnp.repeat(jnp.arange(4), 16)[None], 16, axis=0),
)

def test_add_one_with_async_copy_smem_to_gmem(self):
@functools.partial(
pl.pallas_call,
Expand Down Expand Up @@ -309,6 +330,25 @@ def kernel(o_ref):
jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32),
)

def test_program_id_in_block_spec(self):
@functools.partial(
pl.pallas_call,
out_specs=pl.BlockSpec((128,), lambda *_: pl.program_id(0)),
out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32),
grid=2,
)
def kernel(o_ref):
del o_ref

# ``assertRaises`` have no way of asserting against the cause, so we
# have to use ``traceback.format_exception`` manually.
with self.assertRaises(Exception) as exc_info:
kernel()
self.assertIn(
"not supported in this context",
"".join(traceback.format_exception(exc_info.exception)),
)

def test_num_programs(self):
@functools.partial(
pl.pallas_call,
Expand Down

0 comments on commit 38d2a57

Please sign in to comment.