Skip to content

Commit

Permalink
Add support for max_version, dl_device, copy kwargs in __dlpack__
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Mar 26, 2024
1 parent d7e5dde commit 4a85f4a
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 27 deletions.
22 changes: 18 additions & 4 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,11 +404,25 @@ def __array__(self, dtype=None, context=None, copy=None):
kwds = {} if copy is None else {'copy': copy}
return np.asarray(self._value, dtype=dtype, **kwds)

def __dlpack__(self, *, stream: int | Any | None = None):
if len(self._arrays) != 1:
raise BufferError("__dlpack__ only supported for unsharded arrays.")
def __dlpack__(self, *, stream: int | Any | None = None,
max_version: tuple[int, int] | None = None,
dl_device: tuple[enum.Enum, int] | None = None,
copy: bool | None = None):
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
return to_dlpack(self, stream=stream)

device_set = self.sharding.device_set
if len(device_set) > 1:
raise BufferError(
"to_dlpack can only pack a dlpack tensor from an array on a singular "
f"device, but an array with a Sharding over {len(device_set)} devices "
"was provided."
)
device, = device_set
return to_dlpack(self, stream=stream,
max_version=max_version,
src_device=device,
dl_device=dl_device, # type: ignore
copy=copy)

def __dlpack_device__(self) -> tuple[enum.Enum, int]:
if len(self._arrays) != 1:
Expand Down
125 changes: 112 additions & 13 deletions jax/_src/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@
from jax import numpy as jnp
from jax._src import array
from jax._src import xla_bridge
from jax._src.lax.lax import _array_copy
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.typing import Array
from jax._src.api import device_put

DLPACK_VERSION = (0, 8)
MIN_DLPACK_VERSION = (0, 5)

# A set of dtypes that dlpack supports.
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
Expand All @@ -48,9 +52,34 @@ class DLDeviceType(enum.IntEnum):
kDLCUDA = 2
kDLROCM = 10

def _to_dlpack(x: Array, stream: int | Any | None,
src_device: xla_client.Device | None = None,
device: xla_client.Device | None = None,
copy: bool | None = None):

if src_device is None:
src_device, = x.devices()
if device and (src_device is None or device != src_device):
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source device "
f"is {repr(src_device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
arr = device_put(x, device)
else:
arr = _array_copy(x) if copy else x
return xla_client._xla.buffer_to_dlpack_managed_tensor(
arr.addressable_data(0), stream=stream
)

