From 7cae272390f15353733e628ec1084572c26c807d Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 13 Mar 2024 19:11:39 -0700 Subject: [PATCH] Propagate effects errors to the results (only if effects are enabled). This will now happen when results of effectful computations are converted to numpy arrays. PiperOrigin-RevId: 615618732 --- tests/python_callback_test.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index b25e33ca2d40..c33b38634b77 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -30,6 +30,7 @@ from jax._src import test_util as jtu from jax._src import util from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version from jax.experimental import io_callback from jax.experimental import pjit from jax.experimental.maps import xmap @@ -716,6 +717,18 @@ def f(x): ValueError, "Pure callbacks do not support JVP."): f(2.) + @unittest.skipIf(xla_extension_version < 245, 'jaxlib version too old') + def test_error_propagation(self): + def throws_error_fn(x): + raise RuntimeError("Errors should propagate.") + + @jax.jit + def f(x): + return jax.pure_callback(throws_error_fn, x, x) + + with self.assertRaisesRegex(RuntimeError, "Errors should propagate."): + print(np.array(f(2.0)), flush=True) + def test_can_take_grad_of_pure_callback_with_custom_jvp(self): @jax.custom_jvp