diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 3169f9667256..0e6b65e0fc75 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -204,6 +204,7 @@ ) from jax.experimental.array_api._utility_functions import ( + __array_namespace_info__ as __array_namespace_info__, all as all, any as any, ) diff --git a/jax/experimental/array_api/_utility_functions.py b/jax/experimental/array_api/_utility_functions.py index 60d739277627..1b4d87f3c852 100644 --- a/jax/experimental/array_api/_utility_functions.py +++ b/jax/experimental/array_api/_utility_functions.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax +from __future__ import annotations +import jax +from typing import Tuple +from jax._src.sharding import Sharding +from jax._src.lib import xla_client as xc +from jax._src import dtypes as _dtypes, config def all(x, /, *, axis=None, keepdims=False): """Tests whether all input array elements evaluate to True along a specified axis.""" @@ -23,3 +28,66 @@ def all(x, /, *, axis=None, keepdims=False): def any(x, /, *, axis=None, keepdims=False): """Tests whether any input array element evaluates to True along a specified axis.""" return jax.numpy.any(x, axis=axis, keepdims=keepdims) + +class __array_namespace_info__: + + def __init__(self): + self._capabilities = { + "boolean indexing": True, + "data-dependent shapes": False, + } + + + def _build_dtype_dict(self): + array_api_types = { + "bool", "int8", "int16", + "int32", "uint8", "uint16", + "uint32", "float32", "complex64" + } + if config.enable_x64.value: + array_api_types |= {"int64", "uint64", "float64", "complex128"} + return {category: {t.name: t for t in types if t.name in array_api_types} + for category, types in _dtypes._dtype_kinds.items()} + + def default_device(self): + # By default JAX arrays are uncommitted (device=None), meaning that + # JAX is free to choose the most efficient device placement. + return None + + def devices(self): + return jax.devices() + + def capabilities(self): + return self._capabilities + + def default_dtypes(self, *, device: xc.Device | Sharding | None = None): + # Array API supported dtypes are device-independent in JAX + del device + default_dtypes = { + "real floating": "f", + "complex floating": "c", + "integral": "i", + "indexing": "i", + } + return { + dtype_name: _dtypes.canonicalize_dtype( + _dtypes._default_types.get(kind) + ) for dtype_name, kind in default_dtypes.items() + } + + def dtypes( + self, *, + device: xc.Device | Sharding | None = None, + kind: str | Tuple[str, ...] | None = None): + # Array API supported dtypes are device-independent in JAX + del device + data_types = self._build_dtype_dict() + if kind is None: + out_dict = data_types["numeric"] | data_types["bool"] + elif isinstance(kind, tuple): + out_dict = {} + for _kind in kind: + out_dict |= data_types[_kind] + else: + out_dict = data_types[kind] + return out_dict diff --git a/tests/BUILD b/tests/BUILD index 98e48f04e28b..177956d2d347 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -54,6 +54,7 @@ py_test( deps = [ "//jax", "//jax:experimental_array_api", + "//jax:test_util", ] + py_deps("absl/testing"), ) diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 0d4893e4939e..4c3000652e28 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -21,9 +21,11 @@ from types import ModuleType -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax -from jax import config +import jax.numpy as jnp +from jax._src import config, test_util as jtu +from jax._src.dtypes import _default_types, canonicalize_dtype from jax.experimental import array_api config.parse_flags_with_absl() @@ -232,6 +234,86 @@ def test_array_namespace_method(self): self.assertIsInstance(x, jax.Array) self.assertIs(x.__array_namespace__(), array_api) +class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase): + + info = array_api.__array_namespace_info__() + + def setUp(self): + super().setUp() + self._boolean = self.build_dtype_dict(["bool"]) + self._signed = self.build_dtype_dict(["int8", "int16", "int32"]) + self._unsigned = self.build_dtype_dict(["uint8", "uint16", "uint32"]) + self._floating = self.build_dtype_dict(["float32"]) + self._complex = self.build_dtype_dict(["complex64"]) + if config.enable_x64.value: + self._signed["int64"] = jnp.dtype("int64") + self._unsigned["uint64"] = jnp.dtype("uint64") + self._floating["float64"] = jnp.dtype("float64") + self._complex["complex128"] = jnp.dtype("complex128") + self._integral = self._signed | self._unsigned + self._numeric = ( + self._signed | self._unsigned | self._floating | self._complex + ) + def build_dtype_dict(self, dtypes): + out = {} + for name in dtypes: + out[name] = jnp.dtype(name) + return out + + def test_capabilities_info(self): + capabilities = self.info.capabilities() + assert capabilities["boolean indexing"] + assert not capabilities["data-dependent shapes"] + + def test_default_device_info(self): + assert self.info.default_device() is None + + def test_devices_info(self): + assert self.info.devices() == jax.devices() + + def test_default_dtypes_info(self): + _default_dtypes = { + "real floating": "f", + "complex floating": "c", + "integral": "i", + "indexing": "i", + } + target_dict = { + dtype_name: canonicalize_dtype( + _default_types.get(kind) + ) for dtype_name, kind in _default_dtypes.items() + } + assert self.info.default_dtypes() == target_dict + + @parameterized.parameters( + "bool", "signed integer", "real floating", + "complex floating", "integral", "numeric", None, + (("real floating", "complex floating"),), + (("integral", "signed integer"),), + (("integral", "bool"),), + ) + def test_dtypes_info(self, kind): + + info_dict = self.info.dtypes(kind=kind) + control = { + "bool":self._boolean, + "signed integer":self._signed, + "unsigned integer":self._unsigned, + "real floating":self._floating, + "complex floating":self._complex, + "integral": self._integral, + "numeric": self._numeric + } + target_dict = {} + if kind is None: + target_dict = control["numeric"] | self._boolean + elif isinstance(kind, tuple): + target_dict = {} + for _kind in kind: + target_dict |= control[_kind] + else: + target_dict = control[kind] + assert info_dict == target_dict if __name__ == '__main__': absltest.main()