diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index c0b7a28e79ae..ed936be4e98c 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -85,6 +85,7 @@ expand_dims as expand_dims, expm1 as expm1, eye as eye, + finfo as finfo, flip as flip, float32 as float32, float64 as float64, @@ -193,7 +194,6 @@ from jax.experimental.array_api._data_type_functions import ( astype as astype, - finfo as finfo, ) from jax.experimental.array_api._elementwise_functions import ( diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py index 248c1c6dd0fe..3ff95befc6fe 100644 --- a/jax/experimental/array_api/_data_type_functions.py +++ b/jax/experimental/array_api/_data_type_functions.py @@ -14,34 +14,15 @@ from __future__ import annotations -import builtins -from typing import NamedTuple -import numpy as np - import jax.numpy as jnp from jax._src.lib import xla_client as xc from jax._src.sharding import Sharding from jax._src import dtypes as _dtypes -# TODO(micky774): Update jax.numpy dtypes to dtype *objects* -bool = np.dtype('bool') -int8 = np.dtype('int8') -int16 = np.dtype('int16') -int32 = np.dtype('int32') -int64 = np.dtype('int64') -uint8 = np.dtype('uint8') -uint16 = np.dtype('uint16') -uint32 = np.dtype('uint32') -uint64 = np.dtype('uint64') -float32 = np.dtype('float32') -float64 = np.dtype('float64') -complex64 = np.dtype('complex64') -complex128 = np.dtype('complex128') - # TODO(micky774): Remove when jax.numpy.astype is deprecation is completed -def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None): +def astype(x, dtype, /, *, copy: bool = True, device: xc.Device | Sharding | None = None): src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x) if ( src_dtype is not None @@ -54,25 +35,3 @@ def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Shard "your input." ) return jnp.astype(x, dtype, copy=copy, device=device) - - -class FInfo(NamedTuple): - bits: int - eps: float - max: float - min: float - smallest_normal: float - dtype: jnp.dtype - -# TODO(micky774): Update jax.numpy.finfo so that its attributes are python -# floats -def finfo(type, /) -> FInfo: - info = jnp.finfo(type) - return FInfo( - bits=info.bits, - eps=float(info.eps), - max=float(info.max), - min=float(info.min), - smallest_normal=float(info.smallest_normal), - dtype=jnp.dtype(type) - ) diff --git a/jax/experimental/array_api/skips.txt b/jax/experimental/array_api/skips.txt index c865fabcfb55..f7d80d94f96f 100644 --- a/jax/experimental/array_api/skips.txt +++ b/jax/experimental/array_api/skips.txt @@ -1,5 +1,8 @@ # Known failures for the array api tests. +# finfo return type misalignment (https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32] + # Test suite attempts in-place mutation: array_api_tests/test_special_cases.py::test_iop array_api_tests/test_special_cases.py::test_nan_propagation