Skip to content

Commit

Permalink
Deprecate jax.interpreters xb, xc, xe abbreviations.
Browse files Browse the repository at this point in the history
Instead, import directly as jax.lib.xla_bridge, jax.lib.xla_client, jax.lib.xla_extension.
  • Loading branch information
jakevdp committed Jun 28, 2024
1 parent 8c889b5 commit 251dfca
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,27 @@
apply_primitive as apply_primitive,
)

from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src import xla_bridge as _xb
from jax._src.lib import xla_client as _xc

xe = xc._xla
Backend = xe.Client
_xe = _xc._xla
Backend = _xe.Client

# Deprecations
_deprecations = {
# Added 2024-06-28
"xb": (
"jax.interpreters.xla.xb is deprecated. Use jax.lib.xla_bridge instead.",
_xb
),
"xc": (
"jax.interpreters.xla.xc is deprecated. Use jax.lib.xla_client instead.",
_xc,
),
"xe": (
"jax.interpreters.xla.xe is deprecated. Use jax.lib.xla_extension instead.",
_xe,
),
# Finalized 2024-05-13; remove after 2024-08-13
"backend_specific_translations": (
"jax.interpreters.xla.backend_specific_translations is deprecated. "
Expand Down Expand Up @@ -69,6 +82,13 @@
),
}

import typing
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
if typing.TYPE_CHECKING:
xb = _xb
xc = _xc
xe = _xe
else:
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing

0 comments on commit 251dfca

Please sign in to comment.