Skip to content

Commit 07f27f6

Browse files
committed
test jitting
1 parent 62f7c54 commit 07f27f6

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

test_autolens/point/model/test_andrew_implementation.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
try:
2+
import jax
3+
4+
JAX_INSTALLED = True
5+
except ImportError:
6+
JAX_INSTALLED = False
7+
18
import numpy as np
29
import pytest
310

@@ -14,23 +21,23 @@ def noise_map():
1421
return np.array([1.0, 1.0])
1522

1623

17-
def test_andrew_implementation(
18-
data,
19-
noise_map,
20-
):
24+
@pytest.fixture
25+
def fit(data, noise_map):
2126
model_positions = np.array(
2227
[
2328
(-1.0749, -1.1),
2429
(1.19117, 1.175),
2530
]
2631
)
2732

28-
fit = Fit(
33+
return Fit(
2934
data=data,
3035
noise_map=noise_map,
3136
model_positions=model_positions,
3237
)
3338

39+
40+
def test_andrew_implementation(fit):
3441
assert np.allclose(
3542
fit.all_permutations_log_likelihoods(),
3643
[
@@ -41,6 +48,11 @@ def test_andrew_implementation(
4148
assert fit.log_likelihood() == -4.40375330990644
4249

4350

51+
@pytest.mark.skipif(not JAX_INSTALLED, reason="JAX is not installed")
52+
def test_jax(fit):
53+
assert jax.jit(fit.log_likelihood)() == -4.40375330990644
54+
55+
4456
def test_nan_model_positions(
4557
data,
4658
noise_map,

0 commit comments

Comments
 (0)