diff --git a/tests/rigid_body_test.py b/tests/rigid_body_test.py index 09d1d9a..608bbdf 100644 --- a/tests/rigid_body_test.py +++ b/tests/rigid_body_test.py @@ -29,6 +29,8 @@ from jax import test_util as jtu from jax import config as jax_config +jax_config.update('jax_disable_jit', True) + import jax.numpy as jnp from jax_md import quantity