Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Mar 12, 2024
1 parent d0eae05 commit 35543cf
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 7 deletions.
21 changes: 19 additions & 2 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,11 +404,28 @@ def __array__(self, dtype=None, context=None, copy=None):
kwds = {} if copy is None else {'copy': copy}
return np.asarray(self._value, dtype=dtype, **kwds)

def __dlpack__(self, *, stream: int | Any | None = None):
def __dlpack__(self, *, stream: int | Any | None = None,
max_version: tuple[int, int] | None = None,
dl_device: tuple[enum.Enum, int] | None = None,
copy: bool | None = None):
if len(self._arrays) != 1:
raise ValueError("__dlpack__ only supported for unsharded arrays.")

from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
return to_dlpack(self, stream=stream)

device_set = self.sharding.device_set
if len(device_set) > 1:
raise BufferError(
"to_dlpack can only pack a dlpack tensor from an array on a singular "
f"device, but an array with a Sharding over {len(device_set)} devices "
"was provided."
)
device = device_set.pop()
return to_dlpack(self, stream=stream,
max_version=max_version,
device=device,
dl_device=dl_device, # type: ignore
copy=copy)

def __dlpack_device__(self) -> tuple[enum.Enum, int]:
if len(self._arrays) != 1:
Expand Down
63 changes: 58 additions & 5 deletions jax/_src/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.typing import Array
from jax._src.api import device_put

DLPACK_VERSION = (0, 1)
MIN_DLPACK_VERSION = (0, 1)

# A set of dtypes that dlpack supports.
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
Expand All @@ -48,9 +51,33 @@ class DLDeviceType(enum.IntEnum):
kDLCUDA = 2
kDLROCM = 10

def _to_dlpack(x: Array, stream: int | Any | None,
device: xla_client.Device | None = None,
dlpack_device: xla_client.Device | None = None,
copy: bool | None = None):
arr = None
if dlpack_device and dlpack_device != device:
if copy is not None and not copy:
raise ValueError(
f"Specified {dlpack_device=} which requires a copy since the source device "
f"is {repr(device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
arr = device_put(x, dlpack_device)
if arr is None:
arr = x.copy() if copy else x

return xla_client._xla.buffer_to_dlpack_managed_tensor(
arr.addressable_data(0), stream=stream
) # type: ignore

def to_dlpack(x: Array, take_ownership: bool = False,
stream: int | Any | None = None):
stream: int | Any | None = None,
device: xla_client.Device | None = None,
dl_device: tuple[DLDeviceType, int] | None = None,
max_version: tuple[int, int] | None = None,
copy : bool | None = None):
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
Args:
Expand All @@ -73,14 +100,40 @@ def to_dlpack(x: Array, take_ownership: bool = False,
if not isinstance(x, array.ArrayImpl):
raise TypeError("Argument to to_dlpack must be a jax.Array, "
f"got {type(x)}")
assert len(x.devices()) == 1
if take_ownership:
warnings.warn(
"take_ownership in to_dlpack is deprecated and it is a no-op."
)
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), stream=stream
) # type: ignore

dlpack_device = None
dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
if dl_device_type:
try:
dl_device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
backend = xla_bridge.get_backend(dl_device_platform)
dlpack_device = backend.device_from_local_hardware_id(local_hardware_id)
except TypeError:
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
# recommends using BufferError.
raise BufferError(
"The device specification passed to to_dlpack contains an unsupported "
f"device type (DLDeviceType: {dl_device_type})")

if max_version is None or max_version[0] >= DLPACK_VERSION[0]:
return _to_dlpack(x, stream=stream, device=device, dlpack_device=dlpack_device, copy=copy)
elif max_version >= MIN_DLPACK_VERSION:
# Legacy path to be implemented when XLA adopts DLManagedTensorVersioned format
raise RuntimeError("This branch should be unreachable. "
"Please open a bug if you see this.")
else:
raise BufferError(
f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
f"version ({max_version}) was requested."
)


def from_dlpack(external_array):
Expand Down

0 comments on commit 35543cf

Please sign in to comment.