diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index a84740026e91..1ce108b3916f 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -797,6 +797,20 @@ def f(x): with self.assertRaisesRegex(Exception, "Errors should propagate."): print(np.array(f(2.0)), flush=True) + @unittest.skipIf(xla_extension_version < 250, "jaxlib version too old") + def test_reentrant_error_propagation(self): + reentrant_fn = jax.jit(jnp.sin).lower(2.0).compile() + + @jax.jit + def f(x): + return jax.pure_callback(reentrant_fn, x, x) + + try: + np.array(f(2.0)) + except: + # Only should not deadlock. + pass + def test_can_take_grad_of_pure_callback_with_custom_jvp(self): @jax.custom_jvp