Skip to content

Commit

Permalink
Refactor array_api namespace, relying more directly on jax.numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed May 2, 2024
1 parent 0335487 commit b88e2e8
Show file tree
Hide file tree
Showing 23 changed files with 189 additions and 1,179 deletions.
27 changes: 17 additions & 10 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
}


def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType, ...]) -> bool:
def isdtype(dtype: DTypeLike, kind: str | DTypeLike | tuple[str | DTypeLike, ...]) -> bool:
"""Returns a boolean indicating whether a provided dtype is of a specified kind.
Args:
Expand All @@ -458,18 +458,25 @@ def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType, ...]) -> bo
True or False
"""
the_dtype = np.dtype(dtype)
kind_tuple: tuple[DType | str, ...] = kind if isinstance(kind, tuple) else (kind,)
kind_tuple: tuple[str | DTypeLike, ...] = (
kind if isinstance(kind, tuple) else (kind,)
)
options: set[DType] = set()
for kind in kind_tuple:
if isinstance(kind, str):
if kind not in _dtype_kinds:
raise ValueError(f"Unrecognized {kind=} expected one of {list(_dtype_kinds.keys())}")
if isinstance(kind, str) and kind in _dtype_kinds:
options.update(_dtype_kinds[kind])
elif isinstance(kind, np.dtype):
options.add(kind)
else:
# TODO(jakevdp): should we handle scalar types or ScalarMeta here?
raise TypeError(f"Expected kind to be a dtype, string, or tuple; got {kind=}")
continue
try:
_dtype = np.dtype(kind)
except TypeError as e:
if isinstance(kind, str):
raise ValueError(
f"Unrecognized {kind=} expected one of {list(_dtype_kinds.keys())}, "
"or a compatible input for jnp.dtype()")
raise TypeError(
f"Expected kind to be a dtype, string, or tuple; got {kind=}"
) from e
options.add(_dtype)
return the_dtype in options


Expand Down
200 changes: 90 additions & 110 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
>>> from jax.experimental import array_api as xp
>>> xp.__array_api_version__
'2022.12'
'2023.12'
>>> arr = xp.arange(1000)
Expand All @@ -38,68 +38,19 @@

from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__

from jax.experimental.array_api import (
fft as fft,
linalg as linalg,
)

from jax.experimental.array_api._constants import (
e as e,
inf as inf,
nan as nan,
newaxis as newaxis,
pi as pi,
)

from jax.experimental.array_api._creation_functions import (
arange as arange,
asarray as asarray,
empty as empty,
empty_like as empty_like,
eye as eye,
from_dlpack as from_dlpack,
full as full,
full_like as full_like,
linspace as linspace,
meshgrid as meshgrid,
ones as ones,
ones_like as ones_like,
tril as tril,
triu as triu,
zeros as zeros,
zeros_like as zeros_like,
)

from jax.experimental.array_api._data_type_functions import (
astype as astype,
can_cast as can_cast,
finfo as finfo,
iinfo as iinfo,
isdtype as isdtype,
result_type as result_type,
)

from jax.experimental.array_api._dtypes import (
bool as bool,
int8 as int8,
int16 as int16,
int32 as int32,
int64 as int64,
uint8 as uint8,
uint16 as uint16,
uint32 as uint32,
uint64 as uint64,
float32 as float32,
float64 as float64,
complex64 as complex64,
complex128 as complex128,
)
from jax.experimental.array_api import fft as fft
from jax.experimental.array_api import linalg as linalg

