From ae0194220102ae619b81eebdd0bb916a814bcd2f Mon Sep 17 00:00:00 2001 From: Sam McCallum Date: Wed, 19 Feb 2025 10:11:51 +0000 Subject: [PATCH] AbstractReversibleSolver + ReversibleAdjoint add reversible testing testing AbstractReversibleSolver + ReversibleAdjoint allow arbitrary interpolation unpacking over indexing jax while loop collapse saveat ValueErrors remove statonovich solver condition remove unused returns from AbstractReversibleSolver backward_step add test and remove messy benchmark add wrapped solver + tests made_jump=True for both solver steps improve docstrings AbstractSolver and docstring note about SDEs add AbstractReversibleSolver to public API newline in docstrings return RESULTS from reversible backward_step restrict Reversible to AbstractERK and check result in adjoint correct tprev and tnext of solver init switch to linear interpolation and y0,y1 dense_info name UReversible various doc formatting changes AbstractReversibleSolver check add disable_fsal property to AbstractRungeKutta and use in UReversible allow t0 != 0 Handle StepTo controller t0==t1 branch --- diffrax/__init__.py | 3 + diffrax/_adjoint.py | 306 ++++++++++++++++++++++++ diffrax/_integrate.py | 69 +++++- diffrax/_solver/__init__.py | 2 + diffrax/_solver/base.py | 58 ++++- diffrax/_solver/leapfrog_midpoint.py | 43 +++- diffrax/_solver/reversible.py | 186 +++++++++++++++ diffrax/_solver/reversible_heun.py | 39 +++- diffrax/_solver/runge_kutta.py | 2 + diffrax/_solver/semi_implicit_euler.py | 35 ++- test/test_reversible.py | 310 +++++++++++++++++++++++++ 11 files changed, 1040 insertions(+), 13 deletions(-) create mode 100644 diffrax/_solver/reversible.py create mode 100644 test/test_reversible.py diff --git a/diffrax/__init__.py b/diffrax/__init__.py index d35a7fac..b9d24b76 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -7,6 +7,7 @@ ForwardMode as ForwardMode, ImplicitAdjoint as ImplicitAdjoint, RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint, + ReversibleAdjoint as ReversibleAdjoint, ) from ._autocitation import citation as citation, citation_rules as citation_rules from ._brownian import ( @@ -75,6 +76,7 @@ AbstractFosterLangevinSRK as AbstractFosterLangevinSRK, AbstractImplicitSolver as AbstractImplicitSolver, AbstractItoSolver as AbstractItoSolver, + AbstractReversibleSolver as AbstractReversibleSolver, AbstractRungeKutta as AbstractRungeKutta, AbstractSDIRK as AbstractSDIRK, AbstractSolver as AbstractSolver, @@ -117,6 +119,7 @@ StochasticButcherTableau as StochasticButcherTableau, StratonovichMilstein as StratonovichMilstein, Tsit5 as Tsit5, + UReversible as UReversible, ) from ._step_size_controller import ( AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController, diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 7bc081b9..9dfc75c1 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -16,8 +16,10 @@ from ._heuristics import is_sde, is_unsafe_sde from ._saveat import save_y, SaveAt, SubSaveAt +from ._solution import RESULTS from ._solver import ( AbstractItoSolver, + AbstractReversibleSolver, AbstractRungeKutta, AbstractSRK, AbstractStratonovichSolver, @@ -918,3 +920,307 @@ def loop( ForwardMode.__init__.__doc__ = """**Arguments:** None""" + +# Reversible Adjoint custom vjp computes gradients w.r.t. +# - y, corresponding to the initial state; +# - args, corresponding to explicit parameters; +# - terms, corresponding to implicit parameters as part of the vector field. + + +@eqx.filter_custom_vjp +def _loop_reversible(y__args__terms, *, self, throw, max_steps, init_state, **kwargs): + del throw + y, args, terms = y__args__terms + init_state = eqx.tree_at(lambda s: s.y, init_state, y) + del y + return self._loop( + args=args, + terms=terms, + max_steps=max_steps, + init_state=init_state, + inner_while_loop=ft.partial(_inner_loop, kind="lax"), + outer_while_loop=ft.partial(_outer_loop, kind="lax"), + **kwargs, + ) + + +@_loop_reversible.def_fwd +def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs): + del perturbed + final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs) + init_ts = final_state.reversible_init_ts + ts = final_state.reversible_ts + ts_final_index = final_state.reversible_save_index + y1 = final_state.y + save_state = final_state.save_state + solver_state = final_state.solver_state + return (final_state, aux_stats), ( + init_ts, + ts, + ts_final_index, + y1, + save_state, + solver_state, + ) + + +@_loop_reversible.def_bwd +def _loop_reversible_bwd( + residuals, + grad_final_state__aux_stats, + perturbed, + y__args__terms, + *, + self, + saveat, + init_state, + solver, + event, + **kwargs, +): + assert event is None + + del perturbed, self, init_state, kwargs + init_ts, ts, ts_final_index, y1, save_state, solver_state = residuals + del residuals + + grad_final_state, _ = grad_final_state__aux_stats + saveat_ts = save_state.ts + ys = save_state.ys + saveat_ts_index = save_state.saveat_ts_index - 1 + grad_ys = grad_final_state.save_state.ys + grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys) + + if saveat.subs.t1: + grad_y1 = (ω(grad_ys)[-1]).ω + else: + grad_y1 = jtu.tree_map(jnp.zeros_like, y1) + + if saveat.subs.t0: + saveat_ts_index = saveat_ts_index + 1 + + del grad_final_state, grad_final_state__aux_stats + + y, args, terms = y__args__terms + del y__args__terms + + diff_state = eqx.filter(solver_state, eqx.is_inexact_array) + diff_args = eqx.filter(args, eqx.is_inexact_array) + diff_terms = eqx.filter(terms, eqx.is_inexact_array) + grad_state = jtu.tree_map(jnp.zeros_like, diff_state) + grad_args = jtu.tree_map(jnp.zeros_like, diff_args) + grad_terms = jtu.tree_map(jnp.zeros_like, diff_terms) + del diff_args, diff_terms + + def grad_step(state): + def forward_step(y0, solver_state, args, terms): + y1, _, dense_info, new_solver_state, result = solver.step( + terms, t0, t1, y0, args, solver_state, False + ) + assert result == RESULTS.successful + return y1, dense_info, new_solver_state + + ( + saveat_ts_index, + ts_index, + y1, + solver_state, + grad_y1, + grad_state, + grad_args, + grad_terms, + ) = state + + t1 = ts[ts_index] + t0 = ts[ts_index - 1] + + # Any ts state required to reverse the forward step + # e.g. LeapfrogMidpoint requires tm1 + tm1_index = ts_index - 2 + tm1 = ts[tm1_index] + tm1 = jnp.where(tm1_index >= 0, tm1, t0) + ts_state = (tm1,) + + y0, dense_info, solver_state, result = solver.backward_step( + terms, t0, t1, y1, args, ts_state, solver_state, False + ) + assert result == RESULTS.successful + + # Pull gradients back through interpolation + + def interpolate(t, t0, t1, dense_info): + interpolator = solver.interpolation_cls(t0=t0, t1=t1, **dense_info) + return interpolator.evaluate(t) + + def _cond_fun(inner_state): + saveat_ts_index, _ = inner_state + return (saveat_ts[saveat_ts_index] >= t0) & (saveat_ts_index >= 0) + + def _body_fun(inner_state): + saveat_ts_index, grad_dense_info = inner_state + t = saveat_ts[saveat_ts_index] + grad_y = (ω(grad_ys)[saveat_ts_index]).ω + _, interp_vjp = eqx.filter_vjp(interpolate, t, t0, t1, dense_info) + _, _, _, dgrad_dense_info = interp_vjp(grad_y) + grad_dense_info = eqx.apply_updates(grad_dense_info, dgrad_dense_info) + saveat_ts_index = saveat_ts_index - 1 + return saveat_ts_index, grad_dense_info + + grad_dense_info = jtu.tree_map(jnp.zeros_like, dense_info) + inner_state = (saveat_ts_index, grad_dense_info) + inner_state = eqxi.while_loop(_cond_fun, _body_fun, inner_state, kind="lax") + saveat_ts_index, grad_dense_info = inner_state + + # Pull gradients back through forward step + + _, vjp_fn = eqx.filter_vjp(forward_step, y0, solver_state, args, terms) + grad_y0, grad_state, dgrad_args, dgrad_terms = vjp_fn( + (grad_y1, grad_dense_info, grad_state) + ) + + grad_args = eqx.apply_updates(grad_args, dgrad_args) + grad_terms = eqx.apply_updates(grad_terms, dgrad_terms) + + ts_index = ts_index - 1 + + return ( + saveat_ts_index, + ts_index, + y0, + solver_state, + grad_y0, + grad_state, + grad_args, + grad_terms, + ) + + def cond_fun(state): + ts_index = state[1] + return ts_index > 0 + + state = ( + saveat_ts_index, + ts_final_index, + y1, + solver_state, + grad_y1, + grad_state, + grad_args, + grad_terms, + ) + + state = jax.lax.while_loop(cond_fun, grad_step, state) + _, _, y0, _, grad_y0, grad_state, grad_args, grad_terms = state + + # Pull solver_state gradients back onto y0, args, terms. + + init_t0, init_t1 = init_ts + _, init_vjp = eqx.filter_vjp(solver.init, terms, init_t0, init_t1, y0, args) + dgrad_terms, _, _, dgrad_y0, dgrad_args = init_vjp(grad_state) + grad_y0 = eqx.apply_updates(grad_y0, dgrad_y0) + grad_terms = eqx.apply_updates(grad_terms, dgrad_terms) + grad_args = eqx.apply_updates(grad_args, dgrad_args) + + return grad_y0, grad_args, grad_terms + + +class ReversibleAdjoint(AbstractAdjoint): + """Backpropagate through [`diffrax.diffeqsolve`][] when using a reversible solver + [`diffrax.AbstractReversibleSolver`][]. + + Gradient calculation is exact (up to floating point errors) and backpropagation + becomes linear in time $O(n)$ and constant in memory $O(1)$, for $n$ time steps. + + !!! note + + This adjoint can be less numerically stable than + [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.DirectAdjoint`][]. + Stability can be largely improved by using [double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) + and [smaller/adaptive step sizes](https://docs.kidger.site/diffrax/api/stepsize_controller/). + + ??? cite "References" + + For an introduction to reversible backpropagation, see these references: + + ```bibtex + @article{mccallum2024efficient, + title={Efficient, Accurate and Stable Gradients for Neural ODEs}, + author={McCallum, Sam and Foster, James}, + journal={arXiv preprint arXiv:2410.11648}, + year={2024} + } + + @phdthesis{kidger2021on, + title={{O}n {N}eural {D}ifferential {E}quations}, + author={Patrick Kidger}, + year={2021}, + school={University of Oxford}, + } + ``` + """ + + def loop( + self, + *, + args, + terms, + solver, + saveat, + max_steps, + init_state, + passed_solver_state, + passed_controller_state, + event, + **kwargs, + ): + if not isinstance(solver, AbstractReversibleSolver): + raise ValueError( + "`ReversibleAdjoint` can only be used with an " + "`AbstractReversibleSolver`" + ) + if max_steps is None: + raise ValueError( + "`max_steps=None` is incompatible with `ReversibleAdjoint`." + ) + + if ( + jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) + != jtu.tree_structure(0) + or saveat.dense + or saveat.subs.steps + or (saveat.subs.fn is not save_y) + ): + raise ValueError( + "`ReversibleAdjoint` is only compatible with the following `SaveAt` " + "properties: `t0`, `t1`, `ts`, `fn=save_y` (default)." + ) + + if event is not None: + raise NotImplementedError( + "`ReversibleAdjoint` is not compatible with events." + ) + + if is_unsafe_sde(terms): + raise ValueError( + "`ReversibleAdjoint` does not support `UnsafeBrownianPath`. " + "Consider using `VirtualBrownianTree` instead." + ) + + y = init_state.y + init_state = eqx.tree_at(lambda s: s.y, init_state, object()) + init_state = _nondiff_solver_controller_state( + self, init_state, passed_solver_state, passed_controller_state + ) + + final_state, aux_stats = _loop_reversible( + (y, args, terms), + self=self, + saveat=saveat, + max_steps=max_steps, + init_state=init_state, + solver=solver, + event=event, + **kwargs, + ) + final_state = _only_transpose_ys(final_state) + return final_state, aux_stats diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 6fc38ce3..fbcbc184 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -6,6 +6,7 @@ cast, get_args, get_origin, + Optional, Tuple, ) @@ -108,6 +109,12 @@ class State(eqx.Module): event_dense_info: DenseInfo | None event_values: PyTree[BoolScalarLike | RealScalarLike] | None event_mask: PyTree[BoolScalarLike] | None + # + # Information for reversible adjoint (save ts) + # + reversible_init_ts: Optional[PyTree[FloatScalarLike]] + reversible_ts: Optional[eqxi.MaybeBuffer[Float[Array, " times_plus_1"]]] + reversible_save_index: Optional[IntScalarLike] def _is_none(x: Any) -> bool: @@ -220,7 +227,7 @@ def _outer_buffers(state): return ( [s.ts for s in save_states] + [s.ys for s in save_states] - + [state.dense_ts, state.dense_infos] + + [state.dense_ts, state.dense_infos, state.reversible_ts] ) @@ -309,6 +316,11 @@ def loop( dense_ts = dense_ts.at[0].set(t0) init_state = eqx.tree_at(lambda s: s.dense_ts, init_state, dense_ts) + if init_state.reversible_ts is not None: + reversible_ts = init_state.reversible_ts + reversible_ts = reversible_ts.at[0].set(t0) + init_state = eqx.tree_at(lambda s: s.reversible_ts, init_state, reversible_ts) + def save_t0(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: if subsaveat.t0: save_state = _save( @@ -621,6 +633,16 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i, direction_i): result, ) + reversible_init_ts = state.reversible_init_ts + reversible_ts = state.reversible_ts + reversible_save_index = state.reversible_save_index + + if state.reversible_ts is not None: + reversible_ts = eqxi.buffer_at_set( + reversible_ts, reversible_save_index + 1, tprev, pred=keep_step + ) + reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0) + new_state = State( y=y, tprev=tprev, @@ -642,6 +664,9 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i, direction_i): event_dense_info=event_dense_info, event_values=event_values, event_mask=event_mask, + reversible_init_ts=reversible_init_ts, + reversible_ts=reversible_ts, # pyright: ignore[reportArgumentType] + reversible_save_index=reversible_save_index, ) return ( @@ -864,6 +889,35 @@ def _save_t1(subsaveat, save_state): final_state = eqx.tree_at( lambda s: s.save_state, final_state, save_state, is_leaf=_is_none ) + + # if t0 == t1 and we are using diffrax.ReversibleAdjoint then we need to update the + # reversible_ts and reversible_ts_index to get correct gradients + def _reversible_info_if_t0_equals_t1(reversible_ts, reversible_save_index): + reversible_ts = eqxi.buffer_at_set(final_state.reversible_ts, 1, t0) + reversible_save_index += 1 + return reversible_ts, reversible_save_index + + reversible_ts, reversible_save_index = jax.lax.cond( + eqxi.unvmap_any(t0 == t1), + lambda __ts, __index: jax.lax.cond( + t0 == t1, + lambda _ts, _index: _reversible_info_if_t0_equals_t1(_ts, _index), + lambda _ts, _index: (_ts, _index), + __ts, + __index, + ), + lambda __ts, __index: (__ts, __index), + final_state.reversible_ts, + final_state.reversible_save_index, + ) + + final_state = eqx.tree_at( + lambda s: (s.reversible_ts, s.reversible_save_index), + final_state, + (reversible_ts, reversible_save_index), + is_leaf=_is_none, + ) + final_state = _handle_static(final_state) result = RESULTS.where(cond_fun(final_state), RESULTS.max_steps_reached, result) aux_stats = dict() # TODO: put something in here? @@ -1399,6 +1453,16 @@ def _outer_cond_fn(cond_fn_i): ) del had_event, event_structure, event_mask_leaves, event_values__mask + # Reversible info + if max_steps is None: + reversible_init_ts = None + reversible_ts = None + reversible_save_index = None + else: + reversible_init_ts = (tprev, tnext) + reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype) + reversible_save_index = 0 + # Initialise state init_state = State( y=y0, @@ -1421,6 +1485,9 @@ def _outer_cond_fn(cond_fn_i): event_dense_info=event_dense_info, event_values=event_values, event_mask=event_mask, + reversible_init_ts=reversible_init_ts, + reversible_ts=reversible_ts, + reversible_save_index=reversible_save_index, ) # diff --git a/diffrax/_solver/__init__.py b/diffrax/_solver/__init__.py index 0a840413..d3048088 100644 --- a/diffrax/_solver/__init__.py +++ b/diffrax/_solver/__init__.py @@ -3,6 +3,7 @@ AbstractAdaptiveSolver as AbstractAdaptiveSolver, AbstractImplicitSolver as AbstractImplicitSolver, AbstractItoSolver as AbstractItoSolver, + AbstractReversibleSolver as AbstractReversibleSolver, AbstractSolver as AbstractSolver, AbstractStratonovichSolver as AbstractStratonovichSolver, AbstractWrappedSolver as AbstractWrappedSolver, @@ -30,6 +31,7 @@ ) from .quicsort import QUICSORT as QUICSORT from .ralston import Ralston as Ralston +from .reversible import UReversible as UReversible from .reversible_heun import ReversibleHeun as ReversibleHeun from .runge_kutta import ( AbstractDIRK as AbstractDIRK, diff --git a/diffrax/_solver/base.py b/diffrax/_solver/base.py index d4a476e3..7699ebf6 100644 --- a/diffrax/_solver/base.py +++ b/diffrax/_solver/base.py @@ -22,7 +22,7 @@ else: from equinox import AbstractClassVar, AbstractVar from equinox.internal import ω -from jaxtyping import PyTree +from jaxtyping import Array, PyTree from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y from .._heuristics import is_sde @@ -350,3 +350,59 @@ def func( - `solver`: The solver to wrap. """ + + +class AbstractReversibleSolver(AbstractSolver[_SolverState]): + """Indicates that this is a reversible differential equation solver. This means + that the state at `t0` can be reconstructed (in closed form) from the state at `t1`. + + The reconstruction must be implemented by + [`diffrax.AbstractReversibleSolver.backward_step`][]. + + This solver can be combined with `adjoint=diffrax.ReversibleAdjoint` for exact + gradient backpropagation in $O(n)$ time and $O(1)$ memory, for $n$ time steps. + """ + + @abc.abstractmethod + def backward_step( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y1: Y, + args: Args, + ts_state: PyTree[RealScalarLike], + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, DenseInfo, _SolverState, RESULTS]: + """ + Make a single backward step with the reversible solver. + + Each step is made over the specified interval $[t_1, t_0]$. + + **Arguments:** + + - `terms`: The PyTree of terms representing the vector fields and controls. + - `t0`: The end of the interval that the backward step is made over. + - `t1`: The start of the interval that the backward step is made over. + - `y1`: The current value of the solution at `t1`. + - `args`: Any extra arguments passed to the vector field. + - `ts_state`: Any `ts` state required to reverse the forward `step`. + - `solver_state`: Any evolving state for the solver itself, at `t1`. + - `made_jump`: Whether there was a discontinuity in the vector field at `t1`. + Some solvers (notably FSAL Runge--Kutta solvers) usually assume that there + are no jumps and for efficiency re-use information between steps; this + indicates that a jump has just occurred and this assumption is not true. + + **Returns:** + + A tuple of four objects: + + - The value of the solution at `t0`. + - Some dictionary of information that is passed to the solver's interpolation + routine to calculate dense output. Note that this is assumed to be the same + information returned on the forward step. + - The value of the solver state at `t0`. + - An integer (corresponding to `diffrax.RESULTS`) indicating whether the step + happened successfully, or if (unusually) it failed for some reason. + """ diff --git a/diffrax/_solver/leapfrog_midpoint.py b/diffrax/_solver/leapfrog_midpoint.py index ddcaa12e..a5928075 100644 --- a/diffrax/_solver/leapfrog_midpoint.py +++ b/diffrax/_solver/leapfrog_midpoint.py @@ -1,6 +1,9 @@ from collections.abc import Callable from typing import ClassVar, TypeAlias +import jax +import jax.numpy as jnp +import jax.tree_util as jtu from equinox.internal import ω from jaxtyping import PyTree @@ -8,15 +11,15 @@ from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS from .._term import AbstractTerm -from .base import AbstractSolver +from .base import AbstractReversibleSolver _ErrorEstimate: TypeAlias = None -_SolverState: TypeAlias = tuple[RealScalarLike, PyTree] +_SolverState: TypeAlias = tuple[RealScalarLike, PyTree, RealScalarLike] # TODO: support arbitrary linear multistep methods -class LeapfrogMidpoint(AbstractSolver): +class LeapfrogMidpoint(AbstractReversibleSolver): r"""Leapfrog/midpoint method. 2nd order linear multistep method. Uses 1st order local linear interpolation for @@ -28,6 +31,13 @@ class LeapfrogMidpoint(AbstractSolver): (which is usually taken to refer to the explicit Runge--Kutta method [`diffrax.Midpoint`][]). + !!! note + + This solver is algebraically reversible, meaning that the state at `t0` can be + reconstructed (in closed form) from the state at `t1`. This allows exact + gradient backpropagation in $O(n)$ time and $O(1)$ memory when using + [`diffrax.ReversibleAdjoint`][]. + ??? cite "Reference" ```bibtex @@ -59,9 +69,8 @@ def init( y0: Y, args: Args, ) -> _SolverState: - del terms, t1, args # Corresponds to making an explicit Euler step on the first step. - return t0, y0 + return t0, y0, t0 def step( self, @@ -74,13 +83,33 @@ def step( made_jump: BoolScalarLike, ) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: del made_jump - tm1, ym1 = solver_state + tm1, ym1, init_t0 = solver_state control = terms.contr(tm1, t1) y1 = (ym1**ω + terms.vf_prod(t0, y0, args, control) ** ω).ω dense_info = dict(y0=y0, y1=y1) - solver_state = (t0, y0) + solver_state = (t0, y0, init_t0) return y1, None, dense_info, solver_state, RESULTS.successful + def backward_step( + self, + terms: AbstractTerm, + t0: RealScalarLike, + t1: RealScalarLike, + y1: Y, + args: Args, + ts_state: PyTree[RealScalarLike], + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, DenseInfo, _SolverState, RESULTS]: + del made_jump + t0, y0, init_t0 = solver_state + (tm1,) = ts_state + control = terms.contr(tm1, t1) + ym1 = (y1**ω - terms.vf_prod(t0, y0, args, control) ** ω).ω + dense_info = dict(y0=y0, y1=y1) + solver_state = (tm1, ym1, init_t0) + return y0, dense_info, solver_state, RESULTS.successful + def func(self, terms: AbstractTerm, t0: RealScalarLike, y0: Y, args: Args) -> VF: return terms.vf(t0, y0, args) diff --git a/diffrax/_solver/reversible.py b/diffrax/_solver/reversible.py new file mode 100644 index 00000000..bc776ffc --- /dev/null +++ b/diffrax/_solver/reversible.py @@ -0,0 +1,186 @@ +from collections.abc import Callable +from typing import cast, ClassVar, Optional + +import equinox as eqx +from equinox.internal import ω +from jaxtyping import PyTree + +from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y +from .._local_interpolation import LocalLinearInterpolation +from .._solution import RESULTS, update_result +from .._solver.base import ( + AbstractReversibleSolver, + AbstractWrappedSolver, +) +from .._term import AbstractTerm +from .runge_kutta import AbstractERK + + +ω = cast(Callable, ω) +_SolverState = Y + + +class UReversible( + AbstractReversibleSolver[_SolverState], AbstractWrappedSolver[_SolverState] +): + """ + U-Reversible solver method. + + Allows any explicit Runge-Kutta solver ([`diffrax.AbstractERK`][]) to be made + algebraically reversible. + + **Arguments:** + + - `solver`: base solver to be made reversible + - `coupling_parameter`: determines coupling between the two evolving solutions. + Must be within the range `0 < coupling_parameter < 1`. Unless you need finer control + over stability, the default value of `0.999` should be sufficient. + + !!! note + + When solving SDEs, the base `solver` must converge to the Statonovich solution. + + ??? cite "References" + + This method was developed in: + + ```bibtex + @article{mccallum2024efficient, + title={Efficient, Accurate and Stable Gradients for Neural ODEs}, + author={McCallum, Sam and Foster, James}, + journal={arXiv preprint arXiv:2410.11648}, + year={2024} + } + ``` + + And built on previous work by: + + ```bibtex + @article{kidger2021efficient, + title={Efficient and accurate gradients for neural sdes}, + author={Kidger, Patrick and Foster, James and Li, Xuechen Chen and Lyons, + Terry}, + journal={Advances in Neural Information Processing Systems}, + volume={34}, + pages={18747--18761}, + year={2021} + } + + @article{zhuang2021mali, + title={Mali: A memory efficient and reverse accurate integrator for neural + odes}, + author={Zhuang, Juntang and Dvornek, Nicha C and Tatikonda, Sekhar and + Duncan, James S}, + journal={arXiv preprint arXiv:2102.04668}, + year={2021} + } + ``` + """ + + solver: AbstractERK + coupling_parameter: float + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) + + @property + def term_structure(self): + return self.solver.term_structure + + @property + def term_compatible_contr_kwargs(self): + return self.solver.term_compatible_contr_kwargs + + @property + def root_finder(self): + return self.solver.root_finder + + @property + def root_find_max_steps(self): + return self.solver.root_find_max_steps + + def order(self, terms: PyTree[AbstractTerm]) -> Optional[int]: + return self.solver.order(terms) + + def strong_order(self, terms: PyTree[AbstractTerm]) -> Optional[RealScalarLike]: + return self.solver.strong_order(terms) + + def __init__(self, solver: AbstractERK, coupling_parameter: float = 0.999): + self.solver = eqx.tree_at(lambda s: s.disable_fsal, solver, True) + self.coupling_parameter = coupling_parameter + + def init( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + args: Args, + ) -> _SolverState: + if not isinstance(self.solver, AbstractERK): + raise ValueError( + "`UReversible` is only compatible with `AbstractERK` base solvers." + ) + return y0 + + def step( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]: + del made_jump + z0 = solver_state + + step_z0, _, _, _, result1 = self.solver.step( + terms, t0, t1, z0, args, None, True + ) + y1 = (self.coupling_parameter * (ω(y0) - ω(z0)) + ω(step_z0)).ω + + step_y1, y_error, _, _, result2 = self.solver.step( + terms, t1, t0, y1, args, None, True + ) + z1 = (ω(y1) + ω(z0) - ω(step_y1)).ω + + solver_state = z1 + dense_info = dict(y0=y0, y1=y1) + result = update_result(result1, result2) + + return y1, y_error, dense_info, solver_state, result + + def backward_step( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y1: Y, + args: Args, + ts_state: PyTree[RealScalarLike], + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, DenseInfo, _SolverState, RESULTS]: + del made_jump, ts_state + z1 = solver_state + step_y1, _, _, _, result1 = self.solver.step( + terms, t1, t0, y1, args, None, True + ) + z0 = (ω(z1) - ω(y1) + ω(step_y1)).ω + step_z0, _, _, _, result2 = self.solver.step( + terms, t0, t1, z0, args, None, True + ) + y0 = ((1 / self.coupling_parameter) * (ω(y1) - ω(step_z0)) + ω(z0)).ω + + solver_state = z0 + dense_info = dict(y0=y0, y1=y1) + result = update_result(result1, result2) + + return y0, dense_info, solver_state, result + + def func( + self, terms: PyTree[AbstractTerm], t0: RealScalarLike, y0: Y, args: Args + ) -> VF: + return self.solver.func(terms, t0, y0, args) diff --git a/diffrax/_solver/reversible_heun.py b/diffrax/_solver/reversible_heun.py index 91617d4f..de48f9a9 100644 --- a/diffrax/_solver/reversible_heun.py +++ b/diffrax/_solver/reversible_heun.py @@ -9,13 +9,19 @@ from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS from .._term import AbstractTerm -from .base import AbstractAdaptiveSolver, AbstractStratonovichSolver +from .base import ( + AbstractAdaptiveSolver, + AbstractReversibleSolver, + AbstractStratonovichSolver, +) _SolverState: TypeAlias = tuple[PyTree, PyTree] -class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): +class ReversibleHeun( + AbstractReversibleSolver, AbstractAdaptiveSolver, AbstractStratonovichSolver +): """Reversible Heun method. Algebraically reversible 2nd order method. Has an embedded 1st order method for @@ -23,6 +29,12 @@ class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): When used to solve SDEs, converges to the Stratonovich solution. + !!! note + This solver is algebraically reversible, meaning that the state at `t0` can be + reconstructed (in closed form) from the state at `t1`. This allows exact + gradient backpropagation in $O(n)$ time and $O(1)$ memory when using + [`diffrax.ReversibleAdjoint`][]. + ??? cite "Reference" ```bibtex @@ -82,6 +94,29 @@ def step( solver_state = (yhat1, vf1) return y1, y1_error, dense_info, solver_state, RESULTS.successful + def backward_step( + self, + terms: AbstractTerm, + t0: RealScalarLike, + t1: RealScalarLike, + y1: Y, + args: Args, + ts_state: PyTree[RealScalarLike], + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, DenseInfo, _SolverState, RESULTS]: + del made_jump, ts_state + yhat1, vf1 = solver_state + + control = terms.contr(t0, t1) + yhat0 = (2 * y1**ω - yhat1**ω - terms.prod(vf1, control) ** ω).ω + vf0 = terms.vf(t0, yhat0, args) + y0 = (y1**ω - 0.5 * terms.prod((vf0**ω + vf1**ω).ω, control) ** ω).ω + + dense_info = dict(y0=y0, y1=y1) + solver_state = (yhat0, vf0) + return y0, dense_info, solver_state, RESULTS.successful + def func(self, terms: AbstractTerm, t0: RealScalarLike, y0: Y, args: Args) -> VF: return terms.vf(t0, y0, args) diff --git a/diffrax/_solver/runge_kutta.py b/diffrax/_solver/runge_kutta.py index 9473ab44..b95f1539 100644 --- a/diffrax/_solver/runge_kutta.py +++ b/diffrax/_solver/runge_kutta.py @@ -354,6 +354,7 @@ class AbstractRungeKutta(AbstractAdaptiveSolver[_SolverState]): """ scan_kind: None | Literal["lax", "checkpointed", "bounded"] = None + disable_fsal: bool = False tableau: AbstractClassVar[ButcherTableau | MultiButcherTableau] calculate_jacobian: AbstractClassVar[CalculateJacobian] @@ -401,6 +402,7 @@ def _common(self, terms, t0, t1, y0, args): # FSAL implies evaluating just the vector field, since we need to contract # the same vector field evaluation against two different controls. fsal = fsal and not vf_expensive + fsal = fsal and not self.disable_fsal return vf_expensive, fsal def func( diff --git a/diffrax/_solver/semi_implicit_euler.py b/diffrax/_solver/semi_implicit_euler.py index 34122fe3..6313fec9 100644 --- a/diffrax/_solver/semi_implicit_euler.py +++ b/diffrax/_solver/semi_implicit_euler.py @@ -8,7 +8,7 @@ from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS from .._term import AbstractTerm -from .base import AbstractSolver +from .base import AbstractReversibleSolver _ErrorEstimate: TypeAlias = None @@ -18,11 +18,17 @@ Yb: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] # pyright: ignore[reportUndefinedVariable] -class SemiImplicitEuler(AbstractSolver): +class SemiImplicitEuler(AbstractReversibleSolver): """Semi-implicit Euler's method. Symplectic method. Does not support adaptive step sizing. Uses 1st order local linear interpolation for dense/ts output. + + !!! note + This solver is algebraically reversible, meaning that the state at `t0` can be + reconstructed (in closed form) from the state at `t1`. This allows exact + gradient backpropagation in $O(n)$ time and $O(1)$ memory when using + [`diffrax.ReversibleAdjoint`][]. """ term_structure: ClassVar = (AbstractTerm, AbstractTerm) @@ -67,6 +73,31 @@ def step( dense_info = dict(y0=y0, y1=y1) return y1, None, dense_info, None, RESULTS.successful + def backward_step( + self, + terms: tuple[AbstractTerm, AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y1: tuple[Ya, Yb], + args: Args, + ts_state: PyTree[RealScalarLike], + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[tuple[Ya, Yb], DenseInfo, _SolverState, RESULTS]: + del solver_state, made_jump, ts_state + + term_1, term_2 = terms + y1_1, y1_2 = y1 + + control1 = term_1.contr(t0, t1) + control2 = term_2.contr(t0, t1) + y0_2 = (y1_2**ω - term_2.vf_prod(t0, y1_1, args, control2) ** ω).ω + y0_1 = (y1_1**ω - term_1.vf_prod(t0, y0_2, args, control1) ** ω).ω + + y0 = (y0_1, y0_2) + dense_info = dict(y0=y0, y1=y1) + return y0, dense_info, None, RESULTS.successful + def func( self, terms: tuple[AbstractTerm, AbstractTerm], diff --git a/test/test_reversible.py b/test/test_reversible.py new file mode 100644 index 00000000..e88c9b24 --- /dev/null +++ b/test/test_reversible.py @@ -0,0 +1,310 @@ +from typing import cast + +import diffrax +import equinox as eqx +import jax +import jax.numpy as jnp +import jax.random as jr +import pytest +from jaxtyping import Array + +from .helpers import tree_allclose + + +jax.config.update("jax_enable_x64", True) + + +class VectorField(eqx.Module): + mlp: eqx.nn.MLP + + def __init__(self, in_size, out_size, width_size, depth, key): + self.mlp = eqx.nn.MLP(in_size, out_size, width_size, depth, key=key) + + def __call__(self, t, y, args): + return args * self.mlp(y) + + +@eqx.filter_value_and_grad +def _loss( + y0__args__term, + solver, + saveat, + adjoint, + stepsize_controller, + dual_y0, + t0_equals_t1, +): + y0, args, term = y0__args__term + + if isinstance(stepsize_controller, diffrax.StepTo): + dt0 = None + else: + dt0 = 0.01 + + if t0_equals_t1: + t1 = 0 + else: + t1 = 5 + + sol = diffrax.diffeqsolve( + term, + solver, + t0=0, + t1=t1, + dt0=dt0, + y0=y0, + args=args, + saveat=saveat, + max_steps=4096, + adjoint=adjoint, + stepsize_controller=stepsize_controller, + ) + if dual_y0: + y1 = sol.ys[0] # pyright: ignore + else: + y1 = sol.ys + return jnp.sum(cast(Array, y1)) + + +def _compare_grads( + y0__args__term, + base_solver, + solver, + saveat, + stepsize_controller, + dual_y0=False, + t0_equals_t1=False, +): + loss, grads_base = _loss( + y0__args__term, + base_solver, + saveat, + adjoint=diffrax.RecursiveCheckpointAdjoint(), + stepsize_controller=stepsize_controller, + dual_y0=dual_y0, + t0_equals_t1=t0_equals_t1, + ) + loss, grads_reversible = _loss( + y0__args__term, + solver, + saveat, + adjoint=diffrax.ReversibleAdjoint(), + stepsize_controller=stepsize_controller, + dual_y0=dual_y0, + t0_equals_t1=t0_equals_t1, + ) + assert tree_allclose(grads_base, grads_reversible, atol=1e-5) + + +@pytest.mark.parametrize( + "stepsize_controller", + [ + diffrax.StepTo(jnp.linspace(0, 5, 50)), + diffrax.ConstantStepSize(), + ], +) +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_semi_implicit_euler(stepsize_controller, saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + fkey, gkey = jr.split(key, 2) + f = VectorField(n, n, n, depth=4, key=fkey) + g = VectorField(n, n, n, depth=4, key=gkey) + terms = (diffrax.ODETerm(f), diffrax.ODETerm(g)) + y0 = (y0, y0) + args = jnp.array([0.5]) + solver = diffrax.SemiImplicitEuler() + + _compare_grads( + (y0, args, terms), solver, solver, saveat, stepsize_controller, dual_y0=True + ) + + +@pytest.mark.parametrize( + "stepsize_controller", + [ + diffrax.StepTo(jnp.linspace(0, 5, 50)), + diffrax.ConstantStepSize(), + diffrax.PIDController(rtol=1e-8, atol=1e-8), + ], +) +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_reversible_heun_ode(stepsize_controller, saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + f = VectorField(n, n, n, depth=4, key=key) + terms = diffrax.ODETerm(f) + args = jnp.array([0.5]) + solver = diffrax.ReversibleHeun() + + _compare_grads((y0, args, terms), solver, solver, saveat, stepsize_controller) + + +@pytest.mark.parametrize( + "stepsize_controller", + [ + diffrax.StepTo(jnp.linspace(0, 5, 50)), + diffrax.ConstantStepSize(), + ], +) +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_reversible_heun_sde(stepsize_controller, saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + fkey, Wkey = jr.split(key, 2) + f = VectorField(n, n, n, depth=4, key=fkey) + g = lambda t, y, args: jnp.ones((n,)) + W = diffrax.VirtualBrownianTree(t0=0, t1=5, tol=1e-3, shape=(n,), key=Wkey) + terms = diffrax.MultiTerm(diffrax.ODETerm(f), diffrax.ControlTerm(g, W)) + args = jnp.array([0.5]) + solver = diffrax.ReversibleHeun() + + _compare_grads((y0, args, terms), solver, solver, saveat, stepsize_controller) + + +@pytest.mark.parametrize( + "stepsize_controller", + [ + diffrax.StepTo(jnp.linspace(0, 5, 50)), + diffrax.ConstantStepSize(), + ], +) +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_leapfrog_midpoint(stepsize_controller, saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + f = VectorField(n, n, n, depth=4, key=key) + terms = diffrax.ODETerm(f) + args = jnp.array([0.5]) + solver = diffrax.LeapfrogMidpoint() + + _compare_grads((y0, args, terms), solver, solver, saveat, stepsize_controller) + + +@pytest.mark.parametrize( + "stepsize_controller", + [ + diffrax.StepTo(jnp.linspace(0, 5, 50)), + diffrax.ConstantStepSize(), + diffrax.PIDController(rtol=1e-8, atol=1e-8), + ], +) +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_reversible_explicit(stepsize_controller, saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + f = VectorField(n, n, n, depth=4, key=key) + terms = diffrax.ODETerm(f) + args = jnp.array([0.5]) + base_solver = diffrax.Tsit5() + solver = diffrax.UReversible(base_solver) + + # If we're using SaveAt(ts=...) then we can only compare the grads from: + # Reversible solver + ReversibleAdjoint, and + # Reversible solver + RecursiveCheckpointAdjoint. + # as the interpolation scheme is different for Tsit5() and Reversible(). + if saveat.subs.ts is not None: + base_solver = solver + + _compare_grads((y0, args, terms), base_solver, solver, saveat, stepsize_controller) + + +@pytest.mark.parametrize( + "stepsize_controller", + [ + diffrax.StepTo(jnp.linspace(0, 5, 50)), + diffrax.ConstantStepSize(), + ], +) +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True, t1=True), + diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True), + ], +) +def test_reversible_sde(stepsize_controller, saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + fkey, Wkey = jr.split(key, 2) + f = VectorField(n, n, n, depth=4, key=fkey) + g = lambda t, y, args: jnp.ones((n, n)) + W = diffrax.VirtualBrownianTree(t0=0, t1=5, tol=1e-3, shape=(n,), key=Wkey) + terms = diffrax.MultiTerm(diffrax.ODETerm(f), diffrax.ControlTerm(g, W)) + args = jnp.array([0.5]) + base_solver = diffrax.Heun() + solver = diffrax.UReversible(base_solver) + + # If we're using SaveAt(ts=...) then we can only compare the grads from: + # Reversible solver + ReversibleAdjoint, and + # Reversible solver + RecursiveCheckpointAdjoint. + # as the interpolation scheme is different for Tsit5() and Reversible(). + if saveat.subs.ts is not None: + base_solver = solver + + _compare_grads((y0, args, terms), base_solver, solver, saveat, stepsize_controller) + + +@pytest.mark.parametrize( + "saveat", + [ + diffrax.SaveAt(t0=True), + diffrax.SaveAt(t1=True), + diffrax.SaveAt(t0=True, t1=True), + ], +) +def test_reversible_t0_equals_t1(saveat): + n = 10 + y0 = jnp.linspace(1, 10, num=n) + key = jr.PRNGKey(10) + f = VectorField(n, n, n, depth=4, key=key) + terms = diffrax.ODETerm(f) + args = jnp.array([0.5]) + base_solver = diffrax.Tsit5() + solver = diffrax.UReversible(base_solver) + stepsize_controller = diffrax.ConstantStepSize() + + _compare_grads( + (y0, args, terms), + base_solver, + solver, + saveat, + stepsize_controller, + t0_equals_t1=True, + )