Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[array API] add device property & to_device method #22597

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Remember to align the itemized text with the first line of an item within a list
will be removed in a future release.
* Updated the repr of gpu devices to be more consistent
with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`.
* Added the `device` property and `to_device` method to {class}`jax.Array`, as
part of JAX's [Array API](https://data-apis.org/array-api) support.
* Deprecations
* Removed a number of previously-deprecated internal APIs related to
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ def size(self):
def sharding(self):
return self._sharding

@property
def device(self):
yashk2810 marked this conversation as resolved.
Show resolved Hide resolved
self._check_if_deleted()
if isinstance(self.sharding, SingleDeviceSharding):
return list(self.sharding.device_set)[0]
return self.sharding

@property
def weak_type(self):
return self.aval.weak_type
Expand Down
10 changes: 10 additions & 0 deletions jax/_src/basearray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from collections.abc import Sequence

# TODO(jakevdp): fix import cycles and define these.
Device = Any
Shard = Any
Sharding = Any

Expand Down Expand Up @@ -112,6 +113,15 @@ def is_fully_replicated(self) -> bool:
def sharding(self) -> Sharding:
"""The sharding for the array."""

@property
@abc.abstractmethod
def device(self) -> Device | Sharding:
"""Array API-compatible device attribute.

For single-device arrays, this returns a Device. For sharded arrays, this
returns a Sharding.
"""


Array.__module__ = "jax"

Expand Down
3 changes: 3 additions & 0 deletions jax/_src/basearray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ class Array(abc.ABC):
@property
def sharding(self) -> Sharding: ...
@property
def device(self) -> Device | Sharding: ...
@property
def addressable_shards(self) -> Sequence[Shard]: ...
@property
def global_shards(self) -> Sequence[Shard]: ...
Expand All @@ -216,6 +218,7 @@ class Array(abc.ABC):
@property
def traceback(self) -> Traceback: ...
def unsafe_buffer_pointer(self) -> int: ...
def to_device(self, device: Device | Sharding, *, stream: int | Any | None) -> Array: ...


StaticScalar = Union[
Expand Down
9 changes: 9 additions & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,15 @@ def sharding(self):
f"The 'sharding' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")

@property
def device(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
# Raising a ConcretizationTypeError would make sense, but for backward compatibility
# we raise an AttributeError so that hasattr() and getattr() work as expected.
raise AttributeError(self,
f"The 'device' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")

@property
def addressable_shards(self):
raise ConcretizationTypeError(self,
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/earray.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def sharding(self):
phys_sharding = self._data.sharding
return sharding_impls.logical_sharding(self.aval, phys_sharding)

@property
def device(self):
yashk2810 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(self._data.sharding, sharding_impls.SingleDeviceSharding):
return self._data.device
return self.sharding

# TODO(mattjj): not implemented below here, need more methods from ArrayImpl

def addressable_data(self, index: int) -> EArray:
Expand Down
8 changes: 8 additions & 0 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import jax
from jax import lax
from jax.sharding import Sharding
from jax._src import api
from jax._src import core
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
Expand Down Expand Up @@ -67,6 +68,12 @@ def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Dev
"""
return lax_numpy.astype(arr, dtype, copy=copy, device=device)

def _to_device(arr: ArrayLike, device: xc.Device | Sharding, *,
stream: int | Any | None = None):
if stream is not None:
raise NotImplementedError("stream argument of array.to_device()")
return api.device_put(arr, device)


def _nbytes(arr: ArrayLike) -> int:
"""Total bytes consumed by the elements of the array."""
Expand Down Expand Up @@ -694,6 +701,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
"sum": reductions.sum,
"swapaxes": lax_numpy.swapaxes,
"take": lax_numpy.take,
"to_device": _to_device,
"trace": lax_numpy.trace,
"transpose": _transpose,
"var": reductions.var,
Expand Down
9 changes: 0 additions & 9 deletions jax/experimental/array_api/_array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,6 @@ def _array_namespace(self, /, *, api_version: None | str = None):
return jax.experimental.array_api


def _to_device(self, device: xe.Device | Sharding | None, *,
stream: int | Any | None = None):
if stream is not None:
raise NotImplementedError("stream argument of array.to_device()")
return jax.device_put(self, device)


def add_array_object_methods():
# TODO(jakevdp): set on tracers as well?
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
setattr(ArrayImpl, "to_device", _to_device)
setattr(ArrayImpl, "device", property(lambda self: self.sharding))
24 changes: 24 additions & 0 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,30 @@ def test_gspmd_sharding_hash_eq(self):
self.assertEqual(x1, x2)
self.assertEqual(hash(x1), hash(x2))

def test_device_attr(self):
# For single-device arrays, x.device returns the device
x = jnp.ones((2, 10))
self.assertEqual(x.device, list(x.devices())[0])

# For sharded arrays, x.device returns the sharding
mesh = jtu.create_global_mesh((2,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x = jax.device_put(x, sharding)
self.assertEqual(x.device, sharding)

def test_to_device(self):
device = jax.devices()[-1]
mesh = jtu.create_global_mesh((2,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))

x = jnp.ones((2, 10))

x_device = x.to_device(device)
x_sharding = x.to_device(sharding)

self.assertEqual(x_device.device, device)
self.assertEqual(x_sharding.device, sharding)


class RngShardingTest(jtu.JaxTestCase):
# tests that the PRNGs are automatically sharded as expected
Expand Down