diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index b25e33ca2d40..d846377236c6 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -30,6 +30,7 @@ from jax._src import test_util as jtu from jax._src import util from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version from jax.experimental import io_callback from jax.experimental import pjit from jax.experimental.maps import xmap @@ -639,7 +640,6 @@ def h(x, y): out = h(jnp.arange(4.)[None], 4.) np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.) - def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self): def cb(x): @@ -716,6 +716,18 @@ def f(x): ValueError, "Pure callbacks do not support JVP."): f(2.) + @unittest.skipIf(xla_extension_version < 245, "jaxlib version too old") + def test_error_propagation(self): + def throws_error_fn(x): + raise RuntimeError("Errors should propagate.") + + @jax.jit + def f(x): + return jax.pure_callback(throws_error_fn, x, x) + + with self.assertRaisesRegex(Exception, "Errors should propagate."): + print(np.array(f(2.0)), flush=True) + def test_can_take_grad_of_pure_callback_with_custom_jvp(self): @jax.custom_jvp @@ -833,7 +845,6 @@ def f(self, ys): # callback alive. np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32)) - def test_callback_inside_xmap(self): def _callback(x):