diff --git a/jax/BUILD b/jax/BUILD index 8386660f3c19..0ebfca2ba35f 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -1024,6 +1024,7 @@ pytype_strict_library( visibility = [":jax_extend_users"], deps = [ "//jax/extend", + "//jax/extend:backend", "//jax/extend:core", "//jax/extend:linear_util", "//jax/extend:random", diff --git a/jax/extend/BUILD b/jax/extend/BUILD index 7b367d3a2599..9b24200e047d 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -26,6 +26,7 @@ pytype_strict_library( name = "extend", srcs = ["__init__.py"], deps = [ + ":backend", ":core", ":linear_util", ":random", @@ -45,6 +46,12 @@ pytype_strict_library( deps = ["//jax:core"], ) +pytype_strict_library( + name = "backend", + srcs = ["backend.py"], + deps = ["//jax"], +) + pytype_strict_library( name = "random", srcs = ["random.py"], diff --git a/jax/extend/__init__.py b/jax/extend/__init__.py index 77c81488c5cb..3f4327dde917 100644 --- a/jax/extend/__init__.py +++ b/jax/extend/__init__.py @@ -29,6 +29,7 @@ """ from jax.extend import ( + backend as backend, core as core, linear_util as linear_util, random as random, diff --git a/jax/extend/backend.py b/jax/extend/backend.py new file mode 100644 index 000000000000..7aa2c8a06ba8 --- /dev/null +++ b/jax/extend/backend.py @@ -0,0 +1,20 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +from jax._src.api import ( + clear_backends as clear_backends, +) diff --git a/tests/extend_test.py b/tests/extend_test.py index 37f5c911d821..b49c1ac09214 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -18,6 +18,7 @@ import jax.extend as jex import jax.numpy as jnp +from jax._src import api from jax._src import abstract_arrays from jax._src import linear_util from jax._src import prng @@ -39,6 +40,7 @@ def test_symbols(self): self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl) # Assume these are tested elsewhere, only check equivalence + self.assertIs(jex.backend.clear_backends, api.clear_backends) self.assertIs(jex.core.array_types, abstract_arrays.array_types) self.assertIs(jex.linear_util.StoreException, linear_util.StoreException) self.assertIs(jex.linear_util.WrappedFun, linear_util.WrappedFun)