From 6ddd488df09d3840cdf1af45eb0b7dc373749943 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Fri, 26 Jul 2024 06:40:33 -0700 Subject: [PATCH 1/2] improve RNG doc around implementation configuration --- jax/random.py | 74 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/jax/random.py b/jax/random.py index f951fce406a8..1cda2a1a91df 100644 --- a/jax/random.py +++ b/jax/random.py @@ -117,31 +117,59 @@ ========================== JAX provides several PRNG implementations. A specific one can be -selected with the optional `impl` keyword argument to -`jax.random.key`. When no `impl` option is passed to the `key` +selected with the optional ``impl`` keyword argument to +``jax.random.key``. When no ``impl`` option is passed to the ``key`` constructor, the implementation is determined by the global -`jax_default_prng_impl` configuration flag. +``jax_default_prng_impl`` configuration flag. The string names of +available implementations are: -- **default**, `"threefry2x32"`: - `A counter-based PRNG built around the Threefry hash function `_. -- *experimental* A PRNG that thinly wraps the XLA Random Bit Generator (RBG) algorithm. See - `TF doc `_. +- ``"threefry2x32"`` (**default**): + A counter-based PRNG based on a variant of the Threefry hash function, + as described in `this paper by Salmon et al., 2011 + `_. - - `"rbg"` uses ThreeFry for splitting, and XLA RBG for data generation. - - `"unsafe_rbg"` exists only for demonstration purposes, using RBG both for - splitting (using an untested made up algorithm) and generating. +- ``"rbg"`` and ``"unsafe_rbg"`` (**experimental**): PRNGs built atop + `XLA's Random Bit Generator (RBG) algorithm + `_. - The random streams generated by these experimental implementations haven't - been subject to any empirical randomness testing (e.g. Big Crush). The - random bits generated may change between JAX versions. + - ``"rbg"`` uses XLA RBG for random number generation, whereas for + key derivation (as in ``jax.random.split`` and + ``jax.random.fold_in``) it uses the same method as + ``"threefry2x32"``. -The possible reasons not use the default RNG are: + - ``"unsafe_rbg"`` uses XLA RBG for both generation as well as key + derivation. -1. it may be slow to compile (specifically for Google Cloud TPUs) -2. it's slower to execute on TPUs -3. it doesn't support efficient automatic sharding / partitioning + Random numbers generated by these experimental schemes have not + been subject to empirical randomness testing (e.g. BigCrush). -Here is a short summary: + Key derivation in ``"unsafe_rbg"`` has also not been empirically + tested. The name emphasizes "unsafe" because key derivation + quality and generation quality are not well understood. + +Reasons to use an alternative to the default RNG include that: + +1. It may be slow to compile for TPUs. +2. It is relatively slower to execute on TPUs. + +**Automatic partitioning:** + +In order for ``jax.jit`` to efficiently auto-partition functions that +generate sharded random number arrays (or key arrays), all PRNG +implementations require extra flags: + +- For ``"threefry2x32"``, and ``"rbg"`` key derivation, set + ``jax_threefry_partitionable=True``. +- For ``"unsafe_rbg"``, and ``"rbg"`` random generation", set the XLA + flag ``--xla_tpu_spmd_rng_bit_generator_unsafe=1``. + +The XLA flag can be set using an the ``XLA_FLAGS`` environment +variable, e.g. as ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``. + +For more about ``jax_threefry_partitionable``, see +https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers + +**Summary:** .. table:: :widths: auto @@ -159,16 +187,6 @@ (*): with ``jax_threefry_partitionable=1`` set (**): with ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1`` set - -The difference between "rbg" and "unsafe_rbg" is that while "rbg" uses a less -robust/studied hash function for random value generation (but not for -`jax.random.split` or `jax.random.fold_in`), "unsafe_rbg" additionally uses less -robust hash functions for `jax.random.split` and `jax.random.fold_in`. Therefore -less safe in the sense that the quality of random streams it generates from -different keys is less well understood. - -For more about `jax_threefry_partitionable`, see -https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers """ # Note: import as is required for names to be exported. From f30ebd8586a077af3c71733ca3177d219787363f Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Fri, 26 Jul 2024 06:56:26 -0700 Subject: [PATCH 2/2] document vmap peculiarity of experimental RNG implementations --- jax/random.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/jax/random.py b/jax/random.py index 1cda2a1a91df..5ced19dbbede 100644 --- a/jax/random.py +++ b/jax/random.py @@ -147,6 +147,16 @@ tested. The name emphasizes "unsafe" because key derivation quality and generation quality are not well understood. + Additionally, both ``"rbg"`` and ``"unsafe_rbg"`` behave unusually + under ``jax.vmap``. When vmapping a random function over a batch + of keys, its output values can differ from its true map over the + same keys. Instead, under ``vmap``, the entire batch of output + random numbers is generated from only the first key in the input + key batch. For example, if ``keys`` is a vector of 8 keys, then + ``jax.vmap(jax.random.normal)(keys)`` equals + ``jax.random.normal(keys[0], shape=(8,))``. This peculiarity + reflects a workaround to XLA RBG's limited batching support. + Reasons to use an alternative to the default RNG include that: 1. It may be slow to compile for TPUs. @@ -164,7 +174,8 @@ flag ``--xla_tpu_spmd_rng_bit_generator_unsafe=1``. The XLA flag can be set using an the ``XLA_FLAGS`` environment -variable, e.g. as ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``. +variable, e.g. as +``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``. For more about ``jax_threefry_partitionable``, see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers @@ -181,7 +192,7 @@ efficiently shardable (w/ pjit) ✅ ✅ ✅ identical across shardings ✅ ✅ ✅ ✅ identical across CPU/GPU/TPU ✅ ✅ - identical across JAX/XLA versions ✅ ✅ + exact ``jax.vmap`` over keys ✅ ✅ ================================= ======== ========= === ========== ===== ============ (*): with ``jax_threefry_partitionable=1`` set