diff --git a/jax/experimental/key_reuse/__init__.py b/jax/experimental/key_reuse/__init__.py index 874f184461f4..72d9a861eacf 100644 --- a/jax/experimental/key_reuse/__init__.py +++ b/jax/experimental/key_reuse/__init__.py @@ -43,6 +43,6 @@ reuse_key as reuse_key, ) -from jax.experimental.key_reuse._common import ( +from jax.experimental.key_reuse._core import ( KeyReuseError as KeyReuseError, ) diff --git a/jax/experimental/key_reuse/_common.py b/jax/experimental/key_reuse/_common.py deleted file mode 100644 index e0d20d269f75..000000000000 --- a/jax/experimental/key_reuse/_common.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import NamedTuple -from jax import core -from jax.interpreters import batching, mlir -from jax._src import prng -import numpy as np - - -class Sink(NamedTuple): - idx: int - mask: bool | np.ndarray = True - - def __repr__(self): - if isinstance(self.mask, bool) and self.mask: - return f"Sink({self.idx})" - else: - return f"Sink({self.idx}, mask={self.mask})" - - -class Source(NamedTuple): - idx: int - mask: bool | np.ndarray = True - - def __repr__(self): - if isinstance(self.mask, bool) and self.mask: - return f"Source({self.idx})" - else: - return f"Source({self.idx}, mask={self.mask})" - -class Forward(NamedTuple): - in_idx: int - out_idx: int - - -class KeyReuseSignature(NamedTuple): - sinks: list[Sink] - sources: list[Source] - forwards: list[Forward] = [] - - def check_signature(self, *args, jaxpr=None): - for sink in self.sinks: - if not isinstance(args[sink.idx], prng.PRNGKeyArray): - continue - if np.any(args[sink.idx]._consumed & sink.mask): - msg = f"Previously-consumed key at index {sink.idx} passed to function" - if jaxpr: - msg += f"\n{jaxpr=}" - raise KeyReuseError(msg) - - def update_consumption(self, args_in, args_out): - for sink in self.sinks: - arg = args_in[sink.idx] - if isinstance(arg, prng.PRNGKeyArray): - arg._consumed = arg._consumed | sink.mask - for arg in args_out: - if isinstance(arg, prng.PRNGKeyArray): - arg._consumed = True - for source in self.sources: - if isinstance(args_out[source.idx], prng.PRNGKeyArray): - args_out[source.idx]._consumed = ~np.asarray(source.mask) - for forward in self.forwards: - arg_in = args_in[forward.in_idx] - arg_out = args_out[forward.out_idx] - if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray): - arg_out._consumed = arg_in._consumed - - -class KeyReuseError(RuntimeError): - pass - -consume_p = core.Primitive("consume") -consume_p.def_impl(lambda x: x) -consume_p.def_abstract_eval(lambda x: x) -batching.defvectorized(consume_p) -mlir.register_lowering( - consume_p, - mlir.lower_fun(lambda x: x, multiple_results=False)) - -def consume(key): - """Consume the key and return a consumed copy.""" - return consume_p.bind(key) - - -assert_consumed_value_p = core.Primitive("assert_consumed_value") -assert_consumed_value_p.def_impl(lambda x, *, value: x) -assert_consumed_value_p.def_abstract_eval(lambda x, *, value: x) -batching.defvectorized(assert_consumed_value_p) -mlir.register_lowering( - assert_consumed_value_p, - mlir.lower_fun(lambda x, *, value: x, multiple_results=False)) - -def assert_unconsumed(key): - """Assert that a key is unconsumed""" - assert_consumed_value_p.bind(key, value=False) - -def assert_consumed(key, value=True): - """Assert that a key is consumed""" - assert_consumed_value_p.bind(key, value=value) diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index b015696e07ee..b67ce28aef9e 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -16,17 +16,17 @@ from collections import defaultdict from functools import partial, reduce, wraps -from typing import Any, Callable +from typing import Any, Callable, NamedTuple import jax from jax import lax from jax import tree_util +from jax.interpreters import batching, mlir from jax._src import api_util from jax._src import config from jax._src import core from jax._src import linear_util as lu from jax._src import pjit -from jax._src import pretty_printer as pp from jax._src import prng from jax._src import random from jax._src import util @@ -35,14 +35,102 @@ from jax._src.interpreters import partial_eval as pe from jax._src.util import weakref_lru_cache -from jax.experimental.key_reuse._common import ( - consume_p, assert_consumed_value_p, KeyReuseError, - Sink, Source, Forward, KeyReuseSignature -) from jax.experimental.shard_map import shard_map_p import numpy as np +class Sink(NamedTuple): + idx: int + mask: bool | np.ndarray = True + + def __repr__(self): + if isinstance(self.mask, bool) and self.mask: + return f"Sink({self.idx})" + else: + return f"Sink({self.idx}, mask={self.mask})" + + +class Source(NamedTuple): + idx: int + mask: bool | np.ndarray = True + + def __repr__(self): + if isinstance(self.mask, bool) and self.mask: + return f"Source({self.idx})" + else: + return f"Source({self.idx}, mask={self.mask})" + +class Forward(NamedTuple): + in_idx: int + out_idx: int + + +class KeyReuseSignature(NamedTuple): + sinks: list[Sink] + sources: list[Source] + forwards: list[Forward] = [] + + def check_signature(self, *args, funcname="function", context=None): + for sink in self.sinks: + if not isinstance(args[sink.idx], prng.PRNGKeyArray): + continue + if np.any(args[sink.idx]._consumed & sink.mask): + msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}" + if context: + msg += " {context}" + raise KeyReuseError(msg) + + def update_consumption(self, args_in, args_out): + for sink in self.sinks: + arg = args_in[sink.idx] + if isinstance(arg, prng.PRNGKeyArray): + arg._consumed = arg._consumed | sink.mask + for arg in args_out: + if isinstance(arg, prng.PRNGKeyArray): + arg._consumed = True + for source in self.sources: + if isinstance(args_out[source.idx], prng.PRNGKeyArray): + args_out[source.idx]._consumed = ~np.asarray(source.mask) + for forward in self.forwards: + arg_in = args_in[forward.in_idx] + arg_out = args_out[forward.out_idx] + if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray): + arg_out._consumed = arg_in._consumed + + +class KeyReuseError(RuntimeError): + pass + +consume_p = core.Primitive("consume") +consume_p.def_impl(lambda x: x) +consume_p.def_abstract_eval(lambda x: x) +batching.defvectorized(consume_p) +mlir.register_lowering( + consume_p, + mlir.lower_fun(lambda x: x, multiple_results=False)) + +def consume(key): + """Consume the key and return a consumed copy.""" + return consume_p.bind(key) + + +assert_consumed_value_p = core.Primitive("assert_consumed_value") +assert_consumed_value_p.def_impl(lambda x, *, value: x) +assert_consumed_value_p.def_abstract_eval(lambda x, *, value: x) +batching.defvectorized(assert_consumed_value_p) +mlir.register_lowering( + assert_consumed_value_p, + mlir.lower_fun(lambda x, *, value: x, multiple_results=False)) + +def assert_unconsumed(key): + """Assert that a key is unconsumed""" + assert_consumed_value_p.bind(key, value=False) + +def assert_consumed(key, value=True): + """Assert that a key is consumed""" + assert_consumed_value_p.bind(key, value=value) + + def _check_consumed_value(eqn, consumed): """Extra check for use with assert_consumed_value_p""" expected = eqn.params['value'] @@ -341,17 +429,20 @@ def key_reuse_impl_rule(prim, original_rule): def key_reuse_impl(*args, **kwargs): if config.enable_key_reuse_checks.value: if prim == pjit.pjit_p: + funcname = "jit-compiled function" jaxpr = kwargs['jaxpr'].jaxpr signature = get_jaxpr_type_signature(jaxpr) elif prim in key_reuse_signatures: - jaxpr = prim + funcname = str(prim) + jaxpr = None signature = key_reuse_signatures[prim] elif prim in key_reuse_signatures_dynamic: + funcname = str(prim) jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr signature = get_jaxpr_type_signature(jaxpr) else: raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}") - signature.check_signature(*args, jaxpr=jaxpr) + signature.check_signature(*args, funcname=funcname) result = original_rule(*args, **kwargs) signature.update_consumption(args, result if prim.multiple_results else [result]) return result diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index dac015ad42c8..781a2f44be61 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -22,7 +22,7 @@ import jax.numpy as jnp from jax._src import prng from jax._src import test_util as jtu -from jax.experimental.key_reuse._common import ( +from jax.experimental.key_reuse._core import ( assert_consumed, assert_unconsumed, consume, consume_p) from jax.experimental.key_reuse import _core, KeyReuseError @@ -587,28 +587,29 @@ def f_good(x, key): class KeyReuseEager(jtu.JaxTestCase): - jit_msg = "Previously-consumed key at index 0 passed to function" - bits_msg = "In random_bits, key values a are already consumed." + jit_msg = "Previously-consumed key passed to jit-compiled function at index 0" + eager_bits_msg = "Previously-consumed key passed to random_bits at index 0" + traced_bits_msg = "In random_bits, key values a are already consumed." def test_simple_reuse_nojit(self): key = jax.random.key(0) _ = jax.random.bits(key) with jax.disable_jit(): - with self.assertRaisesRegex(KeyReuseError, self.jit_msg): + with self.assertRaisesRegex(KeyReuseError, self.eager_bits_msg): _ = jax.random.bits(key) def test_simple_key_reuse_jit(self): key = jax.random.key(0) _ = jax.random.bits(key) with self.assertRaisesRegex(KeyReuseError, self.jit_msg): - _ = jax.random.bits(key) + _ = jax.jit(jax.random.bits)(key) def test_key_reuse_within_jit(self): @jax.jit def f(): key = jax.random.key(0) return jax.random.bits(key) + jax.random.bits(key) - with self.assertRaisesRegex(KeyReuseError, self.bits_msg): + with self.assertRaisesRegex(KeyReuseError, self.traced_bits_msg): f()