Skip to content

Commit

Permalink
Reverts 8e2a8b7
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615401711
  • Loading branch information
hawkinsp authored and jax authors committed Mar 13, 2024
1 parent 642f20d commit cf856ad
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 32 deletions.
17 changes: 1 addition & 16 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def trace_context():
default_device.value, random_seed_offset.value,
threefry_partitionable.value,
softmax_custom_jvp.value,
new_select_transpose.value,
enable_memories.value,
disable_jit.value,
jax_xla_profile_version.value,
Expand Down Expand Up @@ -805,7 +804,6 @@ class _GlobalExtraJitContext(NamedTuple):
random_seed_offset: int = 0
threefry_partitionable: bool = False
softmax_custom_jvp: bool = False
new_select_transpose: bool = False
xla_profile_version: int = 0


Expand Down Expand Up @@ -840,7 +838,6 @@ class _ThreadLocalExtraJitContext(NamedTuple):
random_seed_offset: int | None = None
threefry_partitionable: bool | None = None
softmax_custom_jvp: bool | None = None
new_select_transpose: bool | None = None
xla_profile_version: int | None = None


Expand Down Expand Up @@ -1083,7 +1080,7 @@ def _update_jax_memories_thread_local(val):
update_thread_local_hook=lambda val: update_thread_local_jit_state(
threefry_partitionable=val))

# TODO(mattjj): set default True then remove this flag (or die trying)

softmax_custom_jvp = define_bool_state(
name='jax_softmax_custom_jvp',
default=False,
Expand All @@ -1097,18 +1094,6 @@ def _update_jax_memories_thread_local(val):
softmax_custom_jvp=val))


# TODO(mattjj): remove this flag
new_select_transpose = define_bool_state(
name='new_select_transpose',
default=True,
upgrade=True,
help=('Change select_n_p transpose rule to specialize on bools'),
update_global_hook=lambda val: _update_global_jit_state(
new_select_transpose=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
new_select_transpose=val))


enable_custom_vjp_by_custom_transpose = define_bool_state(
name='jax_enable_custom_vjp_by_custom_transpose',
default=False,
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3628,8 +3628,7 @@ def _select_transpose_rule(t, which, *cases):
for c in cases]
else:
zeros = full_like(t, 0)
if (dtypes.dtype(which) == np.dtype(np.bool_) and
config.new_select_transpose.value):
if dtypes.dtype(which) == np.dtype(np.bool_):
ct0 = select(which, zeros, t) if ad.is_undefined_primal(cases[0]) else None
ct1 = select(which, t, zeros) if ad.is_undefined_primal(cases[1]) else None
return (None, ct0, ct1)
Expand Down
6 changes: 0 additions & 6 deletions tests/lax_metal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4477,12 +4477,6 @@ def f(x):
self.assertNotIn('False', str(jaxpr))
self.assertNotIn('True', str(jaxpr))

# But if we set the option off, we get the old behavior.
with config.new_select_transpose(False):
jaxpr = jax.make_jaxpr(jax.grad(f))(3.)
self.assertIn('False', str(jaxpr))
self.assertIn('True', str(jaxpr))

def testWhereScalarPromotion(self):
x = jnp.where(jnp.array([True, False]), 3,
jnp.ones((2,), dtype=jnp.float32))
Expand Down
11 changes: 3 additions & 8 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4396,7 +4396,8 @@ def args_maker(): return []
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
shape=all_shapes, dtype=all_dtypes,
shape=all_shapes,
dtype=all_dtypes,
)
def testWhereOneArgument(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
Expand Down Expand Up @@ -4432,18 +4433,12 @@ def testWhereExtraCode(self):
def f(x):
return jnp.where(x > 0, x, -x)

jaxpr = jax.make_jaxpr(jax.grad(f))(3.)
# Test no comparison literal True/False in jaxpr, and hence no comparison to
# literals
jaxpr = jax.make_jaxpr(jax.grad(f))(3.)
self.assertNotIn('False', str(jaxpr))
self.assertNotIn('True', str(jaxpr))

# But if we set the option off, we get the old behavior.
with config.new_select_transpose(False):
jaxpr = jax.make_jaxpr(jax.grad(f))(3.)
self.assertIn('False', str(jaxpr))
self.assertIn('True', str(jaxpr))

def testWhereScalarPromotion(self):
x = jnp.where(jnp.array([True, False]), 3,
jnp.ones((2,), dtype=jnp.float32))
Expand Down

0 comments on commit cf856ad

Please sign in to comment.