From a87718237ac726fd3c53614750f2719a97bb83e9 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 12 Mar 2024 18:28:02 +0000 Subject: [PATCH] Update --- .../array_api/_data_type_functions.py | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py index d2bb032b85ab..43a6da8815f0 100644 --- a/jax/experimental/array_api/_data_type_functions.py +++ b/jax/experimental/array_api/_data_type_functions.py @@ -11,12 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations +import builtins import functools -from typing import NamedTuple +from typing import Any, NamedTuple import jax import jax.numpy as jnp +from jax._src.lib import xla_client as xc +from jax._src.sharding import Sharding +from jax._src.api import device_put from jax.experimental.array_api._dtypes import ( @@ -124,9 +129,24 @@ def _promote_types(t1, t2): raise ValueError("No promotion path for {t1} & {t2}") -def astype(x, dtype, /, *, copy=True): - return jnp.array(x, dtype=dtype, copy=copy) - +def astype(x, dtype, /, *, copy: builtins.bool | None = True, device: xc.Device | Sharding | None = None): + arr = jnp.array(x, dtype=dtype) + src_device = arr.devices().pop() + # TODO(micky774): refactor into a common utility with _place_array in gh-20175 + if device is not None: + if copy is not None and not copy: + raise ValueError( + f"Specified {device=} which requires a copy since the source device " + f"is {repr(src_device)}, however copy=False. Set copy=True or " + "copy=None to perform the requested operation." + ) + else: + return device_put(arr, device) + if copy: + # TODO(micky774): Remove if clause and replace with jnp.array(arr, copy=copy) + # when we support Numpy 2.0 copy semantics + return jnp.array(arr, copy=True) + return arr def can_cast(from_, to, /): if isinstance(from_, jax.Array):