diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index 83dc893a9515..72503fe18a2c 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -18,13 +18,14 @@ from typing import Any import warnings +from jax._src.api import device_put from jax import numpy as jnp from jax._src import array from jax._src import xla_bridge from jax._src.lib import xla_client from jax._src.lib import xla_extension_version from jax._src.typing import Array - +from jax._src.sharding import Sharding # A set of dtypes that dlpack supports. # Note: Make sure to use a "type", not a dtype instance, when looking up this set @@ -82,16 +83,111 @@ def to_dlpack(x: Array, take_ownership: bool = False, x.addressable_data(0), stream=stream ) # type: ignore +def _place_array(_arr, device, dlpack_device, copy): + if device and dlpack_device != device: + if copy is not None and not copy: + raise ValueError( + f"Specified {device=} which requires a copy since the source device " + f"is {repr(dlpack_device)}, however copy=False. Set copy=True or " + "copy=None to perform the requested operation." + ) + else: + return device_put(_arr, device) + if copy: + return jnp.array(_arr, copy=True) + return _arr + +def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None, + copy: bool | None = None): + preferred_platform = getattr(device, "platform", None) + if device and preferred_platform == "gpu": + preferred_platform = "cuda" if "cuda" in device.client.platform_version else "rocm" + + cpu_backend = xla_bridge.get_backend("cpu") + gpu_backend = None + + if preferred_platform in {"cuda", "rocm"}: + try: + gpu_backend = xla_bridge.get_backend(preferred_platform) + except RuntimeError: + raise TypeError( + f"A {str.upper(preferred_platform)} device was specified, however no " + f"{str.upper(preferred_platform)} backend was found." + ) -def from_dlpack(external_array): + if preferred_platform is None: + try: + gpu_backend = xla_bridge.get_backend("cuda") + except RuntimeError: + pass + # Try ROCm if CUDA backend not found + if gpu_backend is None: + try: + gpu_backend = xla_bridge.get_backend("rocm") + except RuntimeError: + pass + + _arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( + dlpack, cpu_backend, gpu_backend)) # type: ignore + dlpack_device, = _arr.devices() + return _place_array(_arr, device, dlpack_device, copy) + +def _from_dlpack(external_array, device: xla_client.Device | None = None, + copy: bool | None = None): + dl_device_type, device_id = external_array.__dlpack_device__() + try: + dl_device_platform = { + DLDeviceType.kDLCPU: "cpu", + DLDeviceType.kDLCUDA: "cuda", + DLDeviceType.kDLROCM: "rocm", + }[dl_device_type] + except TypeError: + # https://dmlc.github.io/dlpack/latest/python_spec.html recommends using + # TypeError. + raise TypeError( + "Array passed to from_dlpack is on unsupported device type " + f"(DLDeviceType: {dl_device_type}, array: {external_array}") + + backend = xla_bridge.get_backend(dl_device_platform) + dlpack_device = backend.device_from_local_hardware_id(device_id) + try: + stream = dlpack_device.get_stream_for_external_ready_events() + except xla_client.XlaRuntimeError as err: # type: ignore + if "UNIMPLEMENTED" in str(err): + stream = None + else: + raise + dlpack = external_array.__dlpack__(stream=stream) + + _arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( + dlpack, dlpack_device, stream)) + return _place_array(_arr, device, dlpack_device, copy) + +def from_dlpack(external_array, + device: xla_client.Device | Sharding | None = None, + copy: bool | None = None): """Returns a :class:`~jax.Array` representation of a DLPack tensor. - The returned :class:`~jax.Array` shares memory with ``external_array``. + The returned :class:`~jax.Array` shares memory with ``external_array`` if no + device transfer or copy was requested. Args: - external_array: an array object that has __dlpack__ and __dlpack_device__ + external_array: An array object that has __dlpack__ and __dlpack_device__ methods, or a DLPack tensor on either CPU or GPU (legacy API). + device: The (optional) :py:class:`Device`, representing the device on which + the returned array should be placed. If given, then the result is committed + to the device. If unspecified, the resulting array will be unpacked onto the + same device it originated from. Setting ``device`` to a device different from + the source of ``external_array`` will require a copy, meaning ``copy`` must be + set to either ``True`` or ``None``. + + copy: An (optional) boolean, controlling whether or not to a copy is performed. + If ``copy=True`` then a copy is always performed, even if unpacked onto the + same device. If ``copy=False`` then the copy is never peformed and will raise + an error if necessary. When ``copy=None`` then a copy may be performed if + needed for a device transfer. + Returns: A jax.Array @@ -102,49 +198,16 @@ def from_dlpack(external_array): is later modified in-place, it may lead to undefined behavior when using the associated JAX array. """ + if isinstance(device, Sharding): + device_set = device.device_set + if len(device_set) > 1: + raise ValueError( + "from_dlpack can only unpack a dlpack tensor onto a singular device, but " + f"a Sharding with {len(device_set)} devices was provided." + ) + device, = device_set if hasattr(external_array, "__dlpack__"): - dl_device_type, device_id = external_array.__dlpack_device__() - try: - device_platform = { - DLDeviceType.kDLCPU: "cpu", - DLDeviceType.kDLCUDA: "cuda", - DLDeviceType.kDLROCM: "rocm", - }[dl_device_type] - except TypeError: - # https://dmlc.github.io/dlpack/latest/python_spec.html recommends using - # TypeError. - raise TypeError( - "Array passed to from_dlpack is on unsupported device type " - f"(DLDeviceType: {dl_device_type}, array: {external_array}") - - backend = xla_bridge.get_backend(device_platform) - device = backend.device_from_local_hardware_id(device_id) - try: - stream = device.get_stream_for_external_ready_events() - except xla_client.XlaRuntimeError as err: # type: ignore - if "UNIMPLEMENTED" in str(err): - stream = None - else: - raise - dlpack = external_array.__dlpack__(stream=stream) - - return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( - dlpack, device, stream)) - else: - # Legacy path - dlpack = external_array - cpu_backend = xla_bridge.get_backend("cpu") - try: - gpu_backend = xla_bridge.get_backend("cuda") - except RuntimeError: - gpu_backend = None - - # Try ROCm if CUDA backend not found - if gpu_backend is None: - try: - gpu_backend = xla_bridge.get_backend("rocm") - except RuntimeError: - gpu_backend = None + return _from_dlpack(external_array, device, copy) - return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( - dlpack, cpu_backend, gpu_backend)) + # Legacy path + return _legacy_from_dlpack(external_array, device, copy) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 826251c82945..f8974c398e8c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2442,9 +2442,10 @@ def fromiter(*args, **kwargs): is later modified in-place, it may lead to undefined behavior when using the associated JAX array. """) -def from_dlpack(x: Any) -> Array: +def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, + copy: bool | None = None) -> Array: from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top - return from_dlpack(x) + return from_dlpack(x, device=device, copy=copy) @util.implements(np.fromfunction) def fromfunction(function: Callable[..., Array], shape: Any, diff --git a/jax/experimental/array_api/_creation_functions.py b/jax/experimental/array_api/_creation_functions.py index 0fcde42d58bb..2fd9be97ba27 100644 --- a/jax/experimental/array_api/_creation_functions.py +++ b/jax/experimental/array_api/_creation_functions.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import jax import jax.numpy as jnp - +from jax._src.lib import xla_client as xc +from jax._src.sharding import Sharding def arange(start, /, stop=None, step=1, *, dtype=None, device=None): return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device) @@ -31,8 +34,8 @@ def empty_like(x, /, *, dtype=None, device=None): def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None): return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device) -def from_dlpack(x, /): - return jnp.from_dlpack(x) +def from_dlpack(x, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None): + return jnp.from_dlpack(x, device=device, copy=copy) def full(shape, fill_value, *, dtype=None, device=None): return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 9ed5f39b393e..bfd652eb3be0 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -353,7 +353,8 @@ def fmax(x: ArrayLike, y: ArrayLike, /) -> Array: ... def fmin(x: ArrayLike, y: ArrayLike, /) -> Array: ... def fmod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def frexp(x: ArrayLike, /) -> tuple[Array, Array]: ... -def from_dlpack(x: Any) -> Array: ... +def from_dlpack(x: Any, /, *, device: _Device | None = None, + copy: builtins.bool | None = None) -> Array: ... def frombuffer(buffer: Union[bytes, Any], dtype: DTypeLike = ..., count: int = ..., offset: int = ...) -> Array: ... def fromfile(*args, **kwargs): ... diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 006dea8acd4d..3e6eb0ec1ed4 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -174,12 +174,23 @@ def testTensorFlowToJaxInt64(self): @jtu.sample_product( shape=all_shapes, dtype=numpy_dtypes, + gpu=[False, True] if jtu.test_device_matches(["gpu"]) else [False], + copy=[False, True], ) - def testNumpyToJax(self, shape, dtype): + def testNumpyToJax(self, shape, dtype, gpu, copy): rng = jtu.rand_default(self.rng()) x_np = rng(shape, dtype) - x_jax = jnp.from_dlpack(x_np) - self.assertAllClose(x_np, x_jax) + platform = "gpu" if gpu else "cpu" + device = jax.devices(platform)[0] + _from_dlpack = lambda: jnp.from_dlpack(x_np, device=device, copy=copy) + if gpu and not copy: + self.assertRaisesRegex( + ValueError, + r"Specified .* which requires a copy", + _from_dlpack + ) + else: + self.assertAllClose(x_np, _from_dlpack()) @jtu.sample_product( shape=all_shapes,