def to_dlpack(x: Array, take_ownership: bool = False,
stream: int | Any | None = None):
stream: int | Any | None = None,
src_device: xla_client.Device | None = None,
dl_device: tuple[DLDeviceType, int] | None = None,
max_version: tuple[int, int] | None = None,
copy : bool | None = None):
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
Args:
Expand All @@ -60,27 +89,97 @@ def to_dlpack(x: Array, take_ownership: bool = False,
stream: optional platform-dependent stream to wait on until the buffer is
ready. This corresponds to the `stream` argument to ``__dlpack__``
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
src_device: either a CPU or GPU :class:`~jax.Device`.
dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
format e.g. as produced by ``__dlpack_device__``.
max_version: the maximum DLPack version that the consumer (i.e. caller of
``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
This function is not guaranteed to return a capsule of version
``max_version``.
copy: a boolean indicating whether or not to copy the input. If
``copy=True`` then the function must always copy. When
``copy=False`` then the function must never copy, and must raise an error
when a copy is deemed necessary. If ``copy=None`` then the function must
avoid a copy if possible but may copy if needed.
Returns:
A dlpack PyCapsule object.
A DLPack PyCapsule object.
Note:
While JAX arrays are always immutable, dlpack buffers cannot be marked as
immutable, and it is possible for processes external to JAX to mutate them
in-place. If a dlpack buffer derived from a JAX array is mutated, it may
lead to undefined behavior when using the associated JAX array.
While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
cannot be marked as immutable, and it is possible for processes external
to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
is mutated, it may lead to undefined behavior when using the associated JAX
array. When JAX eventually supports ``DLManagedTensorVersioned``
(DLPack 1.0), it will be possible to specify that a buffer is read-only.
"""
if not isinstance(x, array.ArrayImpl):
raise TypeError("Argument to to_dlpack must be a jax.Array, "
f"got {type(x)}")
assert len(x.devices()) == 1
if take_ownership:
warnings.warn(
"take_ownership in to_dlpack is deprecated and it is a no-op."
)
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), stream=stream
) # type: ignore

device = None
dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
if dl_device_type:
try:
dl_device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
backend = xla_bridge.get_backend(dl_device_platform)
device = backend.device_from_local_hardware_id(local_hardware_id)
except TypeError:
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
# recommends using BufferError.
raise BufferError(
"The device specification passed to to_dlpack contains an unsupported "
f"device type (DLDeviceType: {dl_device_type})")

# TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
# supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
# current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0)
if max_version is None:
# Backwards compatible default
return _to_dlpack(
x, stream=stream,
src_device=src_device,
device=device,
copy=copy
)
else:
if max_version >= DLPACK_VERSION:
# Latest
return _to_dlpack(
x, stream=stream,
src_device=src_device,
device=device,
copy=copy
)
if max_version[0] == DLPACK_VERSION[0]:
# ABI compatible
return _to_dlpack(
x, stream=stream,
src_device=src_device,
device=device,
copy=copy
)
elif max_version >= MIN_DLPACK_VERSION:
# Oldest supported
return _to_dlpack(
x, stream=stream,
src_device=src_device,
device=device,
copy=copy
)
else:
raise BufferError(
f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
f"version ({max_version}) was requested."
)


def from_dlpack(external_array):
Expand Down Expand Up @@ -110,12 +209,12 @@ def from_dlpack(external_array):
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
except TypeError:
except TypeError as err:
# https://dmlc.github.io/dlpack/latest/python_spec.html recommends using
# TypeError.
raise TypeError(
raise BufferError(
"Array passed to from_dlpack is on unsupported device type "
f"(DLDeviceType: {dl_device_type}, array: {external_array}")
f"(DLDeviceType: {dl_device_type}, array: {external_array}") from err

backend = xla_bridge.get_backend(device_platform)
device = backend.device_from_local_hardware_id(device_id)
Expand Down
45 changes: 35 additions & 10 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,49 @@ def setUp(self):
shape=all_shapes,
dtype=dlpack_dtypes,
gpu=[False, True],
copy=[False, True, None]
)
def testJaxRoundTrip(self, shape, dtype, gpu):
def testJaxRoundTrip(self, shape, dtype, gpu, copy):
if xb.using_pjrt_c_api():
self.skipTest("DLPack support is incomplete in the PJRT C API") # TODO(skyewm)
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
if gpu and jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("Skipping GPU test case on CPU")
if gpu and not jtu.test_device_matches(["gpu"]):
raise unittest.SkipTest("Skipping GPU test case on CPU/TPU")

def _check_copy(x, y, expect_copy):
is_copy = x.unsafe_buffer_pointer() != y.unsafe_buffer_pointer()
assert is_copy == expect_copy

# Check if the source device is preserved
x = jax.device_put(np, jax.devices("cpu")[0])
device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device)
dlpack = jax.dlpack.to_dlpack(x)
y = jax.dlpack.from_dlpack(dlpack)
self.assertEqual(y.devices(), {device})
self.assertAllClose(np.astype(x.dtype), y)
y = jax.device_put(x, device)
dlpack = jax.dlpack.to_dlpack(y, copy=copy)
z = jax.dlpack.from_dlpack(dlpack)

self.assertEqual(z.devices(), {device})
self.assertAllClose(np.astype(x.dtype), z)
self.assertRaisesRegex(RuntimeError,
"DLPack tensor may be consumed at most once",
lambda: jax.dlpack.from_dlpack(dlpack))
"DLPack tensor may be consumed at most once",
lambda: jax.dlpack.from_dlpack(dlpack))

if shape in nonempty_array_shapes:
_check_copy(y, z, bool(copy))

# Check if the destination device can be specified
dl_device = y.__dlpack_device__()
make_dlpack = lambda: x.__dlpack__(dl_device=dl_device, copy=copy)
if gpu and copy == False:
self.assertRaisesRegex(ValueError, "copy=False", make_dlpack)
return

z = jax.dlpack.from_dlpack(make_dlpack())
self.assertEqual(z.devices(), {device})
self.assertAllClose(np.astype(x.dtype), z)

if shape in nonempty_array_shapes:
_check_copy(x, z, bool(copy) or gpu)

@jtu.sample_product(
shape=all_shapes,
Expand Down

0 comments on commit 4a85f4a

Please sign in to comment.