From a658fdea8d940f81e599cfbf004b0fb040d9c7b9 Mon Sep 17 00:00:00 2001
From: LuggiStruggi <luggistruggi@gmail.com>
Date: Mon, 20 Jan 2025 10:53:05 +0100
Subject: [PATCH] diffrax root finder

---
 test/test_integrate.py | 19 ++++++++-----------
 1 file changed, 8 insertions(+), 11 deletions(-)

diff --git a/test/test_integrate.py b/test/test_integrate.py
index 590ee86e..af7e1b8d 100644
--- a/test/test_integrate.py
+++ b/test/test_integrate.py
@@ -10,7 +10,6 @@
 import jax.random as jr
 import jax.tree_util as jtu
 import lineax as lx
-import optimistix as optx
 import pytest
 import scipy.stats
 from diffrax import ControlTerm, MultiTerm, ODETerm
@@ -796,20 +795,18 @@ def func(self, terms, t0, y0, args):
 
 
 def test_vmap_backprop():
-    
     def dynamics(t, y, args):
         param = args
         return param - y
-    
+
     def event_fn(t, y, args, **kwargs):
         return y - 1.5
-    
+
     def single_loss_fn(param):
         solver = diffrax.Euler()
-        root_finder = optx.Newton(1e-2, 1e-2, optx.rms_norm)
+        root_finder = diffrax.VeryChord(rtol=1e-3, atol=1e-6)
         event = diffrax.Event(event_fn, root_finder)
-        term = diffrax.ODETerm(dynamics) 
-        
+        term = diffrax.ODETerm(dynamics)
         sol = diffrax.diffeqsolve(
             term,
             solver=solver,
@@ -821,18 +818,18 @@ def single_loss_fn(param):
             event=event,
             max_steps=1000,
         )
-    
+        assert sol.ys is not None
         final_y = sol.ys[-1]
         return param**2 + final_y**2
-    
+
     def batched_loss_fn(params: jnp.ndarray) -> jnp.ndarray:
         return jax.vmap(single_loss_fn)(params)
-    
+
     def grad_fn(params: jnp.ndarray) -> jnp.ndarray:
         return jax.grad(lambda p: jnp.sum(batched_loss_fn(p)))(params)
 
     batch = jnp.array([1.0, 2.0, 3.0])
-    
+
     try:
         grad = grad_fn(batch)
     except NotImplementedError as e: