Skip to content

Commit 984ee55

Browse files
Test that exception is raised when size is not static
1 parent 328864b commit 984ee55

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

tests/link/jax/test_random.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -621,12 +621,10 @@ def test_random_concrete_shape_subtensor_tuple():
621621
assert jax_fn(np.ones((2, 3))).shape == (2,)
622622

623623

624-
@pytest.mark.xfail(
625-
reason="`size_at` should be specified as a static argument", strict=True
626-
)
627624
def test_random_concrete_shape_graph_input():
625+
"""JAX cannot JIT-compile random variables whose `size` argument is not static."""
628626
rng = shared(np.random.RandomState(123))
629627
size_at = at.scalar()
630628
out = at.random.normal(0, 1, size=size_at, rng=rng)
631-
jax_fn = function([size_at], out, mode=jax_mode)
632-
assert jax_fn(10).shape == (10,)
629+
with pytest.raises(NotImplementedError, match=r".* concrete values .*"):
630+
function([size_at], out, mode=jax_mode)

0 commit comments

Comments
 (0)