From b0eb8b9c466c1f642900020daa016cd82f473750 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 18 Mar 2024 15:37:40 +0000 Subject: [PATCH] Add __array_namespace_info__ and corresponding utilities --- jax/experimental/array_api/__init__.py | 1 + .../array_api/_utility_functions.py | 91 +++++++++++++++++++ tests/array_api_test.py | 86 ++++++++++++++++++ 3 files changed, 178 insertions(+) diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 3169f9667256..0e6b65e0fc75 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -204,6 +204,7 @@ ) from jax.experimental.array_api._utility_functions import ( + __array_namespace_info__ as __array_namespace_info__, all as all, any as any, ) diff --git a/jax/experimental/array_api/_utility_functions.py b/jax/experimental/array_api/_utility_functions.py index 60d739277627..c12b831e079e 100644 --- a/jax/experimental/array_api/_utility_functions.py +++ b/jax/experimental/array_api/_utility_functions.py @@ -11,8 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import jax +from typing import List, Tuple +from jax._src.sharding import Sharding +from jax._src.xla_bridge import backends +from jax._src.lib import xla_client as xc +from jax._src import dtypes as _dtypes, config def all(x, /, *, axis=None, keepdims=False): @@ -23,3 +29,88 @@ def all(x, /, *, axis=None, keepdims=False): def any(x, /, *, axis=None, keepdims=False): """Tests whether any input array element evaluates to True along a specified axis.""" return jax.numpy.any(x, axis=axis, keepdims=keepdims) + +class __array_namespace_info__: + + def __init__(self): + self._default_dtypes = self._build_default_dtype_dict() + self._capabilities = { + "boolean indexing": True, + "data-dependent shapes": False, + } + self._data_types = self._build_dtype_dict() + + def _build_default_dtype_dict(self): + default_dtypes = { + "real floating": "f", + "complex floating": "c", + "integral": "i", + "indexing": "i", + } + for dtype_name, kind in default_dtypes.items(): + dtype = _dtypes._default_types.get(kind) + dtype = _dtypes.canonicalize_dtype(dtype) + default_dtypes[dtype_name] = dtype + return default_dtypes + + def _build_dtype_dict(self): + data_types = { + "signed integer": ["int8", "int16", "int32", "int64"], + "unsigned integer": ["uint8", "uint16", "uint32", "uint64"], + "real floating": ["float32", "float64"], + "complex floating": ["complex64", "complex128"], + } + if not config.enable_x64.value: + for category in data_types: + data_types[category] = data_types[category][:-1] + + data_types["bool"] = ["bool"] + + for category in data_types: + _dtype_dict = {} + for name in data_types[category]: + _dtype_dict[name] = _dtypes.dtype(name) + data_types[category] = _dtype_dict + data_types["integral"] = ( + data_types["signed integer"] | data_types["unsigned integer"] + ) + data_types["numeric"] = ( + data_types["integral"] + | data_types["real floating"] + | data_types["complex floating"] + ) + return data_types + + def default_device(self): + # Note that since arrays are create uncommitted they technically do not + # have a default device in the classical sense. Functions that accept a + # device parameter generally have a default value of None anyways, so + # to reconcile those two facts we simply return None as our default device + # so that callers e.g. use jax.device_put(x, jnp.default_device()) to + # achieve equivalent results to jax.device_put(x). See gh-20200 for + # details. + return None + + def devices(self): + available_devices: List[xc.Device] = [] + for _backend in backends(): + available_devices.extend(jax.devices(_backend)) + return available_devices + + def capabilities(self): + return self._capabilities + + def default_dtypes(self): + return self._default_dtypes + + def dtypes(self, *, device: xc.Device | Sharding | None = None, kind: str | Tuple[str, ...] | None = None): + # Array API supported dtypes are device-independent in JAX + if kind is None: + out_dict = self._data_types["numeric"] | self._data_types["bool"] + elif isinstance(kind, tuple): + out_dict = {} + for _kind in kind: + out_dict |= self._data_types[_kind] + else: + out_dict = self._data_types[kind] + return out_dict diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 0d4893e4939e..db90d622a26e 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -24,7 +24,11 @@ from absl.testing import absltest import jax from jax import config +import jax.numpy as jnp from jax.experimental import array_api +from jax._src.dtypes import _default_types, canonicalize_dtype +from jax._src import test_util as jtu +from absl.testing import parameterized config.parse_flags_with_absl() @@ -233,5 +237,87 @@ def test_array_namespace_method(self): self.assertIs(x.__array_namespace__(), array_api) +class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase): + + info = array_api.__array_namespace_info__() + + def setUp(self): + super().setUp() + self._boolean = self.build_target_dict(["bool"]) + self._signed = self.build_target_dict(["int8", "int16", "int32"]) + self._unsigned = self.build_target_dict(["uint8", "uint16", "uint32"]) + self._floating = self.build_target_dict(["float32"]) + self._complex = self.build_target_dict(["complex64"]) + if config.jax_enable_x64: + self._signed["int64"] = jnp.dtype("int64") + self._unsigned["uint64"] = jnp.dtype("uint64") + self._floating["float64"] = jnp.dtype("float64") + self._complex["complex128"] = jnp.dtype("complex128") + + def build_target_dict(self, dtypes): + out = {} + for name in dtypes: + out[name] = jnp.dtype(name) + return out + + def test_capabilities_info(self): + capabilities = self.info.capabilities() + assert capabilities["boolean indexing"] + assert not capabilities["data-dependent shapes"] + + def test_default_device(self): + assert self.info.default_device() is None + + def test_devices_info(self): + devices = self.info.devices() + x = array_api.arange(5) + # Sinfoty check that the outputs of __array_namespace_info__.devices() can + # be directly passed to Array API creation functions + for device in devices: + self.assertArraysEqual(x, array_api.arange(x.shape[0], device=device)) + + def test_default_dtypes_info(self): + _default_dtypes = { + "real floating": "f", + "complex floating": "c", + "integral": "i", + "indexing": "i", + } + for dtype_name, kind in _default_dtypes.items(): + dtype = _default_types.get(kind) + dtype = canonicalize_dtype(dtype) + _default_dtypes[dtype_name] = dtype + assert self.info.default_dtypes() == _default_dtypes + + @parameterized.parameters("bool", "signed integer", "real floating", + "complex floating", "integral", "numeric", None, + (("real floating", "complex floating"),), + (("integral", "signed integer"),), + (("integral", "bool"),)) + def test_dtypes_info(self, kind): + + info_dict = self.info.dtypes(kind=kind) + control = { + "bool":self._boolean, + "signed integer":self._signed, + "unsigned integer":self._unsigned, + "real floating":self._floating, + "complex floating":self._complex, + } + control["integral"] = self._signed | self._unsigned + control["numeric"] = ( + self._signed | self._unsigned | self._floating | self._complex + ) + target_dict = {} + if kind is None: + target_dict = control["numeric"] | self._boolean + elif isinstance(kind, tuple): + target_dict = {} + for _kind in kind: + target_dict |= control[_kind] + else: + target_dict = control[kind] + assert info_dict == target_dict + if __name__ == '__main__': absltest.main()