From 391a3083d48220da5f25a3a0b426863ab65f57f5 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 8 Apr 2024 13:46:31 -0700 Subject: [PATCH] Finalize the deprecation of the arr.device() method The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context. PiperOrigin-RevId: 622940937 --- CHANGELOG.md | 2 ++ jax/_src/array.py | 17 ----------------- jax/experimental/array_api/_array_methods.py | 8 ++++---- tests/random_test.py | 4 ---- 4 files changed, 6 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b287de3d0f4..b8acd30c627b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list * {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and `a_max` are deprecated in favor of `x` (positonal only), `min`, and `max` ({jax-issue}`20550`). + * The `device()` method of JAX arrays has been removed, after being deprecated + since JAX v0.4.21. Use `arr.devices()` instead. ## jaxlib 0.4.27 diff --git a/jax/_src/array.py b/jax/_src/array.py index 21f115a8043f..1c24142981ca 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -30,7 +30,6 @@ from jax._src import basearray from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import errors @@ -50,7 +49,6 @@ from jax._src.typing import ArrayLike from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method -deprecations.register(__name__, "device-method") Shape = tuple[int, ...] Device = xc.Device @@ -471,21 +469,6 @@ def on_device_size_in_bytes(self): per_shard_size = arr.on_device_size_in_bytes() # type: ignore return per_shard_size * len(self.sharding.device_set) - # TODO(yashkatariya): Remove this method when everyone is using devices(). - def device(self) -> Device: - if deprecations.is_accelerated(__name__, "device-method"): - raise NotImplementedError("arr.device() is deprecated. Use arr.devices() instead.") - else: - warnings.warn("arr.device() is deprecated. Use arr.devices() instead.", - DeprecationWarning, stacklevel=2) - self._check_if_deleted() - device_set = self.sharding.device_set - if len(device_set) == 1: - single_device, = device_set - return single_device - raise ValueError('Length of devices is greater than 1. ' - 'Please use `.devices()`.') - def devices(self) -> set[Device]: self._check_if_deleted() return self.sharding.device_set diff --git a/jax/experimental/array_api/_array_methods.py b/jax/experimental/array_api/_array_methods.py index 5a73b8a2fe1a..ca5bca356258 100644 --- a/jax/experimental/array_api/_array_methods.py +++ b/jax/experimental/array_api/_array_methods.py @@ -19,6 +19,7 @@ import jax from jax._src.array import ArrayImpl from jax.experimental.array_api._version import __array_api_version__ +from jax.sharding import Sharding from jax._src.lib import xla_extension as xe @@ -30,16 +31,15 @@ def _array_namespace(self, /, *, api_version: None | str = None): return jax.experimental.array_api -def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *, +def _to_device(self, device: xe.Device | Sharding | None, *, stream: int | Any | None = None): if stream is not None: raise NotImplementedError("stream argument of array.to_device()") - # The type of device is defined by Array.device. In JAX, this is a callable that - # returns a device, so we must handle this case to satisfy the API spec. - return jax.device_put(self, device() if callable(device) else device) + return jax.device_put(self, device) def add_array_object_methods(): # TODO(jakevdp): set on tracers as well? setattr(ArrayImpl, "__array_namespace__", _array_namespace) setattr(ArrayImpl, "to_device", _to_device) + setattr(ArrayImpl, "device", property(lambda self: self.sharding)) diff --git a/tests/random_test.py b/tests/random_test.py index 446834644a59..02c731c11340 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -33,7 +33,6 @@ from jax import random from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax import vmap @@ -1019,9 +1018,6 @@ def test_array_impl_attributes(self): self.assertEqual(key.is_fully_addressable, key._base_array.is_fully_addressable) self.assertEqual(key.is_fully_replicated, key._base_array.is_fully_replicated) - if not deprecations.is_accelerated('jax._src.array', 'device-method'): - with jtu.ignore_warning(category=DeprecationWarning, message="arr.device"): - self.assertEqual(key.device(), key._base_array.device()) self.assertEqual(key.devices(), key._base_array.devices()) self.assertEqual(key.on_device_size_in_bytes(), key._base_array.on_device_size_in_bytes())