Skip to content

Commit

Permalink
[array api] fix deprecation to support old import pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 1, 2024
1 parent aa9e1e4 commit 48c5fab
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ jobs:
PY_COLORS: 1
run: |
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/maps.py
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/array_api
documentation_render:
Expand Down
6 changes: 5 additions & 1 deletion jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -987,9 +987,13 @@ pytype_library(

pytype_library(
name = "experimental_array_api",
srcs = glob(
[
"experimental/array_api/*.py",
],
),
visibility = [":internal"] + jax_visibility("array_api"),
deps = [
":experimental",
":jax",
],
)
Expand Down
17 changes: 0 additions & 17 deletions jax/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,3 @@
from jax._src.earray import (
EArray as EArray
)

from jax import numpy as _array_api


_deprecations = {
# Deprecated 01 Aug 2024
"array_api": (
"jax.experimental.array_api import is no longer required as of JAX v0.4.32; "
"jax.numpy supports the array API by default.",
_array_api
),
}

from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _array_api
32 changes: 32 additions & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570

import sys as _sys
import warnings as _warnings

import jax.numpy as _array_api

# Added 2024-08-01
_warnings.warn(
"jax.experimental.array_api import is no longer required as of JAX v0.4.32; "
"jax.numpy supports the array API by default.",
DeprecationWarning, stacklevel=2
)

_sys.modules['jax.experimental.array_api'] = _array_api

del _array_api, _sys, _warnings
4 changes: 2 additions & 2 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def test_array_namespace_method(self):
def test_deprecated_import(self):
msg = "jax.experimental.array_api import is no longer required"
with self.assertWarnsRegex(DeprecationWarning, msg):
from jax.experimental import array_api
self.assertIs(array_api, ARRAY_API_NAMESPACE)
import jax.experimental.array_api as nx
self.assertIs(nx, ARRAY_API_NAMESPACE)


class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 48c5fab

Please sign in to comment.