diff --git a/tests/BUILD b/tests/BUILD index 8e60580d40bc..55eaf75fee84 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -48,6 +48,15 @@ jax_test( srcs = ["api_util_test.py"], ) +py_test( + name = "array_api_test", + srcs = ["array_api_test.py"], + deps = [ + "//jax", + "//jax:experimental_array_api", + ] + py_deps("absl/testing"), +) + jax_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], diff --git a/tests/array_api_test.py b/tests/array_api_test.py new file mode 100644 index 000000000000..97fc682398e4 --- /dev/null +++ b/tests/array_api_test.py @@ -0,0 +1,238 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Smoketest for jax.experimental.array_api + +The full test suite for the array API is run via the array-api-tests CI; +this is just a minimal smoke test to catch issues early. +""" +from __future__ import annotations + +from types import ModuleType + +from absl.testing import absltest +import jax +from jax import config +from jax.experimental import array_api + +config.parse_flags_with_absl() + +MAIN_NAMESPACE = { + 'abs', + 'acos', + 'acosh', + 'add', + 'all', + 'annotations', + 'any', + 'arange', + 'argmax', + 'argmin', + 'argsort', + 'asarray', + 'asin', + 'asinh', + 'astype', + 'atan', + 'atan2', + 'atanh', + 'bitwise_and', + 'bitwise_invert', + 'bitwise_left_shift', + 'bitwise_or', + 'bitwise_right_shift', + 'bitwise_xor', + 'bool', + 'broadcast_arrays', + 'broadcast_to', + 'can_cast', + 'ceil', + 'complex128', + 'complex64', + 'concat', + 'conj', + 'cos', + 'cosh', + 'divide', + 'e', + 'empty', + 'empty_like', + 'equal', + 'exp', + 'expand_dims', + 'expm1', + 'eye', + 'fft', + 'finfo', + 'flip', + 'float32', + 'float64', + 'floor', + 'floor_divide', + 'from_dlpack', + 'full', + 'full_like', + 'greater', + 'greater_equal', + 'iinfo', + 'imag', + 'inf', + 'int16', + 'int32', + 'int64', + 'int8', + 'isdtype', + 'isfinite', + 'isinf', + 'isnan', + 'less', + 'less_equal', + 'linalg', + 'linspace', + 'log', + 'log10', + 'log1p', + 'log2', + 'logaddexp', + 'logical_and', + 'logical_not', + 'logical_or', + 'logical_xor', + 'matmul', + 'matrix_transpose', + 'max', + 'mean', + 'meshgrid', + 'min', + 'multiply', + 'nan', + 'negative', + 'newaxis', + 'nonzero', + 'not_equal', + 'ones', + 'ones_like', + 'permute_dims', + 'pi', + 'positive', + 'pow', + 'prod', + 'real', + 'remainder', + 'reshape', + 'result_type', + 'roll', + 'round', + 'sign', + 'sin', + 'sinh', + 'sort', + 'sqrt', + 'square', + 'squeeze', + 'stack', + 'std', + 'subtract', + 'sum', + 'take', + 'tan', + 'tanh', + 'tensordot', + 'tril', + 'triu', + 'trunc', + 'uint16', + 'uint32', + 'uint64', + 'uint8', + 'unique_all', + 'unique_counts', + 'unique_inverse', + 'unique_values', + 'var', + 'vecdot', + 'where', + 'zeros', + 'zeros_like', +} + +LINALG_NAMESPACE = { + 'cholesky', + 'cross', + 'det', + 'diagonal', + 'eigh', + 'eigvalsh', + 'inv', + 'jax', + 'matmul', + 'matrix_norm', + 'matrix_power', + 'matrix_rank', + 'matrix_transpose', + 'outer', + 'pinv', + 'qr', + 'slogdet', + 'solve', + 'svd', + 'svdvals', + 'tensordot', + 'trace', + 'vecdot', + 'vector_norm', +} + +FFT_NAMESPACE = { + 'fft', + 'fftfreq', + 'fftn', + 'fftshift', + 'hfft', + 'ifft', + 'ifftn', + 'ifftshift', + 'ihfft', + 'irfft', + 'irfftn', + 'rfft', + 'rfftfreq', + 'rfftn', +} + + +def names(module: ModuleType) -> set[str]: + return {name for name in dir(module) if not name.startswith('_')} + + +class ArrayAPISmokeTest(absltest.TestCase): + """Smoke test for the array API.""" + + def test_main_namespace(self): + self.assertSetEqual(names(array_api), MAIN_NAMESPACE) + + def test_linalg_namespace(self): + self.assertSetEqual(names(array_api.linalg), LINALG_NAMESPACE) + + def test_fft_namespace(self): + self.assertSetEqual(names(array_api.fft), FFT_NAMESPACE) + + def test_array_namespace_method(self): + x = array_api.arange(20) + self.assertIsInstance(x, jax.Array) + self.assertIs(x.__array_namespace__(), array_api) + + +if __name__ == '__main__': + absltest.main()