From acd49d2c624e85a6b5ede99bbaeb893aad736237 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 16 Jan 2024 08:19:15 -0800 Subject: [PATCH] jax.random: improve error for batched keys --- jax/_src/random.py | 102 ++++++++++++++------------- jax/experimental/array_api/skips.txt | 2 +- tests/random_test.py | 2 +- 3 files changed, 55 insertions(+), 51 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index bf16cd80b335..469efd84de51 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -69,12 +69,14 @@ def _isnan(x: ArrayLike) -> Array: return lax.ne(x, x) -def _check_prng_key(key: KeyArrayLike) -> tuple[KeyArray, bool]: +def _check_prng_key(name: str, key: KeyArrayLike, allow_batched: bool = False) -> tuple[KeyArray, bool]: if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key): - return key, False + wrapped_key = key + wrapped = False elif _arraylike(key): # Call random_wrap here to surface errors for invalid keys. wrapped_key = prng.random_wrap(key, impl=default_prng_impl()) + wrapped = True if config.legacy_prng_key.value == 'error': raise ValueError( 'Legacy uint32 key array passed as key to jax.random function. ' @@ -91,10 +93,15 @@ def _check_prng_key(key: KeyArrayLike) -> tuple[KeyArray, bool]: 'Raw arrays as random keys to jax.random functions are deprecated. ' 'Assuming valid threefry2x32 key for now.', FutureWarning) - return wrapped_key, True else: raise TypeError(f'unexpected PRNG key type {type(key)}') + if (not allow_batched) and wrapped_key.ndim: + raise ValueError(f"{name} accepts a single key, but was given a key array of " + f"shape {np.shape(key)} != (). Use jax.vmap for batching.") + + return wrapped_key, wrapped + def _return_prng_keys(was_wrapped, key): # TODO(frostig): remove once we always enable_custom_prng @@ -245,10 +252,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: A new PRNG key that is a deterministic function of the inputs and is statistically safe for producing a stream of new pseudo-random values. """ - key, wrapped = _check_prng_key(key) - if np.ndim(key): - raise TypeError("fold_in accepts a single key, but was given a key array of" - f"shape {np.shape(key)} != (). Use jax.vmap for batching.") + key, wrapped = _check_prng_key("fold_in", key) if np.ndim(data): raise TypeError("fold_in accepts a scalar, but was given an array of" f"shape {np.shape(data)} != (). Use jax.vmap for batching.") @@ -262,7 +266,7 @@ def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray: # to always enable_custom_prng assert jnp.issubdtype(key.dtype, dtypes.prng_key) if key.ndim: - raise TypeError("split accepts a single key, but was given a key array of" + raise TypeError("split accepts a single key, but was given a key array of " f"shape {key.shape} != (). Use jax.vmap for batching.") shape = tuple(num) if isinstance(num, Sequence) else (num,) return prng.random_split(key, shape=shape) @@ -278,7 +282,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: Returns: An array-like object of `num` new PRNG keys. """ - typed_key, wrapped = _check_prng_key(key) + typed_key, wrapped = _check_prng_key("split", key) return _return_prng_keys(wrapped, _split(typed_key, num)) @@ -288,7 +292,7 @@ def _key_impl(keys: KeyArray) -> PRNGImpl: return keys_dtype._impl def key_impl(keys: KeyArrayLike) -> Hashable: - typed_keys, _ = _check_prng_key(keys) + typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True) return PRNGSpec(_key_impl(typed_keys)) @@ -298,7 +302,7 @@ def _key_data(keys: KeyArray) -> Array: def key_data(keys: KeyArrayLike) -> Array: """Recover the bits of key data underlying a PRNG key array.""" - keys, _ = _check_prng_key(keys) + keys, _ = _check_prng_key("key_data", keys, allow_batched=True) return _key_data(keys) @@ -350,7 +354,7 @@ def bits(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("bits") if dtype is None: dtype = dtypes.canonicalize_dtype(jnp.uint) else: @@ -383,7 +387,7 @@ def uniform(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("uniform", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): @@ -452,7 +456,7 @@ def randint(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("randint", key) dtypes.check_user_dtype_supported(dtype) dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) @@ -535,7 +539,7 @@ def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array: msg = ("jax.random.shuffle is deprecated and will be removed in a future release. " "Use jax.random.permutation with independent=True.") warnings.warn(msg, FutureWarning) - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("shuffle", key) return _shuffle(key, x, axis) # type: ignore @@ -556,7 +560,7 @@ def permutation(key: KeyArrayLike, Returns: A shuffled version of x or array range """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("permutation", key) check_arraylike("permutation", x) axis = canonicalize_axis(axis, np.ndim(x) or 1) if not np.ndim(x): @@ -630,7 +634,7 @@ def choice(key: KeyArrayLike, Returns: An array of shape `shape` containing samples from `a`. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("choice", key) if not isinstance(shape, Sequence): raise TypeError("shape argument of jax.random.choice must be a sequence, " f"got {shape}") @@ -697,7 +701,7 @@ def normal(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("normal", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, " @@ -764,7 +768,7 @@ def multivariate_normal(key: KeyArrayLike, ``shape + mean.shape[-1:]`` if ``shape`` is not None, or else ``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("multivariate_normal", key) dtypes.check_user_dtype_supported(dtype) mean, cov = promote_dtypes_inexact(mean, cov) if method not in {'svd', 'eigh', 'cholesky'}: @@ -843,7 +847,7 @@ def truncated_normal(key: KeyArrayLike, ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``. Returns values in the open interval ``(lower, upper)``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("truncated_normal", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `truncated_normal` must be a float " @@ -901,7 +905,7 @@ def bernoulli(key: KeyArrayLike, A random array with boolean dtype and shape given by ``shape`` if ``shape`` is not None, or else ``p.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("bernoulli", key) dtype = dtypes.canonicalize_dtype(lax.dtype(p)) if shape is not None: shape = core.as_named_shape(shape) @@ -952,7 +956,7 @@ def beta(key: KeyArrayLike, A random array with the specified dtype and shape given by ``shape`` if ``shape`` is not None, or else by broadcasting ``a`` and ``b``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("beta", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `beta` must be a float " @@ -1005,7 +1009,7 @@ def cauchy(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("cauchy", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `cauchy` must be a float " @@ -1057,7 +1061,7 @@ def dirichlet(key: KeyArrayLike, ``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else ``alpha.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("dirichlet", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `dirichlet` must be a float " @@ -1116,7 +1120,7 @@ def exponential(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("exponential", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `exponential` must be a float " @@ -1297,7 +1301,7 @@ def gamma(key: KeyArrayLike, loggamma : sample gamma values in log-space, which can provide improved accuracy for small values of ``a``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("gamma", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gamma` must be a float " @@ -1339,7 +1343,7 @@ def loggamma(key: KeyArrayLike, See Also: gamma : standard gamma sampler. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("loggamma", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gamma` must be a float " @@ -1475,7 +1479,7 @@ def poisson(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape is not None, or else by ``lam.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("poisson", key) dtypes.check_user_dtype_supported(dtype) # TODO(frostig): generalize underlying poisson implementation and # remove this check @@ -1515,7 +1519,7 @@ def gumbel(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("gumbel", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gumbel` must be a float " @@ -1550,7 +1554,7 @@ def categorical(key: KeyArrayLike, A random array with int dtype and shape given by ``shape`` if ``shape`` is not None, or else ``np.delete(logits.shape, axis)``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("categorical", key) check_arraylike("categorical", logits) logits_arr = jnp.asarray(logits) @@ -1593,7 +1597,7 @@ def laplace(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("laplace", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `laplace` must be a float " @@ -1630,7 +1634,7 @@ def logistic(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("logistic", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `logistic` must be a float " @@ -1673,7 +1677,7 @@ def pareto(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``b.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("pareto", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `pareto` must be a float " @@ -1722,7 +1726,7 @@ def t(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``df.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("t", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `t` must be a float " @@ -1775,7 +1779,7 @@ def chisquare(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``df.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("chisquare", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `chisquare` must be a float " @@ -1833,7 +1837,7 @@ def f(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``df.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("f", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `f` must be a float " @@ -1885,7 +1889,7 @@ def rademacher(key: KeyArrayLike, a 50% change of being 1 or -1. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("rademacher", key) dtypes.check_user_dtype_supported(dtype) dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) @@ -1921,7 +1925,7 @@ def maxwell(key: KeyArrayLike, """ # Generate samples using: # sqrt(X^2 + Y^2 + Z^2), X,Y,Z ~N(0,1) - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("maxwell", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `maxwell` must be a float " @@ -1964,7 +1968,7 @@ def double_sided_maxwell(key: KeyArrayLike, A jnp.array of samples. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("double_sided_maxwell", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `double_sided_maxwell` must be a float" @@ -2016,7 +2020,7 @@ def weibull_min(key: KeyArrayLike, A jnp.array of samples. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("weibull_min", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `weibull_min` must be a float " @@ -2055,7 +2059,7 @@ def orthogonal( Returns: A random array of shape `(*shape, n, n)` and specified dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("orthogonal", key) dtypes.check_user_dtype_supported(dtype) _check_shape("orthogonal", shape) n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()") @@ -2090,7 +2094,7 @@ def generalized_normal( Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("generalized_normal", key) dtypes.check_user_dtype_supported(dtype) _check_shape("generalized_normal", shape) keys = split(key) @@ -2120,7 +2124,7 @@ def ball( Returns: A random array of shape `(*shape, d)` and specified dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("ball", key) dtypes.check_user_dtype_supported(dtype) _check_shape("ball", shape) d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()") @@ -2158,7 +2162,7 @@ def rayleigh(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``scale.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("rayleigh", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `rayleigh` must be a float " @@ -2212,7 +2216,7 @@ def wald(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``mean.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("wald", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `wald` must be a float " @@ -2268,7 +2272,7 @@ def geometric(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``p.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("geometric", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.integer): raise ValueError("dtype argument to `geometric` must be an int " @@ -2330,7 +2334,7 @@ def triangular(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``left.shape``, ``mode.shape`` and ``right.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("triangular", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `triangular` must be a float " @@ -2384,7 +2388,7 @@ def lognormal(key: KeyArrayLike, Returns: A random array with the specified dtype and with shape given by ``shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("lognormal", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `lognormal` must be a float or complex dtype, " @@ -2597,7 +2601,7 @@ def binomial( A random array with the specified dtype and with shape given by ``np.broadcast(n, p).shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("binomial", key) check_arraylike("binomial", n, p) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): diff --git a/jax/experimental/array_api/skips.txt b/jax/experimental/array_api/skips.txt index 9a7557dd0369..c417e700f8d9 100644 --- a/jax/experimental/array_api/skips.txt +++ b/jax/experimental/array_api/skips.txt @@ -12,4 +12,4 @@ array_api_tests/test_array_object.py::test_setitem array_api_tests/test_creation_functions.py::test_asarray_arrays # fft test suite is buggy as of 83f0bcdc -array_api_tests/test_fft.py +# array_api_tests/test_fft.py diff --git a/tests/random_test.py b/tests/random_test.py index 3a7d95000579..3dd62b67c7a7 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -263,7 +263,7 @@ def random_bits(key, width, shape): # # Doing so doesn't work in width 64 at present due to # normalization in random.bits. - key, _ = jax_random._check_prng_key(key) + key, _ = jax_random._check_prng_key('random_bits', key) return jax_random._random_bits(key, width, shape) key = make_key(1701)