Skip to content

Commit

Permalink
Change jax DeviceArray to ndarray (#591)
Browse files Browse the repository at this point in the history
* Change DeviceArray to jnp.ndarray

* Pump minimum jax version
  • Loading branch information
fehiepsi authored Apr 8, 2022
1 parent 8d69fd3 commit 1d07af1
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 13 deletions.
10 changes: 0 additions & 10 deletions funsor/interpretations.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,6 @@ def interpret(self, cls, *args):
@staticmethod
def make_hash_key(cls, *args):
backend = get_backend()
if backend == "jax":
# JAX DeviceArray has .__hash__ method but raise the unhashable error there.
from jax.interpreters.xla import DeviceArray

return tuple(
id(arg)
if isinstance(arg, DeviceArray) or not isinstance(arg, Hashable)
else arg
for arg in args
)
if backend == "torch":
# Avoid "ImportError: sys.meta_path is None" on shutdown.
from torch import Tensor
Expand Down
3 changes: 1 addition & 2 deletions funsor/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as onp
from jax import lax
from jax.core import Tracer
from jax.interpreters.xla import DeviceArray
from jax.scipy.linalg import cho_solve, solve_triangular
from jax.scipy.special import expit, gammaln, logsumexp

Expand All @@ -19,7 +18,7 @@
# Register Ops
################################################################################

array = (onp.generic, onp.ndarray, DeviceArray, Tracer)
array = (onp.generic, onp.ndarray, np.ndarray, Tracer)
ops.atanh.register(array)(np.arctanh)
ops.clamp.register(array)(np.clip)
ops.exp.register(array)(np.exp)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
],
extras_require={
"torch": ["pyro-ppl>=1.8.0", "torch>=1.11.0"],
"jax": ["numpyro>=0.7.0", "jax>=0.2.13", "jaxlib>=0.1.65"],
"jax": ["numpyro>=0.7.0", "jax>=0.2.21", "jaxlib>=0.1.71"],
"test": [
"black",
"flake8",
Expand Down

0 comments on commit 1d07af1

Please sign in to comment.