Skip to content

Commit

Permalink
Merge pull request #22304 from gnecula:pallas_io_alias_error
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 650226122
  • Loading branch information
jax authors committed Jul 8, 2024
2 parents df6080f + f960c28 commit d7bc1ac
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
24 changes: 23 additions & 1 deletion jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,8 @@ def pallas_call(
The default value for `out_specs` specifies the whole array,
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``.
input_output_aliases: a dictionary mapping the index of some inputs to
the index of the output that aliases them.
the index of the output that aliases them. These indices are in the
flattened inputs and outputs.
interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the
grid whose body is the kernel lowered as a JAX function. This does not
require a TPU or a GPU, and is the only way to run Pallas kernels on CPU.
Expand Down Expand Up @@ -1086,6 +1087,27 @@ def wrapped(*args):
raise ValueError(
"The kernel function in a pallas_call should return None. "
f"Found a PyTree: {f_out_tree}")
for i_idx, o_idx in input_output_aliases.items():
if i_idx not in range(len(flat_in_avals)):
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' with "
f"input index {i_idx} outside the range "
f"[0, {len(flat_in_avals)})")
if o_idx not in range(len(flat_out_avals)):
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' with "
f"output index {o_idx} outside the range "
f"[0, {len(flat_out_avals)})")
in_aval = flat_in_avals[i_idx]
out_aval = flat_out_avals[o_idx]
if in_aval.shape != out_aval.shape or in_aval.dtype != out_aval.dtype:
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
f"referring to input{tree_util.keystr(in_paths[i_idx])} with "
f"abstract value {in_aval} "
f"and to output{tree_util.keystr(out_paths[o_idx])} with "
f"a different abstract value {out_aval}.")

out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *consts, *flat_args,
jaxpr=jaxpr, name=name,
Expand Down
38 changes: 38 additions & 0 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,44 @@ def test_pallas_call_block_shape_ndim_mismatch(self):
"array shape"):
f(a)

def test_pallas_call_input_output_aliases_errors(self):
x = np.arange(8 * 128, dtype=np.int32).reshape((8, 128))

with self.assertRaisesRegex(
ValueError,
"input_output_aliases contains the mapping '2:0' with input index 2 "
"outside the range .*"):
self.pallas_call(lambda x_ref, y_ref, o1_ref: None,
out_shape=[x],
input_output_aliases={2: 0})(x, x)

with self.assertRaisesRegex(
ValueError,
"input_output_aliases contains the mapping '1:1' with output index 1 "
"outside the range .*"):
self.pallas_call(lambda x_ref, y_ref, o1_ref: None,
out_shape=[x],
input_output_aliases={1: 1})(x, x)

y = np.concatenate([x, x], axis=0)
with self.assertRaisesRegex(
ValueError,
"input_output_aliases contains the mapping '1:0' referring to "
"input\\[1\\] with abstract value .*int32\\[16,128\\].* "
"output\\[0\\] with a different abstract value .*int32\\[8,128\\]"):
self.pallas_call(lambda x_ref, y_ref, o1_ref: None,
out_shape=[x],
input_output_aliases={1: 0})(x, y)

with self.assertRaisesRegex(
ValueError,
"input_output_aliases contains the mapping '1:0' referring to "
"input\\[1\\] with abstract value .*int32\\[8,128\\].* "
"output\\[0\\] with a different abstract value .*float32\\[8,128\\]"):
self.pallas_call(lambda x_ref, y_ref, o1_ref: None,
out_shape=[jax.ShapeDtypeStruct(x.shape, jnp.float32)],
input_output_aliases={1: 0})(x, x)


class ApiErrorInterpreterTest(ApiErrorTest):
INTERPRET = True
Expand Down

0 comments on commit d7bc1ac

Please sign in to comment.