Skip to content

Commit

Permalink
Finalize deprecation of jax.random.shuffle
Browse files Browse the repository at this point in the history
This has been raising a DeprecationWarning for longer than anyone can remember.

PiperOrigin-RevId: 656765001
  • Loading branch information
Jake VanderPlas authored and jax authors committed Jul 27, 2024
1 parent dab15d6 commit a17c8d9
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 22 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ Remember to align the itemized text with the first line of an item within a list
or `enable_xla=False` is now deprecated and this support will be removed in
a future version.
Native serialization has been the default since JAX 0.4.16 (September 2023).
* The previously-deprecated function `jax.random.shuffle` has been removed;
instead use `jax.random.permutation` with `independent=True`.

## jaxlib 0.4.31

Expand Down
18 changes: 0 additions & 18 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,24 +516,6 @@ def _randint(key, shape, minval, maxval, dtype) -> Array:
return lax.add(minval, lax.convert_element_type(random_offset, dtype))


def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array:
"""Shuffle the elements of an array uniformly at random along an axis.
Args:
key: a PRNG key used as the random key.
x: the array to be shuffled.
axis: optional, an int axis along which to shuffle (default 0).
Returns:
A shuffled version of x.
"""
msg = ("jax.random.shuffle is deprecated and will be removed in a future release. "
"Use jax.random.permutation with independent=True.")
warnings.warn(msg, FutureWarning)
key, _ = _check_prng_key("shuffle", key)
return _shuffle(key, x, axis)


def permutation(key: KeyArrayLike,
x: int | ArrayLike,
axis: int = 0,
Expand Down
7 changes: 3 additions & 4 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@
randint as randint,
random_gamma_p as random_gamma_p,
rayleigh as rayleigh,
shuffle as _deprecated_shuffle,
split as split,
t as t,
triangular as triangular,
Expand All @@ -254,16 +253,16 @@
)

_deprecations = {
# Added November 6, 2023; but has been raising a FutureWarning since JAX 0.1.66
# Finalized Jul 26 2024; remove after Nov 2024.
"shuffle": (
"jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.",
_deprecated_shuffle,
None,
)
}

import typing
if typing.TYPE_CHECKING:
shuffle = _deprecated_shuffle
pass
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
Expand Down

0 comments on commit a17c8d9

Please sign in to comment.