From 1d07af18c21894dd56e2f4f877c7845430c3b729 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 8 Apr 2022 13:51:44 -0400 Subject: [PATCH] Change jax DeviceArray to ndarray (#591) * Change DeviceArray to jnp.ndarray * Pump minimum jax version --- funsor/interpretations.py | 10 ---------- funsor/jax/ops.py | 3 +-- setup.py | 2 +- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/funsor/interpretations.py b/funsor/interpretations.py index 1961418e..b349e2ab 100644 --- a/funsor/interpretations.py +++ b/funsor/interpretations.py @@ -53,16 +53,6 @@ def interpret(self, cls, *args): @staticmethod def make_hash_key(cls, *args): backend = get_backend() - if backend == "jax": - # JAX DeviceArray has .__hash__ method but raise the unhashable error there. - from jax.interpreters.xla import DeviceArray - - return tuple( - id(arg) - if isinstance(arg, DeviceArray) or not isinstance(arg, Hashable) - else arg - for arg in args - ) if backend == "torch": # Avoid "ImportError: sys.meta_path is None" on shutdown. from torch import Tensor diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 9111a4de..238227dd 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -9,7 +9,6 @@ import numpy as onp from jax import lax from jax.core import Tracer -from jax.interpreters.xla import DeviceArray from jax.scipy.linalg import cho_solve, solve_triangular from jax.scipy.special import expit, gammaln, logsumexp @@ -19,7 +18,7 @@ # Register Ops ################################################################################ -array = (onp.generic, onp.ndarray, DeviceArray, Tracer) +array = (onp.generic, onp.ndarray, np.ndarray, Tracer) ops.atanh.register(array)(np.arctanh) ops.clamp.register(array)(np.clip) ops.exp.register(array)(np.exp) diff --git a/setup.py b/setup.py index ea5da5f6..ab9d9a43 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ ], extras_require={ "torch": ["pyro-ppl>=1.8.0", "torch>=1.11.0"], - "jax": ["numpyro>=0.7.0", "jax>=0.2.13", "jaxlib>=0.1.65"], + "jax": ["numpyro>=0.7.0", "jax>=0.2.21", "jaxlib>=0.1.71"], "test": [ "black", "flake8",