|
1 | 1 | from typing import Literal
|
2 | 2 |
|
3 | 3 | import diffrax
|
4 |
| -import jax |
5 | 4 | import jax.numpy as jnp
|
6 | 5 | import jax.random as jr
|
7 |
| -import jax.tree_util as jtu |
8 |
| -import lineax as lx |
9 | 6 | import pytest
|
10 |
| -from diffrax import ControlTerm, MultiTerm, ODETerm, WeaklyDiagonalControlTerm |
11 | 7 |
|
12 | 8 | from .helpers import (
|
13 | 9 | get_mlp_sde,
|
@@ -119,10 +115,7 @@ def get_dt_and_controller(level):
|
119 | 115 | # using a single reference solution. We use Euler if the solver is Ito
|
120 | 116 | # and Heun if the solver is Stratonovich.
|
121 | 117 | @pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders())
|
122 |
| -@pytest.mark.parametrize( |
123 |
| - "dtype", |
124 |
| - (jnp.float64,), |
125 |
| -) |
| 118 | +@pytest.mark.parametrize("dtype", (jnp.float64,)) |
126 | 119 | def test_sde_strong_limit(
|
127 | 120 | solver_ctr, noise: Literal["any", "com", "add"], theoretical_order, dtype
|
128 | 121 | ):
|
@@ -184,152 +177,3 @@ def test_sde_strong_limit(
|
184 | 177 | )
|
185 | 178 | error = path_l2_dist(correct_sol, sol)
|
186 | 179 | assert error < 0.05
|
187 |
| - |
188 |
| - |
189 |
| -def _solvers(): |
190 |
| - yield diffrax.SPaRK |
191 |
| - yield diffrax.GeneralShARK |
192 |
| - yield diffrax.SlowRK |
193 |
| - yield diffrax.ShARK |
194 |
| - yield diffrax.SRA1 |
195 |
| - yield diffrax.SEA |
196 |
| - |
197 |
| - |
198 |
| -# Define the SDE |
199 |
| -def dict_drift(t, y, args): |
200 |
| - pytree, _ = args |
201 |
| - return jtu.tree_map(lambda _, x: -0.5 * x, pytree, y) |
202 |
| - |
203 |
| - |
204 |
| -def dict_diffusion(t, y, args): |
205 |
| - pytree, additive = args |
206 |
| - |
207 |
| - def get_matrix(y_leaf): |
208 |
| - if additive: |
209 |
| - return 2.0 * jnp.ones(y_leaf.shape + (3,), dtype=jnp.float64) |
210 |
| - else: |
211 |
| - return 2.0 * jnp.broadcast_to( |
212 |
| - jnp.expand_dims(y_leaf, axis=y_leaf.ndim), y_leaf.shape + (3,) |
213 |
| - ) |
214 |
| - |
215 |
| - return jtu.tree_map(get_matrix, y) |
216 |
| - |
217 |
| - |
218 |
| -@pytest.mark.parametrize("shape", [(), (5, 2)]) |
219 |
| -@pytest.mark.parametrize("solver_ctr", _solvers()) |
220 |
| -@pytest.mark.parametrize( |
221 |
| - "dtype", |
222 |
| - (jnp.float64, jnp.complex128), |
223 |
| -) |
224 |
| -def test_sde_solver_shape(shape, solver_ctr, dtype): |
225 |
| - pytree = ({"a": 0, "b": [0, 0]}, 0, 0) |
226 |
| - key = jr.PRNGKey(0) |
227 |
| - y0 = jtu.tree_map(lambda _: jr.normal(key, shape, dtype=dtype), pytree) |
228 |
| - t0, t1, dt0 = 0.0, 1.0, 0.3 |
229 |
| - |
230 |
| - # Some solvers only work with additive noise |
231 |
| - additive = solver_ctr in [diffrax.ShARK, diffrax.SRA1, diffrax.SEA] |
232 |
| - args = (pytree, additive) |
233 |
| - solver = solver_ctr() |
234 |
| - bmkey = jr.key(1) |
235 |
| - struct = jax.ShapeDtypeStruct((3,), dtype) |
236 |
| - bm_shape = jtu.tree_map(lambda _: struct, pytree) |
237 |
| - bm = diffrax.VirtualBrownianTree( |
238 |
| - t0, t1, 0.1, bm_shape, bmkey, diffrax.SpaceTimeLevyArea |
239 |
| - ) |
240 |
| - terms = MultiTerm(ODETerm(dict_drift), ControlTerm(dict_diffusion, bm)) |
241 |
| - solution = diffrax.diffeqsolve( |
242 |
| - terms, solver, t0, t1, dt0, y0, args, saveat=diffrax.SaveAt(t1=True) |
243 |
| - ) |
244 |
| - assert jtu.tree_structure(solution.ys) == jtu.tree_structure(y0) |
245 |
| - for leaf in jtu.tree_leaves(solution.ys): |
246 |
| - assert leaf[0].shape == shape |
247 |
| - |
248 |
| - |
249 |
| -def _weakly_diagonal_noise_helper(solver, dtype): |
250 |
| - w_shape = (3,) |
251 |
| - args = (0.5, 1.2) |
252 |
| - |
253 |
| - def _diffusion(t, y, args): |
254 |
| - a, b = args |
255 |
| - return jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype) |
256 |
| - |
257 |
| - def _drift(t, y, args): |
258 |
| - a, b = args |
259 |
| - return -a * y |
260 |
| - |
261 |
| - y0 = jnp.ones(w_shape, dtype) |
262 |
| - |
263 |
| - bm = diffrax.VirtualBrownianTree( |
264 |
| - 0.0, 1.0, 0.05, w_shape, jr.key(0), diffrax.SpaceTimeLevyArea |
265 |
| - ) |
266 |
| - |
267 |
| - terms = MultiTerm(ODETerm(_drift), WeaklyDiagonalControlTerm(_diffusion, bm)) |
268 |
| - saveat = diffrax.SaveAt(t1=True) |
269 |
| - solution = diffrax.diffeqsolve( |
270 |
| - terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat |
271 |
| - ) |
272 |
| - assert solution.ys is not None |
273 |
| - assert solution.ys.shape == (1, 3) |
274 |
| - |
275 |
| - |
276 |
| -def _lineax_weakly_diagonal_noise_helper(solver, dtype): |
277 |
| - w_shape = (3,) |
278 |
| - args = (0.5, 1.2) |
279 |
| - |
280 |
| - def _diffusion(t, y, args): |
281 |
| - a, b = args |
282 |
| - return lx.DiagonalLinearOperator(jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype)) |
283 |
| - |
284 |
| - def _drift(t, y, args): |
285 |
| - a, b = args |
286 |
| - return -a * y |
287 |
| - |
288 |
| - y0 = jnp.ones(w_shape, dtype) |
289 |
| - |
290 |
| - bm = diffrax.VirtualBrownianTree( |
291 |
| - 0.0, 1.0, 0.05, w_shape, jr.PRNGKey(0), diffrax.SpaceTimeLevyArea |
292 |
| - ) |
293 |
| - |
294 |
| - terms = MultiTerm(ODETerm(_drift), ControlTerm(_diffusion, bm)) |
295 |
| - saveat = diffrax.SaveAt(t1=True) |
296 |
| - solution = diffrax.diffeqsolve( |
297 |
| - terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat |
298 |
| - ) |
299 |
| - assert solution.ys is not None |
300 |
| - assert solution.ys.shape == (1, 3) |
301 |
| - |
302 |
| - |
303 |
| -@pytest.mark.parametrize("solver_ctr", _solvers()) |
304 |
| -@pytest.mark.parametrize( |
305 |
| - "dtype", |
306 |
| - (jnp.float64, jnp.complex128), |
307 |
| -) |
308 |
| -@pytest.mark.parametrize( |
309 |
| - "weak_type", |
310 |
| - ("old", "lineax"), |
311 |
| -) |
312 |
| -def test_weakly_diagonal_noise(solver_ctr, dtype, weak_type): |
313 |
| - if weak_type == "old": |
314 |
| - _weakly_diagonal_noise_helper(solver_ctr(), dtype) |
315 |
| - elif weak_type == "lineax": |
316 |
| - _lineax_weakly_diagonal_noise_helper(solver_ctr(), dtype) |
317 |
| - else: |
318 |
| - raise ValueError("Invalid weak_type") |
319 |
| - |
320 |
| - |
321 |
| -@pytest.mark.parametrize( |
322 |
| - "dtype", |
323 |
| - (jnp.float64, jnp.complex128), |
324 |
| -) |
325 |
| -@pytest.mark.parametrize( |
326 |
| - "weak_type", |
327 |
| - ("old", "lineax"), |
328 |
| -) |
329 |
| -def test_halfsolver_term_compatible(dtype, weak_type): |
330 |
| - if weak_type == "old": |
331 |
| - _weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype) |
332 |
| - elif weak_type == "lineax": |
333 |
| - _lineax_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype) |
334 |
| - else: |
335 |
| - raise ValueError("Invalid weak_type") |
0 commit comments