From cf856ad4a9c6d659a8708ca10fa28bd0a65cbe5f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 13 Mar 2024 07:00:48 -0700 Subject: [PATCH] Reverts 8e2a8b7b95e838947dcf581d146909d5c4128742 PiperOrigin-RevId: 615401711 --- jax/_src/config.py | 17 +---------------- jax/_src/lax/lax.py | 3 +-- tests/lax_metal_test.py | 6 ------ tests/lax_numpy_test.py | 11 +++-------- 4 files changed, 5 insertions(+), 32 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index c4d5117ece13..3ac3649ed5c4 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -211,7 +211,6 @@ def trace_context(): default_device.value, random_seed_offset.value, threefry_partitionable.value, softmax_custom_jvp.value, - new_select_transpose.value, enable_memories.value, disable_jit.value, jax_xla_profile_version.value, @@ -805,7 +804,6 @@ class _GlobalExtraJitContext(NamedTuple): random_seed_offset: int = 0 threefry_partitionable: bool = False softmax_custom_jvp: bool = False - new_select_transpose: bool = False xla_profile_version: int = 0 @@ -840,7 +838,6 @@ class _ThreadLocalExtraJitContext(NamedTuple): random_seed_offset: int | None = None threefry_partitionable: bool | None = None softmax_custom_jvp: bool | None = None - new_select_transpose: bool | None = None xla_profile_version: int | None = None @@ -1083,7 +1080,7 @@ def _update_jax_memories_thread_local(val): update_thread_local_hook=lambda val: update_thread_local_jit_state( threefry_partitionable=val)) -# TODO(mattjj): set default True then remove this flag (or die trying) + softmax_custom_jvp = define_bool_state( name='jax_softmax_custom_jvp', default=False, @@ -1097,18 +1094,6 @@ def _update_jax_memories_thread_local(val): softmax_custom_jvp=val)) -# TODO(mattjj): remove this flag -new_select_transpose = define_bool_state( - name='new_select_transpose', - default=True, - upgrade=True, - help=('Change select_n_p transpose rule to specialize on bools'), - update_global_hook=lambda val: _update_global_jit_state( - new_select_transpose=val), - update_thread_local_hook=lambda val: update_thread_local_jit_state( - new_select_transpose=val)) - - enable_custom_vjp_by_custom_transpose = define_bool_state( name='jax_enable_custom_vjp_by_custom_transpose', default=False, diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index bc321e3306cf..0af178467e4a 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3628,8 +3628,7 @@ def _select_transpose_rule(t, which, *cases): for c in cases] else: zeros = full_like(t, 0) - if (dtypes.dtype(which) == np.dtype(np.bool_) and - config.new_select_transpose.value): + if dtypes.dtype(which) == np.dtype(np.bool_): ct0 = select(which, zeros, t) if ad.is_undefined_primal(cases[0]) else None ct1 = select(which, t, zeros) if ad.is_undefined_primal(cases[1]) else None return (None, ct0, ct1) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 4b96fd150dfc..c84d3b1b66d6 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -4477,12 +4477,6 @@ def f(x): self.assertNotIn('False', str(jaxpr)) self.assertNotIn('True', str(jaxpr)) - # But if we set the option off, we get the old behavior. - with config.new_select_transpose(False): - jaxpr = jax.make_jaxpr(jax.grad(f))(3.) - self.assertIn('False', str(jaxpr)) - self.assertIn('True', str(jaxpr)) - def testWhereScalarPromotion(self): x = jnp.where(jnp.array([True, False]), 3, jnp.ones((2,), dtype=jnp.float32)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index e05a538e3d34..9b888bd22532 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4396,7 +4396,8 @@ def args_maker(): return [] self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - shape=all_shapes, dtype=all_dtypes, + shape=all_shapes, + dtype=all_dtypes, ) def testWhereOneArgument(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) @@ -4432,18 +4433,12 @@ def testWhereExtraCode(self): def f(x): return jnp.where(x > 0, x, -x) + jaxpr = jax.make_jaxpr(jax.grad(f))(3.) # Test no comparison literal True/False in jaxpr, and hence no comparison to # literals - jaxpr = jax.make_jaxpr(jax.grad(f))(3.) self.assertNotIn('False', str(jaxpr)) self.assertNotIn('True', str(jaxpr)) - # But if we set the option off, we get the old behavior. - with config.new_select_transpose(False): - jaxpr = jax.make_jaxpr(jax.grad(f))(3.) - self.assertIn('False', str(jaxpr)) - self.assertIn('True', str(jaxpr)) - def testWhereScalarPromotion(self): x = jnp.where(jnp.array([True, False]), 3, jnp.ones((2,), dtype=jnp.float32))