Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ForwardMode as ForwardMode,
ImplicitAdjoint as ImplicitAdjoint,
RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint,
ReversibleAdjoint as ReversibleAdjoint,
)
from ._autocitation import citation as citation, citation_rules as citation_rules
from ._brownian import (
Expand Down Expand Up @@ -75,6 +76,7 @@
AbstractFosterLangevinSRK as AbstractFosterLangevinSRK,
AbstractImplicitSolver as AbstractImplicitSolver,
AbstractItoSolver as AbstractItoSolver,
AbstractReversibleSolver as AbstractReversibleSolver,
AbstractRungeKutta as AbstractRungeKutta,
AbstractSDIRK as AbstractSDIRK,
AbstractSolver as AbstractSolver,
Expand Down Expand Up @@ -117,6 +119,7 @@
StochasticButcherTableau as StochasticButcherTableau,
StratonovichMilstein as StratonovichMilstein,
Tsit5 as Tsit5,
UReversible as UReversible,
)
from ._step_size_controller import (
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
Expand Down
306 changes: 306 additions & 0 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

from ._heuristics import is_sde, is_unsafe_sde
from ._saveat import save_y, SaveAt, SubSaveAt
from ._solution import RESULTS
from ._solver import (
AbstractItoSolver,
AbstractReversibleSolver,
AbstractRungeKutta,
AbstractSRK,
AbstractStratonovichSolver,
Expand Down Expand Up @@ -918,3 +920,307 @@ def loop(


ForwardMode.__init__.__doc__ = """**Arguments:** None"""

# Reversible Adjoint custom vjp computes gradients w.r.t.
# - y, corresponding to the initial state;
# - args, corresponding to explicit parameters;
# - terms, corresponding to implicit parameters as part of the vector field.


@eqx.filter_custom_vjp
def _loop_reversible(y__args__terms, *, self, throw, max_steps, init_state, **kwargs):
del throw
y, args, terms = y__args__terms
init_state = eqx.tree_at(lambda s: s.y, init_state, y)
del y
return self._loop(
args=args,
terms=terms,
max_steps=max_steps,
init_state=init_state,
inner_while_loop=ft.partial(_inner_loop, kind="lax"),
outer_while_loop=ft.partial(_outer_loop, kind="lax"),
**kwargs,
)


@_loop_reversible.def_fwd
def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs):
del perturbed
final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs)
init_ts = final_state.reversible_init_ts
ts = final_state.reversible_ts
ts_final_index = final_state.reversible_save_index
y1 = final_state.y
save_state = final_state.save_state
solver_state = final_state.solver_state
return (final_state, aux_stats), (
init_ts,
ts,
ts_final_index,
y1,
save_state,
solver_state,
)


@_loop_reversible.def_bwd
def _loop_reversible_bwd(
residuals,
grad_final_state__aux_stats,
perturbed,
y__args__terms,
*,
self,
saveat,
init_state,
solver,
event,
**kwargs,
):
assert event is None

del perturbed, self, init_state, kwargs
init_ts, ts, ts_final_index, y1, save_state, solver_state = residuals
del residuals

grad_final_state, _ = grad_final_state__aux_stats
saveat_ts = save_state.ts
ys = save_state.ys
saveat_ts_index = save_state.saveat_ts_index - 1
grad_ys = grad_final_state.save_state.ys
grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys)

if saveat.subs.t1:
grad_y1 = (ω(grad_ys)[-1]).ω
else:
grad_y1 = jtu.tree_map(jnp.zeros_like, y1)

if saveat.subs.t0:
saveat_ts_index = saveat_ts_index + 1

del grad_final_state, grad_final_state__aux_stats

y, args, terms = y__args__terms
del y__args__terms

diff_state = eqx.filter(solver_state, eqx.is_inexact_array)
diff_args = eqx.filter(args, eqx.is_inexact_array)
diff_terms = eqx.filter(terms, eqx.is_inexact_array)
grad_state = jtu.tree_map(jnp.zeros_like, diff_state)
grad_args = jtu.tree_map(jnp.zeros_like, diff_args)
grad_terms = jtu.tree_map(jnp.zeros_like, diff_terms)
del diff_args, diff_terms

def grad_step(state):
def forward_step(y0, solver_state, args, terms):
y1, _, dense_info, new_solver_state, result = solver.step(
terms, t0, t1, y0, args, solver_state, False
)
assert result == RESULTS.successful
return y1, dense_info, new_solver_state

(
saveat_ts_index,
ts_index,
y1,
solver_state,
grad_y1,
grad_state,
grad_args,
grad_terms,
) = state

t1 = ts[ts_index]
t0 = ts[ts_index - 1]

# Any ts state required to reverse the forward step
# e.g. LeapfrogMidpoint requires tm1
tm1_index = ts_index - 2
tm1 = ts[tm1_index]
tm1 = jnp.where(tm1_index >= 0, tm1, t0)
ts_state = (tm1,)

y0, dense_info, solver_state, result = solver.backward_step(
terms, t0, t1, y1, args, ts_state, solver_state, False
)
assert result == RESULTS.successful

# Pull gradients back through interpolation

def interpolate(t, t0, t1, dense_info):
interpolator = solver.interpolation_cls(t0=t0, t1=t1, **dense_info)
return interpolator.evaluate(t)

