diff --git a/jax/experimental/key_reuse/__init__.py b/jax/experimental/key_reuse/__init__.py index 6b330975a109..874f184461f4 100644 --- a/jax/experimental/key_reuse/__init__.py +++ b/jax/experimental/key_reuse/__init__.py @@ -18,25 +18,26 @@ This module contains **experimental** functionality for detecting re-use of random keys within JAX programs. It is under active development and the APIs here are likely -to change. +to change. The usage below requires JAX version 0.4.26 or newer. -Key reuse checking can be enabled on `jit`-compiled functions using the -:func:`jax.enable_key_reuse_checks` configuration:: +Key reuse checking can be enabled using the `jax_enable_key_reuse_checks` configuration:: >>> import jax - >>> @jax.jit - ... def f(key): - ... return jax.random.uniform(key) + jax.random.normal(key) - ... + >>> jax.config.update('jax_enable_key_reuse_checks', True) >>> key = jax.random.key(0) - >>> with jax.enable_key_reuse_checks(): - ... f(key) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> jax.random.normal(key) + Array(-0.20584226, dtype=float32) + >>> jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - KeyReuseError: In random_bits, key values a are already consumed. + KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0 -This flag can also be set globally if you wish to enagle key reuse checks in -every JIT-compiled function. +This flag can also be controlled locally using the :func:`jax.enable_key_reuse_checks` +context manager:: + + >>> with jax.enable_key_reuse_checks(False): + ... print(jax.random.normal(key)) + -0.20584226 """ from jax._src.prng import ( reuse_key as reuse_key,