From d77cd9a0f42dc666cf55a1854b88c3975061dd0d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 30 Nov 2023 15:50:22 -0800 Subject: [PATCH] Add jax.numpy.astype function --- docs/jax.numpy.rst | 1 + jax/_src/numpy/array_methods.py | 5 +---- jax/_src/numpy/lax_numpy.py | 14 ++++++++++++++ jax/_src/numpy/util.py | 2 ++ jax/numpy/__init__.py | 1 + jax/numpy/__init__.pyi | 1 + tests/lax_numpy_test.py | 27 +++++++++++++++------------ 7 files changed, 35 insertions(+), 16 deletions(-) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index e75204766a91..d4cfd63c0f0a 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -78,6 +78,7 @@ namespace; they are listed below. array_split array_str asarray + astype atleast_1d atleast_2d atleast_3d diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 067854ea4f94..98ceacf04bf3 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -61,10 +61,7 @@ def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. """ - if dtype is None: - dtype = dtypes.canonicalize_dtype(lax_numpy.float_) - dtypes.check_user_dtype_supported(dtype, "astype") - return lax.convert_element_type(arr, dtype) + return lax_numpy.astype(arr, dtype) def _nbytes(arr: ArrayLike) -> int: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index e388934ae3d7..a5dc5c882fb7 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -2179,6 +2179,20 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: return x +@util._wraps(getattr(np, "astype", None), lax_description=""" +This is implemented via :func:`jax.lax.convert_element_type`, which may +have slightly different behavior than :func:`numpy.astype` in some cases. +In particular, the details of float-to-int and int-to-float casts are +implementation dependent. +""") +def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array: + del copy # unused in JAX + if dtype is None: + dtype = dtypes.canonicalize_dtype(float_) + dtypes.check_user_dtype_supported(dtype, "astype") + return lax.convert_element_type(x, dtype) + + @util._wraps(np.asarray, lax_description=_ARRAY_DOC) def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None) -> Array: dtypes.check_user_dtype_supported(dtype, "asarray") diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 8c78f7702509..616648120f0d 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -156,6 +156,8 @@ def wrap(op): op.__np_wrapped__ = fun # Allows this pattern: @wraps(getattr(np, 'new_function', None)) if fun is None: + if lax_description: + op.__doc__ = lax_description return op docstr = getattr(fun, "__doc__", None) name = getattr(fun, "__name__", getattr(op, "__name__", str(op))) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index f408117db770..8e33c8d0efc4 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -39,6 +39,7 @@ array_repr as array_repr, array_split as array_split, array_str as array_str, + astype as astype, asarray as asarray, atleast_1d as atleast_1d, atleast_2d as atleast_2d, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 876d764480bf..2d460a6ab8b3 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -103,6 +103,7 @@ array_str = _np.array_str def asarray( a: Any, dtype: Optional[DTypeLike] = ..., order: Optional[str] = ... ) -> Array: ... +def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: bool = ...) -> Array: ... @overload def atleast_1d() -> list[Array]: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 9ce9a1cf62a2..75e8521becff 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3531,19 +3531,22 @@ def np_fun(index, shape): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - def testAstype(self): - rng = self.rng() - args_maker = lambda: [rng.randn(3, 4).astype("float32")] - np_op = lambda x: np.asarray(x).astype(jnp.int32) - jnp_op = lambda x: jnp.asarray(x).astype(jnp.int32) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - def testAstypeNone(self): + @jtu.sample_product( + from_dtype=['int32', 'float32'], + to_dtype=['int32', 'float32', None], + use_method=[True, False], + ) + def testAstype(self, from_dtype, to_dtype, use_method): rng = self.rng() - args_maker = lambda: [rng.randn(3, 4).astype("int32")] - np_op = jtu.with_jax_dtype_defaults(lambda x: np.asarray(x).astype(None)) - jnp_op = lambda x: jnp.asarray(x).astype(None) + args_maker = lambda: [rng.randn(3, 4).astype(from_dtype)] + if (not use_method) and hasattr(np, "astype"): # Added in numpy 2.0 + np_op = lambda x: np.astype(x, to_dtype) + else: + np_op = lambda x: np.asarray(x).astype(to_dtype) + if use_method: + jnp_op = lambda x: jnp.asarray(x).astype(to_dtype) + else: + jnp_op = lambda x: jnp.astype(x, to_dtype) self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker)