Skip to content

Commit

Permalink
Propagate effects errors to the results (only if effects are enabled).
Browse files Browse the repository at this point in the history
This will now happen when results of effectful computations are
converted to numpy arrays.

PiperOrigin-RevId: 615883363
  • Loading branch information
pschuh authored and jax authors committed Mar 14, 2024
1 parent 6f38f27 commit 9a00721
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions tests/python_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax.experimental import io_callback
from jax.experimental import pjit
from jax.experimental.maps import xmap
Expand Down Expand Up @@ -639,7 +640,6 @@ def h(x, y):
out = h(jnp.arange(4.)[None], 4.)
np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.)


def test_vmap_vectorized_callback_errors_if_returns_wrong_shape(self):

def cb(x):
Expand Down Expand Up @@ -716,6 +716,18 @@ def f(x):
ValueError, "Pure callbacks do not support JVP."):
f(2.)

@unittest.skipIf(xla_extension_version < 245, "jaxlib version too old")
def test_error_propagation(self):
def throws_error_fn(x):
raise RuntimeError("Errors should propagate.")

@jax.jit
def f(x):
return jax.pure_callback(throws_error_fn, x, x)

with self.assertRaisesRegex(Exception, "Errors should propagate."):
print(np.array(f(2.0)), flush=True)

def test_can_take_grad_of_pure_callback_with_custom_jvp(self):

@jax.custom_jvp
Expand Down Expand Up @@ -833,7 +845,6 @@ def f(self, ys):
# callback alive.
np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32))


def test_callback_inside_xmap(self):

def _callback(x):
Expand Down

0 comments on commit 9a00721

Please sign in to comment.