From 1597dbb0f8bf81e07349725734915c01511d8fb1 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Sat, 19 Oct 2024 16:10:53 -0400 Subject: [PATCH] Clean up conditional lowering of tan after JAX v0.4.34 release. --- jax/_src/lax/lax.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 113c87b60ee0..152b8462476c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -60,7 +60,6 @@ from jax._src.lax.utils import ( _input_dtype, dtype_to_string, standard_abstract_eval, standard_multi_result_abstract_eval, standard_primitive) -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -2367,15 +2366,7 @@ def _tan_impl(x): tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) -# TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this -# lowering is mostly supported, but it fails on export or with the PJRT plugin -# because those modes target an older StableHLO version, and the -# compatibility updates from https://github.com/openxla/xla/pull/16649 aren't -# included in the 0.4.33 release. -if jaxlib_version <= (0, 4, 33): - mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan)) -else: - mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) +mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) def asin_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating):