Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for device kwarg in astype, and add matching utility func #21086

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.28 (May 9, 2024)

* New Functionality
* {func}`jax.numpy.astype` supports a new `device` keyword argument.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move up to 0.4.29

* Bug fixes
* Reverted a change to `make_jaxpr` that was breaking Equinox (#21116).

Expand Down
19 changes: 3 additions & 16 deletions jax/_src/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax._src.lib import xla_client
from jax._src.typing import Array, DLDeviceType
from jax._src.sharding import Sharding
from jax._src.numpy.util import _place_array

DLPACK_VERSION = (0, 8)
MIN_DLPACK_VERSION = (0, 5)
Expand Down Expand Up @@ -148,19 +149,6 @@ def to_dlpack(x: Array, stream: int | Any | None = None,
f"version ({max_version}) was requested."
)

def _place_array(_arr, device, dlpack_device, copy):
if device and dlpack_device != device:
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source device "
f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
return device_put(_arr, device)
if copy:
return jnp.array(_arr, copy=True)
return _arr

def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None,
copy: bool | None = None):
Expand Down Expand Up @@ -194,8 +182,7 @@ def _legacy_from_dlpack(dlpack, device: xla_client.Device | None = None,

_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend))
dlpack_device, = _arr.devices()
return _place_array(_arr, device, dlpack_device, copy)
return _place_array(_arr, device, copy)

def _from_dlpack(external_array, device: xla_client.Device | None = None,
copy: bool | None = None):
Expand Down Expand Up @@ -226,7 +213,7 @@ def _from_dlpack(external_array, device: xla_client.Device | None = None,

_arr = jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, dlpack_device, stream))
return _place_array(_arr, device, dlpack_device, copy)
return _place_array(_arr, device, copy)

def from_dlpack(external_array,
device: xla_client.Device | Sharding | None = None,
Expand Down
14 changes: 5 additions & 9 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2853,18 +2853,14 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
# to issue our warning.
with warnings.catch_warnings():
warnings.simplefilter("ignore", ComplexWarning)
return _place_array(
return util._place_array(
lax.convert_element_type(x_arr, dtype),
device=device, copy=copy,
device=device,
# We translate between array API semantics of copy in _place_array, and
# the NumPy semantics of copy in astype.
copy=True if copy else None,
)

def _place_array(x, device=None, copy=None):
# TODO(micky774): Implement in future PRs as we formalize device placement
# semantics
if copy:
return _array_copy(x)
return x


@util.implements(np.asarray, lax_description=_ARRAY_DOC)
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
Expand Down
25 changes: 24 additions & 1 deletion jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
import re
import textwrap
from typing import Any, Callable, NamedTuple, TypeVar

import warnings

from jax.sharding import Sharding

from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape

Expand Down Expand Up @@ -117,6 +119,27 @@ def _parse_extra_params(extra_params: str) -> dict[str, str]:
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}


def _place_array(x: Array, device: xc.Device | Sharding | None = None, copy=None) -> Array:
Micky774 marked this conversation as resolved.
Show resolved Hide resolved
"""Helper utility for copying an array, or placing it on a device or sharding.

This utility uses `jax.device_put` for device placement.
"""
out = x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why out=x?

if device is not None:
Micky774 marked this conversation as resolved.
Show resolved Hide resolved
# TODO(micky774): Add check to avoid error if no actual device transfer is
# necessary
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy, however copy=False. Set "
"copy=True or copy=None to perform the requested operation."
)
out = api.device_put(out, device)

# TODO(micky774): Avoid copy if data has already been copied via device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This todo doesn't make sense? In this branch the device is None. Remove the todo?

# transfer
return lax._array_copy(out) if copy else out


def implements(
original_fun: Callable[..., Any] | None,
update_doc: bool = True,
Expand Down
6 changes: 6 additions & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
and implements most of the API listed in the standard.

.. _Python array API standard: https://data-apis.org/array-api/latest/


Note that JAX may not always strictly adhere to array API device semantics when
using ``jax.jit``. In particular, specifying the ``device`` argument is
equivalent to calling ``jax.device_put(x, device)``. For up-to-date details on
device placement, see the documentation of ``jax.device_put`` for more details.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually references go at the bottom of the docstring. You can put this paragraph above the .. _Python... line

"""

from __future__ import annotations
Expand Down
7 changes: 4 additions & 3 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,14 @@ def testTensorFlowToJaxInt64(self):
shape=all_shapes,
dtype=numpy_dtypes,
copy=[False, True],
device_transfer=[False, True],
)
def testNumpyToJax(self, shape, dtype, copy):
def testNumpyToJax(self, shape, dtype, copy, device_transfer):
rng = jtu.rand_default(self.rng())
x_np = rng(shape, dtype)
device = jax.devices()[0]
device = jax.devices()[0] if device_transfer else None
_from_dlpack = lambda: jnp.from_dlpack(x_np, device=device, copy=copy)
if jax.default_backend() == 'gpu' and not copy:
if device_transfer and not copy:
self.assertRaisesRegex(
ValueError,
r"Specified .* which requires a copy",
Expand Down
40 changes: 32 additions & 8 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import jax.ops
from jax import lax
from jax import numpy as jnp
from jax.sharding import SingleDeviceSharding
from jax.sharding import SingleDeviceSharding, PartitionSpec as P
from jax.test_util import check_grads

from jax._src import array
Expand Down Expand Up @@ -3931,19 +3931,43 @@ def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'):
self._CompileAndCheck(jnp_op, args_maker)

@jtu.sample_product(
change_dtype=[True, False],
[dict(dtype=dtype, new_dtype=new_dtype)
for dtype in all_dtypes
for new_dtype in (
complex_dtypes
if np.issubdtype(dtype, np.complexfloating)
else all_dtypes
)
],
shape=array_shapes,
copy=[True, False],
device_type=[None, "single", "shard"],
)
def testAstypeCopy(self, change_dtype, copy):
dtype = 'float32' if change_dtype else 'int32'
expect_copy = change_dtype or copy
x = jnp.arange(5, dtype='int32')
y = x.astype(dtype, copy=copy)
@jtu.run_on_devices("gpu")
def testAstypePlacement(self, shape, dtype, new_dtype, copy, device_type):
rng = jtu.rand_default(self.rng())
x = jnp.asarray(rng(shape, dtype))

if device_type is None:
device = None
expected_sharding = x.sharding
elif device_type == "single":
device = jax.devices("cpu")[0]
expected_sharding = SingleDeviceSharding(device)
else:
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use (2, 2) mesh to coverage across more hardware?

device = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to try other PartitionSpecs too (can you parameterize the test on that too)

expected_sharding = device

expect_copy = (dtype != new_dtype) or copy or device

self.assertEqual(y.dtype, dtype)
y = x.astype(new_dtype, copy=copy, device=device)
self.assertEqual(y.dtype, new_dtype)
self.assertEqual(y.sharding, expected_sharding)
y.delete()
self.assertNotEqual(x.is_deleted(), expect_copy)


def testAstypeComplexDowncast(self):
x = jnp.array(2.0+1.5j, dtype='complex64')
msg = "Casting from complex to non-complex dtypes will soon raise "
Expand Down