Skip to content

Commit

Permalink
diffrax root finder
Browse files Browse the repository at this point in the history
  • Loading branch information
LuggiStruggi committed Jan 20, 2025
1 parent 91d710d commit a658fde
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import jax.random as jr
import jax.tree_util as jtu
import lineax as lx
import optimistix as optx
import pytest
import scipy.stats
from diffrax import ControlTerm, MultiTerm, ODETerm
Expand Down Expand Up @@ -796,20 +795,18 @@ def func(self, terms, t0, y0, args):


def test_vmap_backprop():

def dynamics(t, y, args):
param = args
return param - y

def event_fn(t, y, args, **kwargs):
return y - 1.5

def single_loss_fn(param):
solver = diffrax.Euler()
root_finder = optx.Newton(1e-2, 1e-2, optx.rms_norm)
root_finder = diffrax.VeryChord(rtol=1e-3, atol=1e-6)
event = diffrax.Event(event_fn, root_finder)
term = diffrax.ODETerm(dynamics)

term = diffrax.ODETerm(dynamics)
sol = diffrax.diffeqsolve(
term,
solver=solver,
Expand All @@ -821,18 +818,18 @@ def single_loss_fn(param):
event=event,
max_steps=1000,
)

assert sol.ys is not None
final_y = sol.ys[-1]
return param**2 + final_y**2

def batched_loss_fn(params: jnp.ndarray) -> jnp.ndarray:
return jax.vmap(single_loss_fn)(params)

def grad_fn(params: jnp.ndarray) -> jnp.ndarray:
return jax.grad(lambda p: jnp.sum(batched_loss_fn(p)))(params)

batch = jnp.array([1.0, 2.0, 3.0])

try:
grad = grad_fn(batch)
except NotImplementedError as e:
Expand Down

0 comments on commit a658fde

Please sign in to comment.