From 38d2a573fcc975ca778da1052f115d02174805e1 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 30 Sep 2024 03:35:12 -0700 Subject: [PATCH] Exposed sequential iteration index via `pl.program_id` in Pallas Mosaic GPU PiperOrigin-RevId: 680502214 --- jax/_src/pallas/mosaic_gpu/lowering.py | 30 +++++++++++++------ tests/pallas/mosaic_gpu_test.py | 40 ++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0d0ac41d11e3..619d436656be 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -102,6 +102,7 @@ def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int: class ModuleContext: name: str grid_mapping: pallas_core.GridMapping + program_ids: Sequence[ir.Value] | None approx_math: bool runtime_smem: ir.Value # ir.MemRefType smem_used_bytes: int = 0 @@ -266,7 +267,6 @@ def lower_jaxpr_to_module( raise NotImplementedError( "Only <=3D grids are supported in Mosaic GPU lowering." ) - # Compute the number of steps along each sequential axis. if sequential_axes: # TODO(slebedev): Support multiple sequential axes. if len(sequential_axes) > 1: @@ -346,11 +346,19 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): parallel_count = it.count() program_ids_template = [ - _program_id(next(parallel_count)) if i not in sequential_axes else None - for i in range(len(grid_mapping.grid)) + _program_id(next(parallel_count)) + if axis not in sequential_axes + else None + for axis in range(len(grid_mapping.grid)) ] + + def make_program_ids(step: ir.Value): + assert ir.IndexType.isinstance(step.type) + step = arith_dialect.index_cast(ir.IntegerType.get_signless(32), step) + return [step if pid is None else pid for pid in program_ids_template] + module_ctx = ModuleContext( - name_and_src_info.name, grid_mapping, approx_math, runtime_smem + name_and_src_info.name, grid_mapping, None, approx_math, runtime_smem ) smem_scratch_it = iter(scratch_buffers_smem) @@ -412,7 +420,7 @@ def gmem_slice( block_mapping: pallas_core.BlockMapping, ) -> Sequence[mgpu.DynamicSlice]: assert len(sequential_axes) <= 1 - program_ids = [step if i is None else i for i in program_ids_template] + program_ids = make_program_ids(step) idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping) return tuple( mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape) @@ -492,6 +500,7 @@ def store( fetch(idx, _as_index(slot), _as_index(slot)) last_store_offsets = [None if inv else _as_index(-1) for inv in out_sequential_invariant] + @mgpu.fori(_as_index(num_steps), (accs, last_store_offsets)) def _(step, carry): accs, last_store_offsets = carry @@ -519,7 +528,10 @@ def _(step, carry): # but that's not necessarily true. args.extend(extra_barriers) new_accs = lower_jaxpr_to_mosaic_gpu( - module_ctx, launch_ctx, lowered_jaxpr, args + dataclasses.replace(module_ctx, program_ids=make_program_ids(step)), + launch_ctx, + lowered_jaxpr, + args, ) # TODO(apaszke): Elide this if we're not going to perform any stores @@ -668,9 +680,9 @@ def write_env(var: jax_core.Var, val): @register_lowering_rule(primitives.program_id_p) def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): - # TODO(apaszke): Sequential axis should be handled specially!! - del ctx # Unused. - return _program_id(axis) + if ctx.module_ctx.program_ids is None: + raise NotImplementedError("pl.program_id() is not supported in this context") + return ctx.module_ctx.program_ids[axis] def _program_id(axis: int) -> ir.Value: diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b35658ed4845..77ddd5fb2628 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -14,6 +14,7 @@ import functools import math +import traceback from absl.testing import absltest from absl.testing import parameterized @@ -116,6 +117,26 @@ def kernel(x_ref, o_ref): x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) + def test_add_one_grid_pipelined_program_id(self): + + @functools.partial( + pl.pallas_call, + out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)), + out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32), + compiler_params=plgpu.GPUCompilerParams( + dimension_semantics=["parallel", "sequential"], + max_concurrent_steps=2, + ), + grid=(4, 4), + ) + def kernel(o_ref): + o_ref[...] = jnp.broadcast_to(pl.program_id(1), o_ref.shape) + + np.testing.assert_array_equal( + kernel(), + jnp.repeat(jnp.repeat(jnp.arange(4), 16)[None], 16, axis=0), + ) + def test_add_one_with_async_copy_smem_to_gmem(self): @functools.partial( pl.pallas_call, @@ -309,6 +330,25 @@ def kernel(o_ref): jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32), ) + def test_program_id_in_block_spec(self): + @functools.partial( + pl.pallas_call, + out_specs=pl.BlockSpec((128,), lambda *_: pl.program_id(0)), + out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), + grid=2, + ) + def kernel(o_ref): + del o_ref + + # ``assertRaises`` have no way of asserting against the cause, so we + # have to use ``traceback.format_exception`` manually. + with self.assertRaises(Exception) as exc_info: + kernel() + self.assertIn( + "not supported in this context", + "".join(traceback.format_exception(exc_info.exception)), + ) + def test_num_programs(self): @functools.partial( pl.pallas_call,