Skip to content

Commit

Permalink
Merge pull request #22181 from jakevdp:xla-abbrevs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 648701764
  • Loading branch information
jax authors committed Jul 2, 2024
2 parents da76ebf + 251dfca commit 92ebb53
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,27 @@
apply_primitive as apply_primitive,
)

from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src import xla_bridge as _xb
from jax._src.lib import xla_client as _xc

xe = xc._xla
Backend = xe.Client
_xe = _xc._xla
Backend = _xe.Client

# Deprecations
_deprecations = {
# Added 2024-06-28
"xb": (
"jax.interpreters.xla.xb is deprecated. Use jax.lib.xla_bridge instead.",
_xb
),
"xc": (
"jax.interpreters.xla.xc is deprecated. Use jax.lib.xla_client instead.",
_xc,
),
"xe": (
"jax.interpreters.xla.xe is deprecated. Use jax.lib.xla_extension instead.",
_xe,
),
# Finalized 2024-05-13; remove after 2024-08-13
"backend_specific_translations": (
"jax.interpreters.xla.backend_specific_translations is deprecated. "
Expand Down Expand Up @@ -69,6 +82,13 @@
),
}

import typing
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
if typing.TYPE_CHECKING:
xb = _xb
xc = _xc
xe = _xe
else:
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing

0 comments on commit 92ebb53

Please sign in to comment.