From f8e5f0b5452676f2479c4157256b2a0b0c31c345 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 23 Jul 2024 09:48:51 -0700 Subject: [PATCH] [array API] add device property & to_device method --- CHANGELOG.md | 2 ++ jax/_src/array.py | 7 ++++++ jax/_src/basearray.py | 10 ++++++++ jax/_src/basearray.pyi | 3 +++ jax/_src/core.py | 9 ++++++++ jax/_src/earray.py | 6 +++++ jax/_src/numpy/array_methods.py | 8 +++++++ jax/experimental/array_api/_array_methods.py | 9 -------- tests/array_test.py | 24 ++++++++++++++++++++ 9 files changed, 69 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 15c5633c62b1..c0c33282ebdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ Remember to align the itemized text with the first line of an item within a list will be removed in a future release. * Updated the repr of gpu devices to be more consistent with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`. + * Added the `device` property and `to_device` method to {class}`jax.Array`, as + part of JAX's [Array API](https://data-apis.org/array-api) support. * Deprecations * Removed a number of previously-deprecated internal APIs related to polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`, diff --git a/jax/_src/array.py b/jax/_src/array.py index 796d313ae8cb..c75e1cd667a5 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -254,6 +254,13 @@ def size(self): def sharding(self): return self._sharding + @property + def device(self): + self._check_if_deleted() + if isinstance(self.sharding, SingleDeviceSharding): + return list(self.sharding.device_set)[0] + return self.sharding + @property def weak_type(self): return self.aval.weak_type diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index 5809b9649f26..1cf8fa70d33c 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -22,6 +22,7 @@ from collections.abc import Sequence # TODO(jakevdp): fix import cycles and define these. +Device = Any Shard = Any Sharding = Any @@ -112,6 +113,15 @@ def is_fully_replicated(self) -> bool: def sharding(self) -> Sharding: """The sharding for the array.""" + @property + @abc.abstractmethod + def device(self) -> Device | Sharding: + """Array API-compatible device attribute. + + For single-device arrays, this returns a Device. For sharded arrays, this + returns a Sharding. + """ + Array.__module__ = "jax" diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index fbdd4894843e..8cc29f8c33c9 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -204,6 +204,8 @@ class Array(abc.ABC): @property def sharding(self) -> Sharding: ... @property + def device(self) -> Device | Sharding: ... + @property def addressable_shards(self) -> Sequence[Shard]: ... @property def global_shards(self) -> Sequence[Shard]: ... @@ -216,6 +218,7 @@ class Array(abc.ABC): @property def traceback(self) -> Traceback: ... def unsafe_buffer_pointer(self) -> int: ... + def to_device(self, device: Device | Sharding, *, stream: int | Any | None) -> Array: ... StaticScalar = Union[ diff --git a/jax/_src/core.py b/jax/_src/core.py index fea0ac6915e5..5ed1e1871cb2 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -738,6 +738,15 @@ def sharding(self): f"The 'sharding' attribute is not available on {self._error_repr()}." f"{self._origin_msg()}") + @property + def device(self): + # This attribute is part of the jax.Array API, but only defined on concrete arrays. + # Raising a ConcretizationTypeError would make sense, but for backward compatibility + # we raise an AttributeError so that hasattr() and getattr() work as expected. + raise AttributeError(self, + f"The 'device' attribute is not available on {self._error_repr()}." + f"{self._origin_msg()}") + @property def addressable_shards(self): raise ConcretizationTypeError(self, diff --git a/jax/_src/earray.py b/jax/_src/earray.py index f4b5e232bc33..36c8dc80c8ca 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -83,6 +83,12 @@ def sharding(self): phys_sharding = self._data.sharding return sharding_impls.logical_sharding(self.aval, phys_sharding) + @property + def device(self): + if isinstance(self._data.sharding, sharding_impls.SingleDeviceSharding): + return self._data.device + return self.sharding + # TODO(mattjj): not implemented below here, need more methods from ArrayImpl def addressable_data(self, index: int) -> EArray: diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 1d27c4b3aa28..515f245d11d3 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -32,6 +32,7 @@ import jax from jax import lax from jax.sharding import Sharding +from jax._src import api from jax._src import core from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple @@ -67,6 +68,12 @@ def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Dev """ return lax_numpy.astype(arr, dtype, copy=copy, device=device) +def _to_device(arr: ArrayLike, device: xc.Device | Sharding, *, + stream: int | Any | None = None): + if stream is not None: + raise NotImplementedError("stream argument of array.to_device()") + return api.device_put(arr, device) + def _nbytes(arr: ArrayLike) -> int: """Total bytes consumed by the elements of the array.""" @@ -694,6 +701,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, "sum": reductions.sum, "swapaxes": lax_numpy.swapaxes, "take": lax_numpy.take, + "to_device": _to_device, "trace": lax_numpy.trace, "transpose": _transpose, "var": reductions.var, diff --git a/jax/experimental/array_api/_array_methods.py b/jax/experimental/array_api/_array_methods.py index 2b071db573a8..4a1ed496311e 100644 --- a/jax/experimental/array_api/_array_methods.py +++ b/jax/experimental/array_api/_array_methods.py @@ -31,15 +31,6 @@ def _array_namespace(self, /, *, api_version: None | str = None): return jax.experimental.array_api -def _to_device(self, device: xe.Device | Sharding | None, *, - stream: int | Any | None = None): - if stream is not None: - raise NotImplementedError("stream argument of array.to_device()") - return jax.device_put(self, device) - - def add_array_object_methods(): # TODO(jakevdp): set on tracers as well? setattr(ArrayImpl, "__array_namespace__", _array_namespace) - setattr(ArrayImpl, "to_device", _to_device) - setattr(ArrayImpl, "device", property(lambda self: self.sharding)) diff --git a/tests/array_test.py b/tests/array_test.py index a8c119cfa82e..4ddba606b6e2 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1276,6 +1276,30 @@ def test_gspmd_sharding_hash_eq(self): self.assertEqual(x1, x2) self.assertEqual(hash(x1), hash(x2)) + def test_device_attr(self): + # For single-device arrays, x.device returns the device + x = jnp.ones((2, 10)) + self.assertEqual(x.device, list(x.devices())[0]) + + # For sharded arrays, x.device returns the sharding + mesh = jtu.create_global_mesh((2,), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, P('x')) + x = jax.device_put(x, sharding) + self.assertEqual(x.device, sharding) + + def test_to_device(self): + device = jax.devices()[-1] + mesh = jtu.create_global_mesh((2,), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, P('x')) + + x = jnp.ones((2, 10)) + + x_device = x.to_device(device) + x_sharding = x.to_device(sharding) + + self.assertEqual(x_device.device, device) + self.assertEqual(x_sharding.device, sharding) + class RngShardingTest(jtu.JaxTestCase): # tests that the PRNGs are automatically sharded as expected