from jax.experimental.array_api._elementwise_functions import (
from jax.numpy import (
abs as abs,
acos as acos,
acosh as acosh,
add as add,
all as all,
any as any,
argmax as argmax,
argmin as argmin,
argsort as argsort,
asin as asin,
asinh as asinh,
atan as atan,
Expand All @@ -111,22 +62,43 @@
bitwise_or as bitwise_or,
bitwise_right_shift as bitwise_right_shift,
bitwise_xor as bitwise_xor,
ceil as ceil,
clip as clip,
bool as bool,
broadcast_arrays as broadcast_arrays,
broadcast_to as broadcast_to,
can_cast as can_cast,
complex128 as complex128,
complex64 as complex64,
concat as concat,
conj as conj,
copysign as copysign,
cos as cos,
cosh as cosh,
cumulative_sum as cumulative_sum,
divide as divide,
e as e,
empty as empty,
empty_like as empty_like,
equal as equal,
exp as exp,
expand_dims as expand_dims,
expm1 as expm1,
floor as floor,
flip as flip,
float32 as float32,
float64 as float64,
floor_divide as floor_divide,
from_dlpack as from_dlpack,
full as full,
full_like as full_like,
greater as greater,
greater_equal as greater_equal,
hypot as hypot,
iinfo as iinfo,
imag as imag,
inf as inf,
int16 as int16,
int32 as int32,
int64 as int64,
int8 as int8,
isdtype as isdtype,
isfinite as isfinite,
isinf as isinf,
isnan as isnan,
Expand All @@ -141,91 +113,99 @@
logical_not as logical_not,
logical_or as logical_or,
logical_xor as logical_xor,
matmul as matmul,
matrix_transpose as matrix_transpose,
max as max,
maximum as maximum,
mean as mean,
meshgrid as meshgrid,
min as min,
minimum as minimum,
moveaxis as moveaxis,
multiply as multiply,
nan as nan,
negative as negative,
newaxis as newaxis,
nonzero as nonzero,
not_equal as not_equal,
ones as ones,
ones_like as ones_like,
permute_dims as permute_dims,
pi as pi,
positive as positive,
pow as pow,
prod as prod,
real as real,
remainder as remainder,
repeat as repeat,
result_type as result_type,
roll as roll,
round as round,
searchsorted as searchsorted,
sign as sign,
signbit as signbit,
sin as sin,
sinh as sinh,
sort as sort,
sqrt as sqrt,
square as square,
squeeze as squeeze,
stack as stack,
subtract as subtract,
sum as sum,
take as take,
tan as tan,
tanh as tanh,
trunc as trunc,
)

from jax.experimental.array_api._indexing_functions import (
take as take,
tensordot as tensordot,
tile as tile,
tril as tril,
triu as triu,
uint16 as uint16,
uint32 as uint32,
uint64 as uint64,
uint8 as uint8,
unique_all as unique_all,
unique_counts as unique_counts,
unique_inverse as unique_inverse,
unique_values as unique_values,
unstack as unstack,
vecdot as vecdot,
where as where,
zeros as zeros,
zeros_like as zeros_like,
)

from jax.experimental.array_api._manipulation_functions import (
broadcast_arrays as broadcast_arrays,
broadcast_to as broadcast_to,
concat as concat,
expand_dims as expand_dims,
flip as flip,
moveaxis as moveaxis,
permute_dims as permute_dims,
repeat as repeat,
reshape as reshape,
roll as roll,
squeeze as squeeze,
stack as stack,
tile as tile,
unstack as unstack,
)

from jax.experimental.array_api._searching_functions import (
argmax as argmax,
argmin as argmin,
nonzero as nonzero,
searchsorted as searchsorted,
where as where,
from jax.experimental.array_api._creation_functions import (
arange as arange,
asarray as asarray,
eye as eye,
linspace as linspace,
)

from jax.experimental.array_api._set_functions import (
unique_all as unique_all,
unique_counts as unique_counts,
unique_inverse as unique_inverse,
unique_values as unique_values,
from jax.experimental.array_api._data_type_functions import (
astype as astype,
finfo as finfo,
)

from jax.experimental.array_api._sorting_functions import (
argsort as argsort,
sort as sort,
from jax.experimental.array_api._elementwise_functions import (
ceil as ceil,
clip as clip,
floor as floor,
hypot as hypot,
trunc as trunc,
)

from jax.experimental.array_api._statistical_functions import (
cumulative_sum as cumulative_sum,
max as max,
mean as mean,
min as min,
prod as prod,
std as std,
sum as sum,
var as var
var as var,
)

from jax.experimental.array_api._utility_functions import (
__array_namespace_info__ as __array_namespace_info__,
all as all,
any as any,
)

from jax.experimental.array_api._linear_algebra_functions import (
matmul as matmul,
matrix_transpose as matrix_transpose,
tensordot as tensordot,
vecdot as vecdot,
)

from jax.experimental.array_api import _array_methods
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/array_api/_array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import Any, Callable
from typing import Any

import jax
from jax._src.array import ArrayImpl
Expand Down
21 changes: 0 additions & 21 deletions jax/experimental/array_api/_constants.py

This file was deleted.

Loading

0 comments on commit b88e2e8

Please sign in to comment.