Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[key reuse] improve module docs #20071

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions jax/experimental/key_reuse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,26 @@

This module contains **experimental** functionality for detecting re-use of random
keys within JAX programs. It is under active development and the APIs here are likely
to change.
to change. The usage below requires JAX version 0.4.26 or newer.

Key reuse checking can be enabled on `jit`-compiled functions using the
:func:`jax.enable_key_reuse_checks` configuration::
Key reuse checking can be enabled using the `jax_enable_key_reuse_checks` configuration::

>>> import jax
>>> @jax.jit
... def f(key):
... return jax.random.uniform(key) + jax.random.normal(key)
...
>>> jax.config.update('jax_enable_key_reuse_checks', True)
>>> key = jax.random.key(0)
>>> with jax.enable_key_reuse_checks():
... f(key) # doctest: +IGNORE_EXCEPTION_DETAIL
>>> jax.random.normal(key)
Array(-0.20584226, dtype=float32)
>>> jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
KeyReuseError: In random_bits, key values a are already consumed.
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0

This flag can also be set globally if you wish to enagle key reuse checks in
every JIT-compiled function.
This flag can also be controlled locally using the :func:`jax.enable_key_reuse_checks`
context manager::

>>> with jax.enable_key_reuse_checks(False):
... print(jax.random.normal(key))
-0.20584226
"""
from jax._src.prng import (
reuse_key as reuse_key,
Expand Down