Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[array API] clean up some superseded definitions #22663

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
expand_dims as expand_dims,
expm1 as expm1,
eye as eye,
finfo as finfo,
flip as flip,
float32 as float32,
float64 as float64,
Expand Down Expand Up @@ -193,7 +194,6 @@

from jax.experimental.array_api._data_type_functions import (
astype as astype,
finfo as finfo,
)

from jax.experimental.array_api._elementwise_functions import (
Expand Down
43 changes: 1 addition & 42 deletions jax/experimental/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,15 @@

from __future__ import annotations

import builtins
from typing import NamedTuple
import numpy as np

import jax.numpy as jnp

from jax._src.lib import xla_client as xc
from jax._src.sharding import Sharding
from jax._src import dtypes as _dtypes

# TODO(micky774): Update jax.numpy dtypes to dtype *objects*
bool = np.dtype('bool')
int8 = np.dtype('int8')
int16 = np.dtype('int16')
int32 = np.dtype('int32')
int64 = np.dtype('int64')
uint8 = np.dtype('uint8')
uint16 = np.dtype('uint16')
uint32 = np.dtype('uint32')
uint64 = np.dtype('uint64')
float32 = np.dtype('float32')
float64 = np.dtype('float64')
complex64 = np.dtype('complex64')
complex128 = np.dtype('complex128')


# TODO(micky774): Remove when jax.numpy.astype is deprecation is completed
def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None):
def astype(x, dtype, /, *, copy: bool = True, device: xc.Device | Sharding | None = None):
src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x)
if (
src_dtype is not None
Expand All @@ -54,25 +35,3 @@ def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Shard
"your input."
)
return jnp.astype(x, dtype, copy=copy, device=device)


class FInfo(NamedTuple):
bits: int
eps: float
max: float
min: float
smallest_normal: float
dtype: jnp.dtype

# TODO(micky774): Update jax.numpy.finfo so that its attributes are python
# floats
def finfo(type, /) -> FInfo:
info = jnp.finfo(type)
return FInfo(
bits=info.bits,
eps=float(info.eps),
max=float(info.max),
min=float(info.min),
smallest_normal=float(info.smallest_normal),
dtype=jnp.dtype(type)
)
3 changes: 3 additions & 0 deletions jax/experimental/array_api/skips.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Known failures for the array api tests.

# finfo return type misalignment (https://github.com/data-apis/array-api/issues/405)
array_api_tests/test_data_type_functions.py::test_finfo[float32]

# Test suite attempts in-place mutation:
array_api_tests/test_special_cases.py::test_iop
array_api_tests/test_special_cases.py::test_nan_propagation
Expand Down