Skip to content

Commit

Permalink
hack in a backport guard for jax version
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewfeickert committed Aug 16, 2023
1 parent 84ced59 commit b440bba
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/pyhf/tensor/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@

log = logging.getLogger(__name__)

# v0.7.x backport hack
from sys import version_info

if version_info < (3, 8):
import jax

_jax__version__ = jax.__version__
else:
from importlib.metadata import version

_jax__version__ = version("jax")
_jax__version__ = tuple(map(int, (_jax__version__.split("."))))
_old_jax_version = _jax__version__ < (0, 4, 1)

if not _old_jax_version:
from jax import Array


class _BasicPoisson:
def __init__(self, rate):
Expand Down Expand Up @@ -54,10 +71,10 @@ class jax_backend:
__slots__ = ['name', 'precision', 'dtypemap', 'default_do_grad']

#: The array type for jax
array_type = jnp.DeviceArray
array_type = jnp.DeviceArray if _old_jax_version else Array

#: The array content type for jax
array_subtype = jnp.DeviceArray
array_subtype = jnp.DeviceArray if _old_jax_version else Array

def __init__(self, **kwargs):
self.name = 'jax'
Expand Down

0 comments on commit b440bba

Please sign in to comment.