Skip to content

Commit

Permalink
Deprecate jax.clear_backends.
Browse files Browse the repository at this point in the history
`jax.clear_backends` does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Use `jax.clear_caches` if you only want to clean up compilation caches. For backward compatibilty or you really need to switch/reinitialize the default backend, use `jax.extend.backend.clear_backends`.

PiperOrigin-RevId: 616212663
  • Loading branch information
yueshengys authored and jax authors committed Mar 18, 2024
1 parent c2bbf9c commit 1d7ed6c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
* {func}`jax.clear_backends` is deprecated as it does not necessarily do what
its name suggests and can lead to unexpected consequences, e.g., it will not
destroy existing backends and release corresponding owned resources. Use
{func}`jax.clear_caches` if you only want to clean up compilation caches.
For backward compatibilty or you really need to switch/reinitialize the
default backend, use {func}`jax.extend.backend.clear_backends`.
* The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are
deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
Expand Down
8 changes: 7 additions & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as clear_backends
from jax._src.api import clear_backends as _deprecated_clear_backends
from jax._src.api import clear_caches as clear_caches
from jax._src.custom_derivatives import closure_convert as closure_convert
from jax._src.custom_derivatives import custom_gradient as custom_gradient
Expand Down Expand Up @@ -218,10 +218,16 @@
"or jax.tree_util.tree_map (any JAX version).",
_deprecated_tree_map
),
# Added Mar 18, 2024
"clear_backends": (
"jax.clear_backends is deprecated.",
_deprecated_tree_map
),
}

import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src.api import clear_backends as clear_backends
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
from jax._src.tree_util import tree_flatten as tree_flatten
from jax._src.tree_util import tree_leaves as tree_leaves
Expand Down

0 comments on commit 1d7ed6c

Please sign in to comment.