From 59a188dfb443b736d9f68f9042deef5102bf8c2e Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 22 Sep 2023 18:46:22 -0400 Subject: [PATCH] avoid breakage in old jax version without jax.extend (#1647) * avoid breakage in old jax version without jax.extend * fix lint --- numpyro/ops/provenance.py | 6 +++++- test/ops/test_provenance.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/numpyro/ops/provenance.py b/numpyro/ops/provenance.py index e7900eeca..9e8eecc2a 100644 --- a/numpyro/ops/provenance.py +++ b/numpyro/ops/provenance.py @@ -5,7 +5,11 @@ from jax.api_util import flatten_fun, shaped_abstractify import jax.core as core from jax.experimental.pjit import pjit_p -import jax.extend.linear_util as lu + +try: + import jax.extend.linear_util as lu +except ImportError: + import jax.linear_util as lu from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic from jax.interpreters.pxla import xla_pmap_p import jax.numpy as jnp diff --git a/test/ops/test_provenance.py b/test/ops/test_provenance.py index 5f9361422..a64fcaadc 100644 --- a/test/ops/test_provenance.py +++ b/test/ops/test_provenance.py @@ -8,7 +8,11 @@ import jax from jax.api_util import flatten_fun_nokwargs import jax.core as core -import jax.extend.linear_util as lu + +try: + import jax.extend.linear_util as lu +except ImportError: + import jax.linear_util as lu import jax.numpy as jnp from numpyro.ops.provenance import eval_provenance