File tree Expand file tree Collapse file tree 1 file changed +17
-5
lines changed
test_autolens/point/model Expand file tree Collapse file tree 1 file changed +17
-5
lines changed Original file line number Diff line number Diff line change
1
+ try :
2
+ import jax
3
+
4
+ JAX_INSTALLED = True
5
+ except ImportError :
6
+ JAX_INSTALLED = False
7
+
1
8
import numpy as np
2
9
import pytest
3
10
@@ -14,23 +21,23 @@ def noise_map():
14
21
return np .array ([1.0 , 1.0 ])
15
22
16
23
17
- def test_andrew_implementation (
18
- data ,
19
- noise_map ,
20
- ):
24
+ @pytest .fixture
25
+ def fit (data , noise_map ):
21
26
model_positions = np .array (
22
27
[
23
28
(- 1.0749 , - 1.1 ),
24
29
(1.19117 , 1.175 ),
25
30
]
26
31
)
27
32
28
- fit = Fit (
33
+ return Fit (
29
34
data = data ,
30
35
noise_map = noise_map ,
31
36
model_positions = model_positions ,
32
37
)
33
38
39
+
40
+ def test_andrew_implementation (fit ):
34
41
assert np .allclose (
35
42
fit .all_permutations_log_likelihoods (),
36
43
[
@@ -41,6 +48,11 @@ def test_andrew_implementation(
41
48
assert fit .log_likelihood () == - 4.40375330990644
42
49
43
50
51
+ @pytest .mark .skipif (not JAX_INSTALLED , reason = "JAX is not installed" )
52
+ def test_jax (fit ):
53
+ assert jax .jit (fit .log_likelihood )() == - 4.40375330990644
54
+
55
+
44
56
def test_nan_model_positions (
45
57
data ,
46
58
noise_map ,
You can’t perform that action at this time.
0 commit comments