Skip to content

Commit 376ce9b

Browse files
Split SDE tests in half, to try and avoid GitHub runner issues?
1 parent 0f809d0 commit 376ce9b

File tree

2 files changed

+155
-157
lines changed

2 files changed

+155
-157
lines changed

test/test_sde.py renamed to test/test_sde1.py

Lines changed: 1 addition & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
from typing import Literal
22

33
import diffrax
4-
import jax
54
import jax.numpy as jnp
65
import jax.random as jr
7-
import jax.tree_util as jtu
8-
import lineax as lx
96
import pytest
10-
from diffrax import ControlTerm, MultiTerm, ODETerm, WeaklyDiagonalControlTerm
117

128
from .helpers import (
139
get_mlp_sde,
@@ -119,10 +115,7 @@ def get_dt_and_controller(level):
119115
# using a single reference solution. We use Euler if the solver is Ito
120116
# and Heun if the solver is Stratonovich.
121117
@pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders())
122-
@pytest.mark.parametrize(
123-
"dtype",
124-
(jnp.float64,),
125-
)
118+
@pytest.mark.parametrize("dtype", (jnp.float64,))
126119
def test_sde_strong_limit(
127120
solver_ctr, noise: Literal["any", "com", "add"], theoretical_order, dtype
128121
):
@@ -184,152 +177,3 @@ def test_sde_strong_limit(
184177
)
185178
error = path_l2_dist(correct_sol, sol)
186179
assert error < 0.05
187-
188-
189-
def _solvers():
190-
yield diffrax.SPaRK
191-
yield diffrax.GeneralShARK
192-
yield diffrax.SlowRK
193-
yield diffrax.ShARK
194-
yield diffrax.SRA1
195-
yield diffrax.SEA
196-
197-
198-
# Define the SDE
199-
def dict_drift(t, y, args):
200-
pytree, _ = args
201-
return jtu.tree_map(lambda _, x: -0.5 * x, pytree, y)
202-
203-
204-
def dict_diffusion(t, y, args):
205-
pytree, additive = args
206-
207-
def get_matrix(y_leaf):
208-
if additive:
209-
return 2.0 * jnp.ones(y_leaf.shape + (3,), dtype=jnp.float64)
210-
else:
211-
return 2.0 * jnp.broadcast_to(
212-
jnp.expand_dims(y_leaf, axis=y_leaf.ndim), y_leaf.shape + (3,)
213-
)
214-
215-
return jtu.tree_map(get_matrix, y)
216-
217-
218-
@pytest.mark.parametrize("shape", [(), (5, 2)])
219-
@pytest.mark.parametrize("solver_ctr", _solvers())
220-
@pytest.mark.parametrize(
221-
"dtype",
222-
(jnp.float64, jnp.complex128),
223-
)
224-
def test_sde_solver_shape(shape, solver_ctr, dtype):
225-
pytree = ({"a": 0, "b": [0, 0]}, 0, 0)
226-
key = jr.PRNGKey(0)
227-
y0 = jtu.tree_map(lambda _: jr.normal(key, shape, dtype=dtype), pytree)
228-
t0, t1, dt0 = 0.0, 1.0, 0.3
229-
230-
# Some solvers only work with additive noise
231-
additive = solver_ctr in [diffrax.ShARK, diffrax.SRA1, diffrax.SEA]
232-
args = (pytree, additive)
233-
solver = solver_ctr()
234-
bmkey = jr.key(1)
235-
struct = jax.ShapeDtypeStruct((3,), dtype)
236-
bm_shape = jtu.tree_map(lambda _: struct, pytree)
237-
bm = diffrax.VirtualBrownianTree(
238-
t0, t1, 0.1, bm_shape, bmkey, diffrax.SpaceTimeLevyArea
239-
)
240-
terms = MultiTerm(ODETerm(dict_drift), ControlTerm(dict_diffusion, bm))
241-
solution = diffrax.diffeqsolve(
242-
terms, solver, t0, t1, dt0, y0, args, saveat=diffrax.SaveAt(t1=True)
243-
)
244-
assert jtu.tree_structure(solution.ys) == jtu.tree_structure(y0)
245-
for leaf in jtu.tree_leaves(solution.ys):
246-
assert leaf[0].shape == shape
247-
248-
249-
def _weakly_diagonal_noise_helper(solver, dtype):
250-
w_shape = (3,)
251-
args = (0.5, 1.2)
252-
253-
def _diffusion(t, y, args):
254-
a, b = args
255-
return jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype)
256-
257-
def _drift(t, y, args):
258-
a, b = args
259-
return -a * y
260-
261-
y0 = jnp.ones(w_shape, dtype)
262-
263-
bm = diffrax.VirtualBrownianTree(
264-
0.0, 1.0, 0.05, w_shape, jr.key(0), diffrax.SpaceTimeLevyArea
265-
)
266-
267-
terms = MultiTerm(ODETerm(_drift), WeaklyDiagonalControlTerm(_diffusion, bm))
268-
saveat = diffrax.SaveAt(t1=True)
269-
solution = diffrax.diffeqsolve(
270-
terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat
271-
)
272-
assert solution.ys is not None
273-
assert solution.ys.shape == (1, 3)
274-
275-
276-
def _lineax_weakly_diagonal_noise_helper(solver, dtype):
277-
w_shape = (3,)
278-
args = (0.5, 1.2)
279-
280-
def _diffusion(t, y, args):
281-
a, b = args
282-
return lx.DiagonalLinearOperator(jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype))
283-
284-
def _drift(t, y, args):
285-
a, b = args
286-
return -a * y
287-
288-
y0 = jnp.ones(w_shape, dtype)
289-
290-
bm = diffrax.VirtualBrownianTree(
291-
0.0, 1.0, 0.05, w_shape, jr.PRNGKey(0), diffrax.SpaceTimeLevyArea
292-
)
293-
294-
terms = MultiTerm(ODETerm(_drift), ControlTerm(_diffusion, bm))
295-
saveat = diffrax.SaveAt(t1=True)
296-
solution = diffrax.diffeqsolve(
297-
terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat
298-
)
299-
assert solution.ys is not None
300-
assert solution.ys.shape == (1, 3)
301-
302-
303-
@pytest.mark.parametrize("solver_ctr", _solvers())
304-
@pytest.mark.parametrize(
305-
"dtype",
306-
(jnp.float64, jnp.complex128),
307-
)
308-
@pytest.mark.parametrize(
309-
"weak_type",
310-
("old", "lineax"),
311-
)
312-
def test_weakly_diagonal_noise(solver_ctr, dtype, weak_type):
313-
if weak_type == "old":
314-
_weakly_diagonal_noise_helper(solver_ctr(), dtype)
315-
elif weak_type == "lineax":
316-
_lineax_weakly_diagonal_noise_helper(solver_ctr(), dtype)
317-
else:
318-
raise ValueError("Invalid weak_type")
319-
320-
321-
@pytest.mark.parametrize(
322-
"dtype",
323-
(jnp.float64, jnp.complex128),
324-
)
325-
@pytest.mark.parametrize(
326-
"weak_type",
327-
("old", "lineax"),
328-
)
329-
def test_halfsolver_term_compatible(dtype, weak_type):
330-
if weak_type == "old":
331-
_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype)
332-
elif weak_type == "lineax":
333-
_lineax_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype)
334-
else:
335-
raise ValueError("Invalid weak_type")

