From a17c8d945b20aceb6e4f1a32dc27bee8d79ec900 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Sat, 27 Jul 2024 11:21:02 -0700 Subject: [PATCH] Finalize deprecation of jax.random.shuffle This has been raising a DeprecationWarning for longer than anyone can remember. PiperOrigin-RevId: 656765001 --- CHANGELOG.md | 2 ++ jax/_src/random.py | 18 ------------------ jax/random.py | 7 +++---- 3 files changed, 5 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 52d3ab5b853f..a94d1e90e9a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,8 @@ Remember to align the itemized text with the first line of an item within a list or `enable_xla=False` is now deprecated and this support will be removed in a future version. Native serialization has been the default since JAX 0.4.16 (September 2023). + * The previously-deprecated function `jax.random.shuffle` has been removed; + instead use `jax.random.permutation` with `independent=True`. ## jaxlib 0.4.31 diff --git a/jax/_src/random.py b/jax/_src/random.py index ac7d475dd382..113bcc450100 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -516,24 +516,6 @@ def _randint(key, shape, minval, maxval, dtype) -> Array: return lax.add(minval, lax.convert_element_type(random_offset, dtype)) -def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array: - """Shuffle the elements of an array uniformly at random along an axis. - - Args: - key: a PRNG key used as the random key. - x: the array to be shuffled. - axis: optional, an int axis along which to shuffle (default 0). - - Returns: - A shuffled version of x. - """ - msg = ("jax.random.shuffle is deprecated and will be removed in a future release. " - "Use jax.random.permutation with independent=True.") - warnings.warn(msg, FutureWarning) - key, _ = _check_prng_key("shuffle", key) - return _shuffle(key, x, axis) - - def permutation(key: KeyArrayLike, x: int | ArrayLike, axis: int = 0, diff --git a/jax/random.py b/jax/random.py index 5ced19dbbede..5c2eaf81f2bc 100644 --- a/jax/random.py +++ b/jax/random.py @@ -242,7 +242,6 @@ randint as randint, random_gamma_p as random_gamma_p, rayleigh as rayleigh, - shuffle as _deprecated_shuffle, split as split, t as t, triangular as triangular, @@ -254,16 +253,16 @@ ) _deprecations = { - # Added November 6, 2023; but has been raising a FutureWarning since JAX 0.1.66 + # Finalized Jul 26 2024; remove after Nov 2024. "shuffle": ( "jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.", - _deprecated_shuffle, + None, ) } import typing if typing.TYPE_CHECKING: - shuffle = _deprecated_shuffle + pass else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations)