Skip to content

Commit

Permalink
document vmap peculiarity of experimental RNG implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Jul 26, 2024
1 parent 6ddd488 commit f30ebd8
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@
tested. The name emphasizes "unsafe" because key derivation
quality and generation quality are not well understood.
Additionally, both ``"rbg"`` and ``"unsafe_rbg"`` behave unusually
under ``jax.vmap``. When vmapping a random function over a batch
of keys, its output values can differ from its true map over the
same keys. Instead, under ``vmap``, the entire batch of output
random numbers is generated from only the first key in the input
key batch. For example, if ``keys`` is a vector of 8 keys, then
``jax.vmap(jax.random.normal)(keys)`` equals
``jax.random.normal(keys[0], shape=(8,))``. This peculiarity
reflects a workaround to XLA RBG's limited batching support.
Reasons to use an alternative to the default RNG include that:
1. It may be slow to compile for TPUs.
Expand All @@ -164,7 +174,8 @@
flag ``--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
The XLA flag can be set using an the ``XLA_FLAGS`` environment
variable, e.g. as ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
variable, e.g. as
``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``.
For more about ``jax_threefry_partitionable``, see
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
Expand All @@ -181,7 +192,7 @@
efficiently shardable (w/ pjit) ✅ ✅ ✅
identical across shardings ✅ ✅ ✅ ✅
identical across CPU/GPU/TPU ✅ ✅
identical across JAX/XLA versions ✅ ✅
exact ``jax.vmap`` over keys ✅ ✅
================================= ======== ========= === ========== ===== ============
(*): with ``jax_threefry_partitionable=1`` set
Expand Down

0 comments on commit f30ebd8

Please sign in to comment.