def _cond_fun(inner_state):
saveat_ts_index, _ = inner_state
return (saveat_ts[saveat_ts_index] >= t0) & (saveat_ts_index >= 0)

def _body_fun(inner_state):
saveat_ts_index, grad_dense_info = inner_state
t = saveat_ts[saveat_ts_index]
grad_y = (ω(grad_ys)[saveat_ts_index]).ω
_, interp_vjp = eqx.filter_vjp(interpolate, t, t0, t1, dense_info)
_, _, _, dgrad_dense_info = interp_vjp(grad_y)
grad_dense_info = eqx.apply_updates(grad_dense_info, dgrad_dense_info)
saveat_ts_index = saveat_ts_index - 1
return saveat_ts_index, grad_dense_info

grad_dense_info = jtu.tree_map(jnp.zeros_like, dense_info)
inner_state = (saveat_ts_index, grad_dense_info)
inner_state = eqxi.while_loop(_cond_fun, _body_fun, inner_state, kind="lax")
saveat_ts_index, grad_dense_info = inner_state

# Pull gradients back through forward step

_, vjp_fn = eqx.filter_vjp(forward_step, y0, solver_state, args, terms)
grad_y0, grad_state, dgrad_args, dgrad_terms = vjp_fn(
(grad_y1, grad_dense_info, grad_state)
)

grad_args = eqx.apply_updates(grad_args, dgrad_args)
grad_terms = eqx.apply_updates(grad_terms, dgrad_terms)

ts_index = ts_index - 1

return (
saveat_ts_index,
ts_index,
y0,
solver_state,
grad_y0,
grad_state,
grad_args,
grad_terms,
)

def cond_fun(state):
ts_index = state[1]
return ts_index > 0

state = (
saveat_ts_index,
ts_final_index,
y1,
solver_state,
grad_y1,
grad_state,
grad_args,
grad_terms,
)

state = jax.lax.while_loop(cond_fun, grad_step, state)
_, _, y0, _, grad_y0, grad_state, grad_args, grad_terms = state

# Pull solver_state gradients back onto y0, args, terms.

init_t0, init_t1 = init_ts
_, init_vjp = eqx.filter_vjp(solver.init, terms, init_t0, init_t1, y0, args)
dgrad_terms, _, _, dgrad_y0, dgrad_args = init_vjp(grad_state)
grad_y0 = eqx.apply_updates(grad_y0, dgrad_y0)
grad_terms = eqx.apply_updates(grad_terms, dgrad_terms)
grad_args = eqx.apply_updates(grad_args, dgrad_args)

return grad_y0, grad_args, grad_terms


class ReversibleAdjoint(AbstractAdjoint):
"""Backpropagate through [`diffrax.diffeqsolve`][] when using a reversible solver
[`diffrax.AbstractReversibleSolver`][].

Gradient calculation is exact (up to floating point errors) and backpropagation
becomes linear in time $O(n)$ and constant in memory $O(1)$, for $n$ time steps.

!!! note

This adjoint can be less numerically stable than
[`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.DirectAdjoint`][].
Stability can be largely improved by using [double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)
and [smaller/adaptive step sizes](https://docs.kidger.site/diffrax/api/stepsize_controller/).

??? cite "References"

For an introduction to reversible backpropagation, see these references:

```bibtex
@article{mccallum2024efficient,
title={Efficient, Accurate and Stable Gradients for Neural ODEs},
author={McCallum, Sam and Foster, James},
journal={arXiv preprint arXiv:2410.11648},
year={2024}
}

@phdthesis{kidger2021on,
title={{O}n {N}eural {D}ifferential {E}quations},
author={Patrick Kidger},
year={2021},
school={University of Oxford},
}
```
"""

def loop(
self,
*,
args,
terms,
solver,
saveat,
max_steps,
init_state,
passed_solver_state,
passed_controller_state,
event,
**kwargs,
):
if not isinstance(solver, AbstractReversibleSolver):
raise ValueError(
"`ReversibleAdjoint` can only be used with an "
"`AbstractReversibleSolver`"
)
if max_steps is None:
raise ValueError(
"`max_steps=None` is incompatible with `ReversibleAdjoint`."
)

if (
jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat)
!= jtu.tree_structure(0)
or saveat.dense
or saveat.subs.steps
or (saveat.subs.fn is not save_y)
):
raise ValueError(
"`ReversibleAdjoint` is only compatible with the following `SaveAt` "
"properties: `t0`, `t1`, `ts`, `fn=save_y` (default)."
)

if event is not None:
raise NotImplementedError(
"`ReversibleAdjoint` is not compatible with events."
)

if is_unsafe_sde(terms):
raise ValueError(
"`ReversibleAdjoint` does not support `UnsafeBrownianPath`. "
"Consider using `VirtualBrownianTree` instead."
)

y = init_state.y
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
init_state = _nondiff_solver_controller_state(
self, init_state, passed_solver_state, passed_controller_state
)

final_state, aux_stats = _loop_reversible(
(y, args, terms),
self=self,
saveat=saveat,
max_steps=max_steps,
init_state=init_state,
solver=solver,
event=event,
**kwargs,
)
final_state = _only_transpose_ys(final_state)
return final_state, aux_stats
Loading