Skip to content

Commit

Permalink
Propagate effects errors to the results (only if effects are enabled).
Browse files Browse the repository at this point in the history
This will now happen when results of effectful computations are
converted to numpy arrays.

PiperOrigin-RevId: 615618732
  • Loading branch information
pschuh authored and jax authors committed Mar 14, 2024
1 parent 6046d7d commit 7cae272
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tests/python_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax.experimental import io_callback
from jax.experimental import pjit
from jax.experimental.maps import xmap
Expand Down Expand Up @@ -716,6 +717,18 @@ def f(x):
ValueError, "Pure callbacks do not support JVP."):
f(2.)

@unittest.skipIf(xla_extension_version < 245, 'jaxlib version too old')
def test_error_propagation(self):
def throws_error_fn(x):
raise RuntimeError("Errors should propagate.")

@jax.jit
def f(x):
return jax.pure_callback(throws_error_fn, x, x)

with self.assertRaisesRegex(RuntimeError, "Errors should propagate."):
print(np.array(f(2.0)), flush=True)

def test_can_take_grad_of_pure_callback_with_custom_jvp(self):

@jax.custom_jvp
Expand Down

0 comments on commit 7cae272

Please sign in to comment.