diff --git a/CHANGELOG.md b/CHANGELOG.md index c1dd67b58d32..b00a43549930 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ Remember to align the itemized text with the first line of an item within a list * Deprecations * {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`. + * {func}`jax.clear_backends` is deprecated as it does not necessarily do what + its name suggests and can lead to unexpected consequences, e.g., it will not + destroy existing backends and release corresponding owned resources. Use + {func}`jax.clear_caches` if you only want to clean up compilation caches. + For backward compatibilty or you really need to switch/reinitialize the + default backend, use {func}`jax.extend.backend.clear_backends`. * The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the `spmd_axis_name` argument for expressing SPMD device-parallel computations. diff --git a/jax/__init__.py b/jax/__init__.py index 7086b9e9c66a..9a71ab6a5c2a 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -81,7 +81,7 @@ from jax._src.api import block_until_ready as block_until_ready from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies -from jax._src.api import clear_backends as clear_backends +from jax._src.api import clear_backends as _deprecated_clear_backends from jax._src.api import clear_caches as clear_caches from jax._src.custom_derivatives import closure_convert as closure_convert from jax._src.custom_derivatives import custom_gradient as custom_gradient @@ -218,10 +218,16 @@ "or jax.tree_util.tree_map (any JAX version).", _deprecated_tree_map ), + # Added Mar 18, 2024 + "clear_backends": ( + "jax.clear_backends is deprecated.", + _deprecated_tree_map + ), } import typing as _typing if _typing.TYPE_CHECKING: + from jax._src.api import clear_backends as clear_backends from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_leaves as tree_leaves