From 48c5fab0233ba0d994131854984d7e1a35af92d3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 1 Aug 2024 14:38:59 -0700 Subject: [PATCH] [array api] fix deprecation to support old import pattern --- .github/workflows/ci-build.yaml | 2 +- jax/BUILD | 6 ++++- jax/experimental/__init__.py | 17 -------------- jax/experimental/array_api/__init__.py | 32 ++++++++++++++++++++++++++ tests/array_api_test.py | 4 ++-- 5 files changed, 40 insertions(+), 21 deletions(-) create mode 100644 jax/experimental/array_api/__init__.py diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 9869d0256f8d..743aae9b4e1e 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -141,7 +141,7 @@ jobs: PY_COLORS: 1 run: | pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst - pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/maps.py + pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/array_api documentation_render: diff --git a/jax/BUILD b/jax/BUILD index 69be704ba57c..66df0d2f7272 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -987,9 +987,13 @@ pytype_library( pytype_library( name = "experimental_array_api", + srcs = glob( + [ + "experimental/array_api/*.py", + ], + ), visibility = [":internal"] + jax_visibility("array_api"), deps = [ - ":experimental", ":jax", ], ) diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index dcca44773921..caf27ec7a8ca 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -25,20 +25,3 @@ from jax._src.earray import ( EArray as EArray ) - -from jax import numpy as _array_api - - -_deprecations = { - # Deprecated 01 Aug 2024 - "array_api": ( - "jax.experimental.array_api import is no longer required as of JAX v0.4.32; " - "jax.numpy supports the array API by default.", - _array_api - ), -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr -del _array_api diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py new file mode 100644 index 000000000000..69b25d0b6ad9 --- /dev/null +++ b/jax/experimental/array_api/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +import sys as _sys +import warnings as _warnings + +import jax.numpy as _array_api + +# Added 2024-08-01 +_warnings.warn( + "jax.experimental.array_api import is no longer required as of JAX v0.4.32; " + "jax.numpy supports the array API by default.", + DeprecationWarning, stacklevel=2 +) + +_sys.modules['jax.experimental.array_api'] = _array_api + +del _array_api, _sys, _warnings diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 41951fea381d..2f3042fa2a33 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -249,8 +249,8 @@ def test_array_namespace_method(self): def test_deprecated_import(self): msg = "jax.experimental.array_api import is no longer required" with self.assertWarnsRegex(DeprecationWarning, msg): - from jax.experimental import array_api - self.assertIs(array_api, ARRAY_API_NAMESPACE) + import jax.experimental.array_api as nx + self.assertIs(nx, ARRAY_API_NAMESPACE) class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):