From a658fdea8d940f81e599cfbf004b0fb040d9c7b9 Mon Sep 17 00:00:00 2001 From: LuggiStruggi <luggistruggi@gmail.com> Date: Mon, 20 Jan 2025 10:53:05 +0100 Subject: [PATCH] diffrax root finder --- test/test_integrate.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/test/test_integrate.py b/test/test_integrate.py index 590ee86e..af7e1b8d 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -10,7 +10,6 @@ import jax.random as jr import jax.tree_util as jtu import lineax as lx -import optimistix as optx import pytest import scipy.stats from diffrax import ControlTerm, MultiTerm, ODETerm @@ -796,20 +795,18 @@ def func(self, terms, t0, y0, args): def test_vmap_backprop(): - def dynamics(t, y, args): param = args return param - y - + def event_fn(t, y, args, **kwargs): return y - 1.5 - + def single_loss_fn(param): solver = diffrax.Euler() - root_finder = optx.Newton(1e-2, 1e-2, optx.rms_norm) + root_finder = diffrax.VeryChord(rtol=1e-3, atol=1e-6) event = diffrax.Event(event_fn, root_finder) - term = diffrax.ODETerm(dynamics) - + term = diffrax.ODETerm(dynamics) sol = diffrax.diffeqsolve( term, solver=solver, @@ -821,18 +818,18 @@ def single_loss_fn(param): event=event, max_steps=1000, ) - + assert sol.ys is not None final_y = sol.ys[-1] return param**2 + final_y**2 - + def batched_loss_fn(params: jnp.ndarray) -> jnp.ndarray: return jax.vmap(single_loss_fn)(params) - + def grad_fn(params: jnp.ndarray) -> jnp.ndarray: return jax.grad(lambda p: jnp.sum(batched_loss_fn(p)))(params) batch = jnp.array([1.0, 2.0, 3.0]) - + try: grad = grad_fn(batch) except NotImplementedError as e: