Skip to content

Commit

Permalink
Merge pull request #22684 from froystig:rngdoc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 656600958
  • Loading branch information
jax authors committed Jul 27, 2024
2 parents 40d569b + f30ebd8 commit dab15d6
Showing 1 changed file with 58 additions and 29 deletions.
87 changes: 58 additions & 29 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,31 +117,70 @@
==========================
JAX provides several PRNG implementations. A specific one can be
selected with the optional `impl` keyword argument to
`jax.random.key`. When no `impl` option is passed to the `key`
selected with the optional ``impl`` keyword argument to
``jax.random.key``. When no ``impl`` option is passed to the ``key``
constructor, the implementation is determined by the global
`jax_default_prng_impl` configuration flag.
``jax_default_prng_impl`` configuration flag. The string names of
available implementations are:
- **default**, `"threefry2x32"`:
`A counter-based PRNG built around the Threefry hash function <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
- *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See
`TF doc <https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator>`_.
- ``"threefry2x32"`` (**default**):
A counter-based PRNG based on a variant of the Threefry hash function,
as described in `this paper by Salmon et al., 2011
<http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_.
- `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation.
- `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for
splitting (using an untested made up algorithm) and generating.
- ``"rbg"`` and ``"unsafe_rbg"`` (**experimental**): PRNGs built atop
`XLA's Random Bit Generator (RBG) algorithm
<https://openxla.org/xla/operation_semantics#rngbitgenerator>`_.
The random streams generated by these experimental implementations haven't
been subject to any empirical randomness testing (e.g. Big Crush). The
random bits generated may change between JAX versions.
- ``"rbg"`` uses XLA RBG for random number generation, whereas for
key derivation (as in ``jax.random.split`` and
``jax.random.fold_in``) it uses the same method as
``"threefry2x32"``.
The possible reasons not use the default RNG are:
- ``"unsafe_rbg"`` uses XLA RBG for both generation as well as key
derivation.
1. it may be slow to compile (specifically for Google Cloud TPUs)
2. it's slower to execute on TPUs
3. it doesn't support efficient automatic sharding / partitioning
Random numbers generated by these experimental schemes have not
been subject to empirical randomness testing (e.g. BigCrush).
Here is a short summary:
Key derivation in ``"unsafe_rbg"`` has also not been empirically
tested. The name emphasizes "unsafe" because key derivation
quality and generation quality are not well understood.
Additionally, both ``"rbg"`` and ``"unsafe_rbg"`` behave unusually
under ``jax.vmap``. When vmapping a random function over a batch
of keys, its output values can differ from its true map over the
same keys. Instead, under ``vmap``, the entire batch of output
random numbers is generated from only the first key in the input
key batch. For example, if ``keys`` is a vector of 8 keys, then
``jax.vmap(jax.random.normal)(keys)`` equals
``jax.random.normal(keys[0], shape=(8,))``. This peculiarity
reflects a workaround to XLA RBG's limited batching support.
Reasons to use an alternative to the default RNG include that:
1. It may be slow to compile for TPUs.
2. It is relatively slower to execute on TPUs.
**Automatic partitioning:**
In order for ``jax.jit`` to efficiently auto-partition functions that
generate sharded random number arrays (or key arrays), all PRNG
implementations require extra flags:
- For ``"threefry2x32"``, and ``"rbg"`` key derivation, set
``jax_threefry_partitionable=True``.
- For ``"unsafe_rbg"``, and ``"rbg"`` random generation", set the XLA
flag ``--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
The XLA flag can be set using an the ``XLA_FLAGS`` environment
variable, e.g. as
``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
For more about ``jax_threefry_partitionable``, see
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
**Summary:**
.. table::
:widths: auto
Expand All @@ -153,22 +192,12 @@
efficiently shardable (w/ pjit) ✅ ✅ ✅
identical across shardings ✅ ✅ ✅ ✅
identical across CPU/GPU/TPU ✅ ✅
identical across JAX/XLA versions ✅ ✅
exact ``jax.vmap`` over keys ✅ ✅
================================= ======== ========= === ========== ===== ============
(*): with ``jax_threefry_partitionable=1`` set
(**): with ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1`` set
The difference between "rbg" and "unsafe_rbg" is that while "rbg" uses a less
robust/studied hash function for random value generation (but not for
`jax.random.split` or `jax.random.fold_in`), "unsafe_rbg" additionally uses less
robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore
less safe in the sense that the quality of random streams it generates from
different keys is less well understood.
For more about `jax_threefry_partitionable`, see
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
"""

# Note: import <name> as <name> is required for names to be exported.
Expand Down

0 comments on commit dab15d6

Please sign in to comment.