diff --git a/jax/__init__.py b/jax/__init__.py index 7086b9e9c66a..457b13ebf5b9 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -81,7 +81,7 @@ from jax._src.api import block_until_ready as block_until_ready from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies -from jax._src.api import clear_backends as clear_backends +from jax._src.api import clear_backends as _deprecated_clear_backends from jax._src.api import clear_caches as clear_caches from jax._src.custom_derivatives import closure_convert as closure_convert from jax._src.custom_derivatives import custom_gradient as custom_gradient @@ -218,10 +218,16 @@ "or jax.tree_util.tree_map (any JAX version).", _deprecated_tree_map ), + # Added Mar 15, 2024 + "clear_backends": ( + "jax.clear_backends is deprecated", + _deprecated_clear_backends + ), } import typing as _typing if _typing.TYPE_CHECKING: + from jax._src.api import clear_backends as clear_backends from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_leaves as tree_leaves