Skip to content

Commit

Permalink
[array API] add device property & to_device method
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 23, 2024
1 parent 0c09e79 commit f8e5f0b
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Remember to align the itemized text with the first line of an item within a list
will be removed in a future release.
* Updated the repr of gpu devices to be more consistent
with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`.
* Added the `device` property and `to_device` method to {class}`jax.Array`, as
part of JAX's [Array API](https://data-apis.org/array-api) support.
* Deprecations
* Removed a number of previously-deprecated internal APIs related to
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ def size(self):
def sharding(self):
return self._sharding

@property
def device(self):
self._check_if_deleted()
if isinstance(self.sharding, SingleDeviceSharding):
return list(self.sharding.device_set)[0]
return self.sharding

@property
def weak_type(self):
return self.aval.weak_type
Expand Down
10 changes: 10 additions & 0 deletions jax/_src/basearray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from collections.abc import Sequence

# TODO(jakevdp): fix import cycles and define these.
Device = Any
Shard = Any
Sharding = Any

Expand Down Expand Up @@ -112,6 +113,15 @@ def is_fully_replicated(self) -> bool:
def sharding(self) -> Sharding:
"""The sharding for the array."""

@property
@abc.abstractmethod
def device(self) -> Device | Sharding:
"""Array API-compatible device attribute.
For single-device arrays, this returns a Device. For sharded arrays, this
returns a Sharding.
"""


Array.__module__ = "jax"

Expand Down
3 changes: 3 additions & 0 deletions jax/_src/basearray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ class Array(abc.ABC):
@property
def sharding(self) -> Sharding: ...
@property
def device(self) -> Device | Sharding: ...
@property
def addressable_shards(self) -> Sequence[Shard]: ...
@property
def global_shards(self) -> Sequence[Shard]: ...
Expand All @@ -216,6 +218,7 @@ class Array(abc.ABC):
@property
def traceback(self) -> Traceback: ...
def unsafe_buffer_pointer(self) -> int: ...
def to_device(self, device: Device | Sharding, *, stream: int | Any | None) -> Array: ...


StaticScalar = Union[
Expand Down
9 changes: 9 additions & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,15 @@ def sharding(self):
f"The 'sharding' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")

@property
def device(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
# we raise an AttributeError so that hasattr() and getattr() work as expected.
raise AttributeError(self,
f"The 'device' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")

@property
def addressable_shards(self):
raise ConcretizationTypeError(self,
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/earray.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def sharding(self):
phys_sharding = self._data.sharding
return sharding_impls.logical_sharding(self.aval, phys_sharding)

@property
def device(self):
if isinstance(self._data.sharding, sharding_impls.SingleDeviceSharding):
return self._data.device
return self.sharding

# TODO(mattjj): not implemented below here, need more methods from ArrayImpl

def addressable_data(self, index: int) -> EArray:
Expand Down
8 changes: 8 additions & 0 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import jax
from jax import lax
from jax.sharding import Sharding
from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
Expand Down Expand Up @@ -67,6 +68,12 @@ def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Dev
"""
return lax_numpy.astype(arr, dtype, copy=copy, device=device)

def _to_device(arr: ArrayLike, device: xc.Device | Sharding, *,
stream: int | Any | None = None):
if stream is not None:
raise NotImplementedError("stream argument of array.to_device()")
return api.device_put(arr, device)


def _nbytes(arr: ArrayLike) -> int:
"""Total bytes consumed by the elements of the array."""
Expand Down Expand Up @@ -694,6 +701,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
"sum": reductions.sum,
"swapaxes": lax_numpy.swapaxes,
"take": lax_numpy.take,
"to_device": _to_device,
"trace": lax_numpy.trace,
"transpose": _transpose,
"var": reductions.var,
Expand Down
9 changes: 0 additions & 9 deletions jax/experimental/array_api/_array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,6 @@ def _array_namespace(self, /, *, api_version: None | str = None):
return jax.experimental.array_api


def _to_device(self, device: xe.Device | Sharding | None, *,
stream: int | Any | None = None):
if stream is not None:
raise NotImplementedError("stream argument of array.to_device()")
return jax.device_put(self, device)


def add_array_object_methods():
# TODO(jakevdp): set on tracers as well?
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
setattr(ArrayImpl, "to_device", _to_device)
setattr(ArrayImpl, "device", property(lambda self: self.sharding))
24 changes: 24 additions & 0 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,30 @@ def test_gspmd_sharding_hash_eq(self):
self.assertEqual(x1, x2)
self.assertEqual(hash(x1), hash(x2))

def test_device_attr(self):
# For single-device arrays, x.device returns the device
x = jnp.ones((2, 10))
self.assertEqual(x.device, list(x.devices())[0])

# For sharded arrays, x.device returns the sharding
mesh = jtu.create_global_mesh((2,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x = jax.device_put(x, sharding)
self.assertEqual(x.device, sharding)

def test_to_device(self):
device = jax.devices()[-1]
mesh = jtu.create_global_mesh((2,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))

x = jnp.ones((2, 10))

x_device = x.to_device(device)
x_sharding = x.to_device(sharding)

self.assertEqual(x_device.device, device)
self.assertEqual(x_sharding.device, sharding)


class RngShardingTest(jtu.JaxTestCase):
# tests that the PRNGs are automatically sharded as expected
Expand Down

0 comments on commit f8e5f0b

Please sign in to comment.