Skip to content

Commit

Permalink
Guard host transfers inside pure_callbacks from deadlocking the TPU.
Browse files Browse the repository at this point in the history
Also fix python/callback.cc to not swallow errors in numpy conversions.

PiperOrigin-RevId: 619375128
  • Loading branch information
pschuh authored and jax authors committed Mar 27, 2024
1 parent 4a9c8d1 commit 0b09762
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/python_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,20 @@ def f(x):
with self.assertRaisesRegex(Exception, "Errors should propagate."):
print(np.array(f(2.0)), flush=True)

@unittest.skipIf(xla_extension_version < 250, "jaxlib version too old")
def test_reentrant_error_propagation(self):
reentrant_fn = jax.jit(jnp.sin).lower(2.0).compile()

@jax.jit
def f(x):
return jax.pure_callback(reentrant_fn, x, x)

try:
np.array(f(2.0))
except:
# Only should not deadlock.
pass

def test_can_take_grad_of_pure_callback_with_custom_jvp(self):

@jax.custom_jvp
Expand Down

0 comments on commit 0b09762

Please sign in to comment.