From 66a3f87a24016594794c2ee289826baed5e979a4 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 19 Aug 2024 04:28:06 -0700 Subject: [PATCH] Rollback for: Implement initial vmap over pallas_call w/ ragged inputs (via jumbles) It can cause issues in x32 when trying to get the aval for array dimension sizes that are larger than i32. Reverts 24394a1b03f01138219013f4773104b834e498b7 PiperOrigin-RevId: 664742891 --- jax/_src/core.py | 24 +-- jax/_src/interpreters/batching.py | 6 +- jax/_src/pallas/core.py | 53 ++---- jax/_src/pallas/mosaic/lowering.py | 28 +-- jax/_src/pallas/pallas_call.py | 275 +++-------------------------- tests/pallas/BUILD | 23 --- tests/pallas/pallas_jumble_test.py | 201 --------------------- 7 files changed, 50 insertions(+), 560 deletions(-) delete mode 100644 tests/pallas/pallas_jumble_test.py diff --git a/jax/_src/core.py b/jax/_src/core.py index 61ed81cdeea9..ebf29cf0b253 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1954,7 +1954,6 @@ def __init__(self, aval, data): assert data.shape == pad_shape self._aval = aval self._data = data - shape = property(lambda self: self._aval.shape) dtype = property(lambda self: self._aval.dtype) aval = property(lambda self: self._aval) @@ -1965,38 +1964,21 @@ def __repr__(self) -> str: dtypestr = _short_dtype_name(self._aval.dtype) shapestr = ','.join(map(str, self.shape)) - data = self.data + slices = tuple(slice(int(d._data)) if type(d) is DArray and + type(d.dtype) is bint else slice(None) for d in self.shape) + data = self._data[slices] return f'{dtypestr}[{shapestr}] with value: {data}' - def __hash__(self) -> int: if not self.shape: return hash((self._aval, int(self._data))) raise TypeError("unhashable type: DArray") - def __eq__(self, other): if isinstance(other, DArray) and self._aval == other._aval: return self._data == other._data return False - def __len__(self): return self.shape[0] - @property - def data(self): - if not self.shape and type(self.dtype) is bint: - # special-case scalar bints - return self._data - - slices = tuple( - slice(int(d._data)) - if type(d) is DArray and type(d.dtype) is bint - else slice(None) - for d in self.shape - ) - data = self._data[slices] - return data - - pytype_aval_mappings[DArray] = \ lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type, x._data) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 27cde6d31d35..fbcd2c4a7a30 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -88,7 +88,6 @@ def _jumble_flatten(jumble): elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape)) aval = jumble.aval.replace(elt_ty=elt_ty) return (lengths, jumble.data), aval - def _jumble_unflatten(aval, x): lengths, data = x new_shape = [d.replace(lengths=lengths[d.lengths - 1]) @@ -252,10 +251,7 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: return (BatchTracer(trace, x, spec, source_info_util.current()) if spec is not None else x) else: - # TODO(mvoz): This is a terrible place to fall into if you pass - # a non jumble type in, make it clearer what went wrong. - assert False, f'Unexpected type in ELT? {type(x)}' - + assert False to_elt_handlers: dict[type, ToEltHandler] = {} def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int, diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 09e02ea5c3a1..0ef208f755e5 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -112,10 +112,7 @@ class AbstractMemoryRef(state.AbstractRef): def __init__(self, inner_aval: jax_core.AbstractValue, memory_space: Any): - - assert isinstance( - inner_aval, jax_core.ShapedArray - ), f"Illegal ref, got {type(inner_aval)}" + assert isinstance(inner_aval, jax_core.ShapedArray) self.inner_aval = inner_aval self.memory_space = memory_space @@ -170,7 +167,9 @@ class PallasGridContext: mapped_dims: tuple[int, ...] def size(self, axis: int) -> int | DynamicGridDim: - valid_grid = tuple(self.grid) + valid_grid = tuple( + s for i, s in enumerate(self.grid) if i not in self.mapped_dims + ) try: size = valid_grid[axis] except IndexError as e: @@ -339,10 +338,7 @@ def check_invariants(self) -> None: ) assert not self.index_map_jaxpr.consts - assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), ( - self.block_shape, - self.index_map_jaxpr.out_avals, - ) + assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals) assert all(ov.shape == () and (ov.dtype == jnp.int32 or ov.dtype == jnp.int64) for ov in self.index_map_jaxpr.out_avals), ( @@ -426,8 +422,6 @@ class GridMapping: num_inputs: int num_outputs: int num_scratch_operands: int - get_grid_indices: Callable | None = None - local_grid_env: Callable | None = None def check_invariants(self) -> None: if not config.enable_checks.value: return @@ -448,8 +442,8 @@ def check_invariants(self) -> None: assert len(index_map_args) >= len(self.grid) for i in range(len(self.grid)): index_map_arg = index_map_args[i] - assert index_map_arg.shape == (), f"index_map_arg: {index_map_arg}" - assert index_map_arg.dtype == jnp.int32, f"index_map_arg: {index_map_arg}" + assert index_map_arg.shape == () + assert index_map_arg.dtype == jnp.int32 assert len(self.vmapped_dims) <= len(self.grid) for i in self.vmapped_dims: @@ -460,11 +454,8 @@ def check_invariants(self) -> None: for bm in self.block_mappings: bm.check_invariants() - assert tuple(self.index_map_avals) == tuple( - bm.index_map_jaxpr.in_avals - ), ( + assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), ( self.index_map_avals, - "|", bm.index_map_jaxpr.in_avals, ) @@ -556,17 +547,6 @@ def _is_valid_grid_dim(dim: int | jax.Array) -> bool: return True return jax_core.is_dim(dim) - -def _max_shape_from_aval(array_aval: jax_core.ShapedArray): - array_aval_shape = list(array_aval.shape) - for i, s in enumerate(array_aval.shape): - aval = jax_core.get_aval(s) - if isinstance(aval, jax_core.DShapedArray): - array_aval_shape[i] = aval.dtype.bound - - return tuple(array_aval_shape) - - def _convert_block_spec_to_block_mapping( block_spec: BlockSpec, origin: OriginStr, @@ -595,15 +575,8 @@ def _convert_block_spec_to_block_mapping( f"array shape {array_aval.shape}.") unmapped_block_shape = tuple(s for s in block_shape if s is not None) - block_array_aval = array_aval.update(shape=unmapped_block_shape) - if isinstance(array_aval, jax_core.DShapedArray): - # Get the "max" shape for the ragged array. - block_array_aval = jax_core.ShapedArray( - block_array_aval.shape, - block_array_aval.dtype, - block_array_aval.weak_type, - ) - block_aval = AbstractMemoryRef(block_array_aval, block_spec.memory_space) + block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape), + block_spec.memory_space) if not jax_core.is_constant_shape(block_aval.shape): raise ValueError( @@ -636,12 +609,12 @@ def _convert_block_spec_to_block_mapping( f"{origin} must return integer scalars. Output[{i}] has type " f"{ov}.") + if consts: raise ValueError( f"Index map function {index_map_src_info} for " f"{origin} must not capture constants: {consts}") - array_aval_shape = _max_shape_from_aval(array_aval) mapping = BlockMapping( block_shape=mapped_block_shape, @@ -649,9 +622,7 @@ def _convert_block_spec_to_block_mapping( index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), index_map_src_info=index_map_src_info, indexing_mode=block_spec.indexing_mode, - array_shape_dtype=jax.ShapeDtypeStruct( - array_aval_shape, array_aval.dtype - ), + array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype), origin=origin, ) mapping.check_invariants() diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index aee894ee1b7e..86ce2f0b1b81 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -298,7 +298,6 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, self.jaxpr = jaxpr self.block_mappings = grid_mapping.block_mappings self.mapped_dims = grid_mapping.vmapped_dims - # TODO(mvoz): Generalize to not need this user_grid = tuple( g for i, g in enumerate(self.grid) if i not in self.mapped_dims ) @@ -346,19 +345,9 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, for _ in range(len(self.grid)) ]) self._prepare_mesh_info(mesh) - - if grid_mapping.get_grid_indices is None: - - def _get_grid_indices(indices, maybe_include_mapped_dims: bool): - if maybe_include_mapped_dims: - return indices - return tuple( - idx for i, idx in enumerate(indices) if i not in self.mapped_dims - ) - - self.get_grid_indices = _get_grid_indices - else: - self.get_grid_indices = grid_mapping.get_grid_indices + def _get_grid_indices(indices): + return indices + self.get_grid_indices = _get_grid_indices def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None): if not self.has_communication: @@ -606,9 +595,7 @@ def lower_jaxpr_to_transform_func( ] def body_func(*args): grid_indices, scalar_prefetch = split_list(args, [num_grid]) - jaxpr_indices = mosaic_grid_mapping.get_grid_indices( - grid_indices, maybe_include_mapped_dims=True - ) + jaxpr_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) arg_block_shapes = [ *[()] * len(jaxpr_indices), *mosaic_grid_mapping.scalar_prefetch_block_shapes, @@ -676,9 +663,9 @@ def lower_jaxpr_to_func( def body_func(*args): grid_indices, scalar_prefetch, operands_and_scratch = split_list( args, [num_grid, num_scalar_prefetch]) - jaxpr_indices = mosaic_grid_mapping.get_grid_indices( - grid_indices, maybe_include_mapped_dims=False - ) + grid_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) + jaxpr_indices = tuple(idx for i, idx in enumerate(grid_indices) + if i not in mosaic_grid_mapping.mapped_dims) mesh_info = mosaic_grid_mapping.mesh_info if mesh_info is not None: mesh_context = MeshContext( @@ -2378,7 +2365,6 @@ def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): - if ctx.lowering_context.user_grid_indices is None: raise ValueError( f"program id: {axis} was passed, but user did not provide a grid." diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index bb1683e38b1c..4f3c9918f664 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -228,12 +228,6 @@ def _pallas_call_impl_interpret( # Pad values to evenly divide into block dimensions. This matches the # behavior of the non-interpret mode. We pad with NaN, to make it easier # to catch OOB accesses. - for carry_element in carry: - aval = carry_element.aval - if isinstance(aval, jax_core.DShapedArray): - aval = jax_core.ShapedArray(aval.shape, aval.dtype) - carry_element.aval = aval - carry = map(_pad_values_to_block_dimension, carry, block_shapes) carry.extend(scratch_values) @@ -253,16 +247,11 @@ def cond(carry): return i < num_iterations def body(carry): i, loop_idx, *carry_blocks = carry - - if grid_mapping.local_grid_env is not None: - local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) - else: - local_grid_env = tuple( - pallas_core.GridAxis(idx, b) - for dim, (idx, b) in enumerate(zip(loop_idx, grid)) - if dim not in grid_mapping.vmapped_dims - ) - + local_grid_env = tuple( + pallas_core.GridAxis(idx, b) + for dim, (idx, b) in enumerate(zip(loop_idx, grid)) + if dim not in grid_mapping.vmapped_dims + ) carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks]) with pallas_core.grid_env(local_grid_env): start_indices = [ @@ -279,14 +268,8 @@ def body(carry): len(blocks), len(scratch_values), ) - for s in scalars: - aval = jax_core.get_aval(s) - if isinstance(aval, jax_core.DShapedArray): - s.aval = aval.update(dtype=jnp.int32) - - blocks = jax_core.eval_jaxpr( - discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch - ) + blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars, + *blocks, *scratch) _, out_inout, out_scratch = split_list( blocks, [grid_mapping.num_index_operands, num_inout_blocks]) @@ -407,55 +390,19 @@ def _pallas_call_jvp_rule( ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule - -def _batch_block_mapping( - grid_mapping: GridMapping, - axis_size: int, - aval: jax_core.ShapedArray, - dim: int | batching.NotMapped, - block_mapping: BlockMapping, - for_ragged: bool, -) -> BlockMapping: +def _batch_block_mapping(grid_mapping: GridMapping, + axis_size: int, + aval: jax_core.ShapedArray, + dim: int | batching.NotMapped, + block_mapping: BlockMapping) -> BlockMapping: def _block_map_function(new_idx, *args): - if for_ragged: - drop_last_args = args[:-1] - else: - drop_last_args = args - - indices = jax_core.eval_jaxpr( - block_mapping.index_map_jaxpr.jaxpr, - block_mapping.index_map_jaxpr.consts, - *drop_last_args, - ) + indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr, + block_mapping.index_map_jaxpr.consts, + *args) if dim is not batching.not_mapped: - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - stacked_axis = dim.stacked_axis - indices.insert(stacked_axis, new_idx) - else: - indices.insert(dim, new_idx) + indices.insert(dim, new_idx) return tuple(indices) idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals] - - if for_ragged: - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - _, _, ragged_axis_length = _ragged_axis_parts(dim) - aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) - if isinstance(aval, jax_core.DShapedArray): - aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) - lengths_aval = pallas_core.AbstractMemoryRef( - aval, - pallas_core.MemorySpace.INDEX, - ) - idx_avals = [*idx_avals, lengths_aval] - else: - i32_aval_memref = pallas_core.AbstractMemoryRef( - jax_core.ShapedArray(([axis_size]), jnp.int32), - pallas_core.MemorySpace.INDEX, - ) - idx_avals = [*idx_avals, i32_aval_memref] - with grid_mapping.trace_env(): block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_block_map_function), idx_avals) @@ -464,27 +411,12 @@ def _block_map_function(new_idx, *args): new_block_shape = shape new_array_shape_dtype = block_mapping.array_shape_dtype else: - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - new_block_shape = shape - stacked_axis = dim.stacked_axis - new_block_shape = tuple_insert( - new_block_shape, stacked_axis, pallas_core.mapped - ) - else: - new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) - - array_shape = block_mapping.array_shape_dtype.shape - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - stacked_axis = dim.stacked_axis - array_shape = tuple_insert(array_shape, stacked_axis, axis_size) - else: - array_shape = tuple_insert(array_shape, dim, axis_size) - + new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) new_array_shape_dtype = jax.ShapeDtypeStruct( - array_shape, block_mapping.array_shape_dtype.dtype - ) + tuple_insert(block_mapping.array_shape_dtype.shape, + dim, + axis_size), + block_mapping.array_shape_dtype.dtype) jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts) return block_mapping.replace(block_shape=new_block_shape, @@ -615,16 +547,6 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: return result, (0,) * len(result) -def _ragged_axis_parts(dim: batching.RaggedAxis) -> tuple[int, int, int]: - stacked_axis = dim.stacked_axis - ragged_axes = dim.ragged_axes - if len(ragged_axes) != 1: - raise ValueError("Multiple ragged axes not yet implemented.") - ragged_axis_dim = ragged_axes[0][0] - ragged_axis_length = ragged_axes[0][1] - return stacked_axis, ragged_axis_dim, ragged_axis_length - - def _pallas_call_batching_rule( args, dims, @@ -645,26 +567,8 @@ def _maybe_squeeze_out_bdim( return x return jnp.squeeze(x, axis=bdim) - all_ragged_axes = [d for d in dims if isinstance(d, batching.RaggedAxis)] - if len(all_ragged_axes) > 1: - raise ValueError("Multiple ragged dimensions not yet implemented.") - - if all_ragged_axes: - stacked_axis, ragged_axis_dim, ragged_axis_length = _ragged_axis_parts( - all_ragged_axes[0] - ) - else: - stacked_axis, ragged_axis_dim, ragged_axis_length = None, None, None - - def get_size(i, x, d): - if not isinstance(d, batching.RaggedAxis): - return x.shape[d] - return x.aval.shape[i] - (axis_size,) = { - get_size(i=i, x=x, d=d) - for i, (x, d) in enumerate(zip(args, dims)) - if d is not batching.not_mapped + x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped } if axis_size == 1: # Why are we even vmapping? @@ -766,27 +670,12 @@ def get_size(i, x, d): num_index_operands = grid_mapping.num_index_operands num_scratch_operands = grid_mapping.num_scratch_operands - lengths_aval = None - if ragged_axis_length is not None: - aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) - if isinstance(aval, jax_core.DShapedArray): - aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) - lengths_aval = pallas_core.AbstractMemoryRef( - aval, - pallas_core.MemorySpace.INDEX, - ) - # Only add a batch dimension for the avals that actually have a grid mapping. # This excludes scalar prefetch inputs (the first in the list) and scratch # operands (the last in the list). avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)] batched_block_mappings = map( - partial( - _batch_block_mapping, - grid_mapping, - axis_size, - for_ragged=lengths_aval is not None, - ), + partial(_batch_block_mapping, grid_mapping, axis_size), avals_to_batch, all_dims[num_index_operands:], block_mappings, @@ -796,23 +685,15 @@ def get_size(i, x, d): grid_mapping.index_map_avals) assert not index_map_tree_kwargs batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args - - if lengths_aval: - batched_index_map_args = batched_index_map_args + (lengths_aval,) - num_index_operands += 1 - batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten( (batched_index_map_args, {})) - batched_grid_mapping = grid_mapping.replace( grid=(axis_size, *grid_mapping.grid), block_mappings=tuple(batched_block_mappings), - index_map_avals=tuple(batched_index_map_avals), + index_map_avals=batched_index_map_avals, index_map_tree=batched_index_map_tree, - num_index_operands=num_index_operands, vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims), ) - if cost_estimate is not None: batched_cost_estimate = CostEstimate( flops=cost_estimate.flops * axis_size, @@ -821,103 +702,6 @@ def get_size(i, x, d): ) else: batched_cost_estimate = None - - if lengths_aval: - batched_grid_mapping = batched_grid_mapping.replace( - get_grid_indices=lambda indices, maybe_include_mapped_dims: indices, - local_grid_env=lambda loop_idx, grid: tuple( - pallas_core.GridAxis(idx, b) for (idx, b) in zip(loop_idx, grid) - ), - ) - - # Note - on zero filling counterfactuals - # A debug util to produce a counterfactual version of the when - # gating, where for all values that don't pass the @when check, - # we write 0s. This is useful for debugging, as certain lowering paths - # like mosaic will write the last data as passthrough, leading to - # potentially confusing results. - debug_zero_fill_counterfactual = debug - - first_block_mapping = batched_grid_mapping.block_mappings[0] - for block_mapping in batched_grid_mapping.block_mappings: - # This invariant may already be checked elsewhere, but lets reaffirm it - assert block_mapping.block_shape == first_block_mapping.block_shape, ( - f"block_mapping.block_shape: {block_mapping.block_shape}, " - f"first_block_mapping.block_shape: {first_block_mapping.block_shape}" - ) - assert ( - block_mapping.array_shape_dtype - == first_block_mapping.array_shape_dtype - ), ( - f"block_mapping.array_shape_dtype: {block_mapping.array_shape_dtype}," - " first_block_mapping.array_shape_dtype:" - f" {first_block_mapping.array_shape_dtype}" - ) - - mapped_dim_idxs = [ - i - for i, d in enumerate(first_block_mapping.block_shape) - if d is pallas_core.mapped - ] - assert len(mapped_dim_idxs) == 1 - mapped_dim_idx = mapped_dim_idxs[0] - if stacked_axis != mapped_dim_idx: - raise ValueError( - f"Expected mapped dim to be {stacked_axis}, but got {mapped_dim_idx}" - ) - - assert ragged_axis_dim is not None, "Invariant violation" - # This is the blockspec size of the dimension - val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim] - - def when_wrapped_kernel(lengths_ref, *args, **kwargs): - b_idx = jax.experimental.pallas.program_id(stacked_axis) - i_idx = ( - jax.experimental.pallas.program_id(ragged_axis_dim) - * val_at_ragged_dim - ) - b_len = lengths_ref[b_idx] - - # TODO(mvoz): Unimplemented primitive in pallas - # b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0) - # checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0") - - @jax.experimental.pallas.when(i_idx < b_len) - def f(): - # Important! This allows us to trace the inner kernel with the correct - # grid to preserve user program_id semantics. Ex: program_id(0) will - # always be analogous to program_id(1) in the outer kernel. - with pallas_core.tracing_grid_env(grid_mapping.grid, ()): - jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs) - - if debug_zero_fill_counterfactual: - - @jax.experimental.pallas.when(i_idx >= b_len) - def g(): - for arg_ref in args: - arg_ref[...] = jnp.zeros_like(arg_ref) - - kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars] - flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten( - list(kernel_avals) - ) - # Important! This allows us to trace the outer kernel with the correct grid - # to enable accessing the batch program_id. - with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()): - kernel_src_info: pallas_core.SrcInfoStr = "" - - jaxpr = _trace_kernel_to_jaxpr( - when_wrapped_kernel, - kernel_src_info, - batched_grid_mapping, - tuple(flat_kernel_avals), - kernel_in_tree, - interpret=interpret, - ) - - assert ragged_axis_length is not None - args = (ragged_axis_length, *args) - out = pallas_call_p.bind( *dynamic_grid_args, *args, @@ -1313,14 +1097,12 @@ def pallas_call( out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths) flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore for x in flat_out_shapes] - @jax.jit def wrapped(*args): flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args) in_paths, flat_args = unzip2(flat_args_with_paths) flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a)) for a in flat_args) - flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype) for v in flat_out_shapes) @@ -1390,18 +1172,15 @@ def wrapped(*args): return wrapped -def in_path_to_input_origin( - in_path: tree_util.KeyPath, arg_names: tuple[str, ...] | None -) -> pallas_core.OriginStr: +def in_path_to_input_origin(in_path: tree_util.KeyPath, + arg_names: tuple[str, ...] | None) -> pallas_core.OriginStr: """Converts `args[k]` into `arg_k_name`.""" if arg_names is None: return f"args{tree_util.keystr(in_path)}" if len(in_path) == 0: return "args" arg_idx, *rest_path = in_path - if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len( - arg_names - ): + if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(arg_names): return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path)) else: return f"args{tree_util.keystr(tuple(in_path))}" diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 5559a0552f9f..c0cf61387cbb 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -62,29 +62,6 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( - name = "pallas_jumble_test", - srcs = [ - "pallas_jumble_test.py", - ], - disable_configs = [ - "gpu", - "gpu_x32", - "gpu_a100", - "gpu_p100", - "gpu_p100_x32", - "gpu_h100", - ], - shard_count = { - "tpu": 1, - }, - deps = [ - "//jax:pallas", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), -) - jax_test( name = "ops_test", srcs = [ diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py deleted file mode 100644 index ee176a0363aa..000000000000 --- a/tests/pallas/pallas_jumble_test.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" - -from absl.testing import absltest -import jax -from jax import lax -from jax._src import config -from jax._src import core -from jax._src import dtypes -from jax._src import test_util as jtu -from jax._src.interpreters import batching -from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr -from jax.experimental import pallas as pl -import jax.numpy as jnp -import numpy as np - - -# TODO(mvoz): Update signatures of pallas_call to correct inputs/outputs. -# pylint: disable=no-value-for-parameter - -config.parse_flags_with_absl() - - -intx = dtypes.canonicalize_dtype(jnp.int64) -floatx = dtypes.canonicalize_dtype(jnp.float64) - - -@jtu.with_config(jax_traceback_filtering="off") -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False - - def setUp(self): - if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: - self.skipTest("On CPU the test works only in interpret mode") - if jtu.test_device_matches( - ["cuda"] - ) and not jtu.is_cuda_compute_capability_at_least("8.0"): - self.skipTest("Only works on GPU with capability >= sm80") - if sys.platform == "win32" and not self.INTERPRET: - self.skipTest("Only works on non-Windows platforms") - - super().setUp() - _trace_kernel_to_jaxpr.cache_clear() - - def pallas_call(self, *args, **kwargs): - return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) - - -@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") -class PallasCallRaggedVmapTest(PallasBaseTest): - - def test_vmap_jumble_over_sin_kernel(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([128 * x for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sin(x_ref[...]) - - def invoke_kernel(x): - return pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct( - (8, col_grid_size * 128), dtype=jnp.float32 - ), - grid=(1, col_grid_size), - interpret=self.INTERPRET, - # See note - on zero filling counterfactuals - debug=True, - )(x) - - res = jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - res = res.data - total = len(ragged_shape) * row_count * col_grid_size * 128 - res_total = np.prod(res.shape) - self.assertEqual(res_total, total) - ragged_total = 0 - for dim in ragged_shape: - ragged_total += row_count * dim * 128 - # See note - on zero filling counterfactuals - self.assertEqual(np.count_nonzero(res == jnp.sin(1.0)), ragged_total) - - def test_vmap_jumble_over_sin_kernel_grid_remapping(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([128 * x for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sin(x_ref[...]) * pl.program_id(2) - - def invoke_kernel(x): - return pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), - grid=(1, 5), - interpret=False, - )(x) - - with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"): - jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - self.skipTest("Checkify NYI") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([(128 * x) - 1 for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sin(x_ref[...]) - - def invoke_kernel(x): - return pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), - grid=(1, 5), - interpret=False, - )(x) - - with self.assertRaisesRegex( - ValueError, - "Ragged input shape must be evenly divisble by the grid" # noqa: W605 - " size at the ragged dimension 2", - ): - jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - -class PallasCallNamedGridInterpretTest(PallasCallRaggedVmapTest): - INTERPRET = True - - -if __name__ == "__main__": - absltest.main()