From f30ebd8586a077af3c71733ca3177d219787363f Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Fri, 26 Jul 2024 06:56:26 -0700 Subject: [PATCH] document vmap peculiarity of experimental RNG implementations --- jax/random.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/jax/random.py b/jax/random.py index 1cda2a1a91df..5ced19dbbede 100644 --- a/jax/random.py +++ b/jax/random.py @@ -147,6 +147,16 @@ tested. The name emphasizes "unsafe" because key derivation quality and generation quality are not well understood. + Additionally, both ``"rbg"`` and ``"unsafe_rbg"`` behave unusually + under ``jax.vmap``. When vmapping a random function over a batch + of keys, its output values can differ from its true map over the + same keys. Instead, under ``vmap``, the entire batch of output + random numbers is generated from only the first key in the input + key batch. For example, if ``keys`` is a vector of 8 keys, then + ``jax.vmap(jax.random.normal)(keys)`` equals + ``jax.random.normal(keys[0], shape=(8,))``. This peculiarity + reflects a workaround to XLA RBG's limited batching support. + Reasons to use an alternative to the default RNG include that: 1. It may be slow to compile for TPUs. @@ -164,7 +174,8 @@ flag ``--xla_tpu_spmd_rng_bit_generator_unsafe=1``. The XLA flag can be set using an the ``XLA_FLAGS`` environment -variable, e.g. as ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``. +variable, e.g. as +``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``. For more about ``jax_threefry_partitionable``, see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers @@ -181,7 +192,7 @@ efficiently shardable (w/ pjit) ✅ ✅ ✅ identical across shardings ✅ ✅ ✅ ✅ identical across CPU/GPU/TPU ✅ ✅ - identical across JAX/XLA versions ✅ ✅ + exact ``jax.vmap`` over keys ✅ ✅ ================================= ======== ========= === ========== ===== ============ (*): with ``jax_threefry_partitionable=1`` set