Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[array-api] add simple smoketest target for standard CI testing #18685

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ jax_test(
srcs = ["api_util_test.py"],
)

py_test(
name = "array_api_test",
srcs = ["array_api_test.py"],
deps = [
"//jax",
"//jax:experimental_array_api",
] + py_deps("absl/testing"),
)

jax_test(
name = "array_interoperability_test",
srcs = ["array_interoperability_test.py"],
Expand Down
238 changes: 238 additions & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Smoketest for jax.experimental.array_api
The full test suite for the array API is run via the array-api-tests CI;
this is just a minimal smoke test to catch issues early.
"""
from __future__ import annotations

from types import ModuleType

from absl.testing import absltest
import jax
from jax import config
from jax.experimental import array_api

config.parse_flags_with_absl()

MAIN_NAMESPACE = {
'abs',
'acos',
'acosh',
'add',
'all',
'annotations',
'any',
'arange',
'argmax',
'argmin',
'argsort',
'asarray',
'asin',
'asinh',
'astype',
'atan',
'atan2',
'atanh',
'bitwise_and',
'bitwise_invert',
'bitwise_left_shift',
'bitwise_or',
'bitwise_right_shift',
'bitwise_xor',
'bool',
'broadcast_arrays',
'broadcast_to',
'can_cast',
'ceil',
'complex128',
'complex64',
'concat',
'conj',
'cos',
'cosh',
'divide',
'e',
'empty',
'empty_like',
'equal',
'exp',
'expand_dims',
'expm1',
'eye',
'fft',
'finfo',
'flip',
'float32',
'float64',
'floor',
'floor_divide',
'from_dlpack',
'full',
'full_like',
'greater',
'greater_equal',
'iinfo',
'imag',
'inf',
'int16',
'int32',
'int64',
'int8',
'isdtype',
'isfinite',
'isinf',
'isnan',
'less',
'less_equal',
'linalg',
'linspace',
'log',
'log10',
'log1p',
'log2',
'logaddexp',
'logical_and',
'logical_not',
'logical_or',
'logical_xor',
'matmul',
'matrix_transpose',
'max',
'mean',
'meshgrid',
'min',
'multiply',
'nan',
'negative',
'newaxis',
'nonzero',
'not_equal',
'ones',
'ones_like',
'permute_dims',
'pi',
'positive',
'pow',
'prod',
'real',
'remainder',
'reshape',
'result_type',
'roll',
'round',
'sign',
'sin',
'sinh',
'sort',
'sqrt',
'square',
'squeeze',
'stack',
'std',
'subtract',
'sum',
'take',
'tan',
'tanh',
'tensordot',
'tril',
'triu',
'trunc',
'uint16',
'uint32',
'uint64',
'uint8',
'unique_all',
'unique_counts',
'unique_inverse',
'unique_values',
'var',
'vecdot',
'where',
'zeros',
'zeros_like',
}

LINALG_NAMESPACE = {
'cholesky',
'cross',
'det',
'diagonal',
'eigh',
'eigvalsh',
'inv',
'jax',
'matmul',
'matrix_norm',
'matrix_power',
'matrix_rank',
'matrix_transpose',
'outer',
'pinv',
'qr',
'slogdet',
'solve',
'svd',
'svdvals',
'tensordot',
'trace',
'vecdot',
'vector_norm',
}

FFT_NAMESPACE = {
'fft',
'fftfreq',
'fftn',
'fftshift',
'hfft',
'ifft',
'ifftn',
'ifftshift',
'ihfft',
'irfft',
'irfftn',
'rfft',
'rfftfreq',
'rfftn',
}


def names(module: ModuleType) -> set[str]:
return {name for name in dir(module) if not name.startswith('_')}


class ArrayAPISmokeTest(absltest.TestCase):
"""Smoke test for the array API."""

def test_main_namespace(self):
self.assertSetEqual(names(array_api), MAIN_NAMESPACE)

def test_linalg_namespace(self):
self.assertSetEqual(names(array_api.linalg), LINALG_NAMESPACE)

def test_fft_namespace(self):
self.assertSetEqual(names(array_api.fft), FFT_NAMESPACE)

def test_array_namespace_method(self):
x = array_api.arange(20)
self.assertIsInstance(x, jax.Array)
self.assertIs(x.__array_namespace__(), array_api)


if __name__ == '__main__':
absltest.main()