test/test_sde2.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import diffrax
2+
import jax
3+
import jax.numpy as jnp
4+
import jax.random as jr
5+
import jax.tree_util as jtu
6+
import lineax as lx
7+
import pytest
8+
from diffrax import ControlTerm, MultiTerm, ODETerm, WeaklyDiagonalControlTerm
9+
10+
11+
def _solvers():
12+
yield diffrax.SPaRK
13+
yield diffrax.GeneralShARK
14+
yield diffrax.SlowRK
15+
yield diffrax.ShARK
16+
yield diffrax.SRA1
17+
yield diffrax.SEA
18+
19+
20+
# Define the SDE
21+
def dict_drift(t, y, args):
22+
pytree, _ = args
23+
return jtu.tree_map(lambda _, x: -0.5 * x, pytree, y)
24+
25+
26+
def dict_diffusion(t, y, args):
27+
pytree, additive = args
28+
29+
def get_matrix(y_leaf):
30+
if additive:
31+
return 2.0 * jnp.ones(y_leaf.shape + (3,), dtype=jnp.float64)
32+
else:
33+
return 2.0 * jnp.broadcast_to(
34+
jnp.expand_dims(y_leaf, axis=y_leaf.ndim), y_leaf.shape + (3,)
35+
)
36+
37+
return jtu.tree_map(get_matrix, y)
38+
39+
40+
@pytest.mark.parametrize("shape", [(), (5, 2)])
41+
@pytest.mark.parametrize("solver_ctr", _solvers())
42+
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
43+
def test_sde_solver_shape(shape, solver_ctr, dtype):
44+
pytree = ({"a": 0, "b": [0, 0]}, 0, 0)
45+
key = jr.PRNGKey(0)
46+
y0 = jtu.tree_map(lambda _: jr.normal(key, shape, dtype=dtype), pytree)
47+
t0, t1, dt0 = 0.0, 1.0, 0.3
48+
49+
# Some solvers only work with additive noise
50+
additive = solver_ctr in [diffrax.ShARK, diffrax.SRA1, diffrax.SEA]
51+
args = (pytree, additive)
52+
solver = solver_ctr()
53+
bmkey = jr.key(1)
54+
struct = jax.ShapeDtypeStruct((3,), dtype)
55+
bm_shape = jtu.tree_map(lambda _: struct, pytree)
56+
bm = diffrax.VirtualBrownianTree(
57+
t0, t1, 0.1, bm_shape, bmkey, diffrax.SpaceTimeLevyArea
58+
)
59+
terms = MultiTerm(ODETerm(dict_drift), ControlTerm(dict_diffusion, bm))
60+
solution = diffrax.diffeqsolve(
61+
terms, solver, t0, t1, dt0, y0, args, saveat=diffrax.SaveAt(t1=True)
62+
)
63+
assert jtu.tree_structure(solution.ys) == jtu.tree_structure(y0)
64+
for leaf in jtu.tree_leaves(solution.ys):
65+
assert leaf[0].shape == shape
66+
67+
68+
def _weakly_diagonal_noise_helper(solver, dtype):
69+
w_shape = (3,)
70+
args = (0.5, 1.2)
71+
72+
def _diffusion(t, y, args):
73+
a, b = args
74+
return jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype)
75+
76+
def _drift(t, y, args):
77+
a, b = args
78+
return -a * y
79+
80+
y0 = jnp.ones(w_shape, dtype)
81+
82+
bm = diffrax.VirtualBrownianTree(
83+
0.0, 1.0, 0.05, w_shape, jr.key(0), diffrax.SpaceTimeLevyArea
84+
)
85+
86+
terms = MultiTerm(ODETerm(_drift), WeaklyDiagonalControlTerm(_diffusion, bm))
87+
saveat = diffrax.SaveAt(t1=True)
88+
solution = diffrax.diffeqsolve(
89+
terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat
90+
)
91+
assert solution.ys is not None
92+
assert solution.ys.shape == (1, 3)
93+
94+
95+
def _lineax_weakly_diagonal_noise_helper(solver, dtype):
96+
w_shape = (3,)
97+
args = (0.5, 1.2)
98+
99+
def _diffusion(t, y, args):
100+
a, b = args
101+
return lx.DiagonalLinearOperator(jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype))
102+
103+
def _drift(t, y, args):
104+
a, b = args
105+
return -a * y
106+
107+
y0 = jnp.ones(w_shape, dtype)
108+
109+
bm = diffrax.VirtualBrownianTree(
110+
0.0, 1.0, 0.05, w_shape, jr.PRNGKey(0), diffrax.SpaceTimeLevyArea
111+
)
112+
113+
terms = MultiTerm(ODETerm(_drift), ControlTerm(_diffusion, bm))
114+
saveat = diffrax.SaveAt(t1=True)
115+
solution = diffrax.diffeqsolve(
116+
terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat
117+
)
118+
assert solution.ys is not None
119+
assert solution.ys.shape == (1, 3)
120+
121+
122+
@pytest.mark.parametrize("solver_ctr", _solvers())
123+
@pytest.mark.parametrize(
124+
"dtype",
125+
(jnp.float64, jnp.complex128),
126+
)
127+
@pytest.mark.parametrize(
128+
"weak_type",
129+
("old", "lineax"),
130+
)
131+
def test_weakly_diagonal_noise(solver_ctr, dtype, weak_type):
132+
if weak_type == "old":
133+
_weakly_diagonal_noise_helper(solver_ctr(), dtype)
134+
elif weak_type == "lineax":
135+
_lineax_weakly_diagonal_noise_helper(solver_ctr(), dtype)
136+
else:
137+
raise ValueError("Invalid weak_type")
138+
139+
140+
@pytest.mark.parametrize(
141+
"dtype",
142+
(jnp.float64, jnp.complex128),
143+
)
144+
@pytest.mark.parametrize(
145+
"weak_type",
146+
("old", "lineax"),
147+
)
148+
def test_halfsolver_term_compatible(dtype, weak_type):
149+
if weak_type == "old":
150+
_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype)
151+
elif weak_type == "lineax":
152+
_lineax_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype)
153+
else:
154+
raise ValueError("Invalid weak_type")

0 commit comments

Comments
 (0)