Skip to content

Commit

Permalink
langevin -> underdamped_langevin
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Aug 31, 2024
1 parent 52e9f44 commit ae72aed
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 201 deletions.
41 changes: 25 additions & 16 deletions diffrax/_solver/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@
RealScalarLike,
)
from .._local_interpolation import LocalLinearInterpolation
from .._term import LangevinLeaf, LangevinTuple, LangevinX
from .langevin_srk import (
_LangevinArgs,
from .._term import (
UnderdampedLangevinLeaf,
UnderdampedLangevinTuple,
UnderdampedLangevinX,
)
from .foster_langevin_srk import (
AbstractCoeffs,
AbstractFosterLangevinSRK,
ULDArgs,
)


# For an explanation of the coefficients, see langevin_srk.py
# For an explanation of the coefficients, see foster_langevin_srk.py
class _ALIGNCoeffs(AbstractCoeffs):
beta: PyTree[ArrayLike]
a1: PyTree[ArrayLike]
Expand All @@ -36,7 +40,7 @@ def __init__(self, beta, a1, b1, aa, chh):
self.dtype = jnp.result_type(*all_leaves)


_ErrorEstimate = LangevinTuple
_ErrorEstimate = UnderdampedLangevinTuple


class ALIGN(AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]):
Expand Down Expand Up @@ -85,7 +89,7 @@ def strong_order(self, terms):
return 2.0

def _directly_compute_coeffs_leaf(
self, h: RealScalarLike, c: LangevinLeaf
self, h: RealScalarLike, c: UnderdampedLangevinLeaf
) -> _ALIGNCoeffs:
del self
# c is a leaf of gamma
Expand All @@ -107,7 +111,7 @@ def _directly_compute_coeffs_leaf(
chh=chh,
)

def _tay_coeffs_single(self, c: LangevinLeaf) -> _ALIGNCoeffs:
def _tay_coeffs_single(self, c: UnderdampedLangevinLeaf) -> _ALIGNCoeffs:
del self
# c is a leaf of gamma
zero = jnp.zeros_like(c)
Expand Down Expand Up @@ -142,18 +146,23 @@ def _compute_step(
self,
h: RealScalarLike,
levy: AbstractSpaceTimeLevyArea,
x0: LangevinX,
v0: LangevinX,
langevin_args: _LangevinArgs,
x0: UnderdampedLangevinX,
v0: UnderdampedLangevinX,
uld_args: ULDArgs,
coeffs: _ALIGNCoeffs,
rho: LangevinX,
prev_f: LangevinX,
) -> tuple[LangevinX, LangevinX, LangevinX, LangevinTuple]:
rho: UnderdampedLangevinX,
prev_f: UnderdampedLangevinX,
) -> tuple[
UnderdampedLangevinX,
UnderdampedLangevinX,
UnderdampedLangevinX,
UnderdampedLangevinTuple,
]:
dtypes = jtu.tree_map(jnp.result_type, x0)
w: LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)
w: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)

gamma, u, f = langevin_args
gamma, u, f = uld_args

uh = (u**ω * h).ω
f0 = prev_f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,36 @@
from .._solution import RESULTS
from .._term import (
AbstractTerm,
LangevinDiffusionTerm,
LangevinDriftTerm,
LangevinLeaf,
LangevinTuple,
LangevinX,
MultiTerm,
UnderdampedLangevinDiffusionTerm,
UnderdampedLangevinDriftTerm,
UnderdampedLangevinLeaf,
UnderdampedLangevinTuple,
UnderdampedLangevinX,
WrapTerm,
)
from .base import AbstractItoSolver, AbstractStratonovichSolver


_ErrorEstimate = TypeVar("_ErrorEstimate", None, LangevinTuple)
_LangevinArgs = tuple[LangevinX, LangevinX, Callable[[LangevinX], LangevinX]]
_ErrorEstimate = TypeVar("_ErrorEstimate", None, UnderdampedLangevinTuple)
ULDArgs = tuple[
UnderdampedLangevinX,
UnderdampedLangevinX,
Callable[[UnderdampedLangevinX], UnderdampedLangevinX],
]


def _get_args_from_terms(
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
) -> tuple[PyTree, PyTree, Callable[[LangevinX], LangevinX]]:
) -> tuple[PyTree, PyTree, Callable[[UnderdampedLangevinX], UnderdampedLangevinX]]:
drift, diffusion = terms.terms
if isinstance(drift, WrapTerm):
assert isinstance(diffusion, WrapTerm)
drift = drift.term
diffusion = diffusion.term

assert isinstance(drift, LangevinDriftTerm)
assert isinstance(diffusion, LangevinDiffusionTerm)
assert isinstance(drift, UnderdampedLangevinDriftTerm)
assert isinstance(diffusion, UnderdampedLangevinDiffusionTerm)
gamma = drift.gamma
u = drift.u
f = drift.grad_f
Expand Down Expand Up @@ -93,13 +97,13 @@ class AbstractCoeffs(eqx.Module):


class SolverState(eqx.Module, Generic[_Coeffs]):
gamma: LangevinX
u: LangevinX
gamma: UnderdampedLangevinX
u: UnderdampedLangevinX
h: RealScalarLike
taylor_coeffs: PyTree[_Coeffs, "LangevinX"]
taylor_coeffs: PyTree[_Coeffs, "UnderdampedLangevinX"]
coeffs: _Coeffs
rho: LangevinX
prev_f: LangevinX
rho: UnderdampedLangevinX
prev_f: UnderdampedLangevinX


class AbstractFosterLangevinSRK(
Expand All @@ -126,14 +130,16 @@ class AbstractFosterLangevinSRK(
[`diffrax.ShOULD`][], and [`diffrax.QUIC_SORT`][].
"""

term_structure = MultiTerm[tuple[LangevinDriftTerm, LangevinDiffusionTerm]]
term_structure = MultiTerm[
tuple[UnderdampedLangevinDriftTerm, UnderdampedLangevinDiffusionTerm]
]
interpolation_cls = LocalLinearInterpolation
minimal_levy_area: eqx.AbstractClassVar[type[AbstractBrownianIncrement]]
taylor_threshold: AbstractVar[RealScalarLike]

@abc.abstractmethod
def _directly_compute_coeffs_leaf(
self, h: RealScalarLike, c: LangevinLeaf
self, h: RealScalarLike, c: UnderdampedLangevinLeaf
) -> _Coeffs:
r"""This method specifies how to compute the SRK coefficients directly
(as opposed to via Taylor expansion). This function is then mapped over the
Expand All @@ -151,7 +157,7 @@ def _directly_compute_coeffs_leaf(
raise NotImplementedError

@abc.abstractmethod
def _tay_coeffs_single(self, c: LangevinLeaf) -> _Coeffs:
def _tay_coeffs_single(self, c: UnderdampedLangevinLeaf) -> _Coeffs:
r"""This method specifies how to compute the Taylor coefficients for a
single leaf of gamma. These coefficients are then used to compute the SRK
coefficients using the Taylor expansion. This function is then mapped over
Expand Down Expand Up @@ -184,14 +190,17 @@ def eval_taylor_fun(tay_leaf):
return jtu.tree_map(eval_taylor_fun, tay_coeffs)

def _recompute_coeffs(
self, h: RealScalarLike, gamma: LangevinX, tay_coeffs: PyTree[_Coeffs]
self,
h: RealScalarLike,
gamma: UnderdampedLangevinX,
tay_coeffs: PyTree[_Coeffs],
) -> _Coeffs:
r"""When h changes, the SRK coefficients (which depend on h) are recomputed
using this function."""
# Inner will record the tree structure of the coefficients
inner = sentinel = object()

def recompute_coeffs_leaf(c: LangevinLeaf, _tay_coeffs: _Coeffs):
def recompute_coeffs_leaf(c: UnderdampedLangevinLeaf, _tay_coeffs: _Coeffs):
# c is a leaf of gamma
# Depending on the size of h*gamma choose whether the Taylor expansion or
# direct computation is more accurate.
Expand Down Expand Up @@ -250,7 +259,7 @@ def init(
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
t0: RealScalarLike,
t1: RealScalarLike,
y0: LangevinTuple,
y0: UnderdampedLangevinTuple,
args: PyTree,
) -> SolverState:
"""Precompute _SolverState which carries the Taylor coefficients and the
Expand Down Expand Up @@ -300,14 +309,16 @@ def _compute_step(
self,
h: RealScalarLike,
levy,
x0: LangevinX,
v0: LangevinX,
langevin_args: _LangevinArgs,
x0: UnderdampedLangevinX,
v0: UnderdampedLangevinX,
uld_args: ULDArgs,
coeffs: _Coeffs,
rho: LangevinX,
prev_f: LangevinX,
) -> tuple[LangevinX, LangevinX, LangevinX, _ErrorEstimate]:
r"""This method specifies how to compute a single step of the Langevin SRK
rho: UnderdampedLangevinX,
prev_f: UnderdampedLangevinX,
) -> tuple[
UnderdampedLangevinX, UnderdampedLangevinX, UnderdampedLangevinX, _ErrorEstimate
]:
r"""This method specifies how to compute a single step of the ULD SRK
method. This holds just the computation that differs between the different
SRK methods. The common bits are handled by the `step` method."""
raise NotImplementedError
Expand All @@ -317,11 +328,13 @@ def step(
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
t0: RealScalarLike,
t1: RealScalarLike,
y0: LangevinTuple,
y0: UnderdampedLangevinTuple,
args: PyTree,
solver_state: SolverState,
made_jump: BoolScalarLike,
) -> tuple[LangevinTuple, _ErrorEstimate, DenseInfo, SolverState, RESULTS]:
) -> tuple[
UnderdampedLangevinTuple, _ErrorEstimate, DenseInfo, SolverState, RESULTS
]:
del args
st = solver_state
drift, diffusion = terms.terms
Expand Down Expand Up @@ -352,7 +365,9 @@ def step(
)

x0, v0 = y0
prev_f = lax.cond(made_jump, lambda: grad_f(x0), lambda: st.prev_f)
prev_f = lax.cond(
eqxi.unvmap_any(made_jump), lambda: grad_f(x0), lambda: st.prev_f
)

# The actual step computation, handled by the subclass
x_out, v_out, f_fsal, error = self._compute_step(
Expand Down Expand Up @@ -389,7 +404,7 @@ def func(
self,
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
t0: RealScalarLike,
y0: LangevinTuple,
y0: UnderdampedLangevinTuple,
args: PyTree,
):
return terms.vf(t0, y0, args)
32 changes: 16 additions & 16 deletions diffrax/_solver/quicsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
RealScalarLike,
)
from .._local_interpolation import LocalLinearInterpolation
from .._term import LangevinLeaf, LangevinX
from .langevin_srk import (
_LangevinArgs,
from .._term import UnderdampedLangevinLeaf, UnderdampedLangevinX
from .foster_langevin_srk import (
AbstractCoeffs,
AbstractFosterLangevinSRK,
ULDArgs,
)


# For an explanation of the coefficients, see langevin_srk.py
# For an explanation of the coefficients, see foster_langevin_srk.py
# UBU evaluates at l = (3 -sqrt(3))/6, at r = (3 + sqrt(3))/6 and at 1,
# so we need 3 versions of each coefficient
class _QUICSORTCoeffs(AbstractCoeffs):
Expand Down Expand Up @@ -103,7 +103,7 @@ def strong_order(self, terms):
return 3.0

def _directly_compute_coeffs_leaf(
self, h: RealScalarLike, c: LangevinLeaf
self, h: RealScalarLike, c: UnderdampedLangevinLeaf
) -> _QUICSORTCoeffs:
del self
# compute the coefficients directly (as opposed to via Taylor expansion)
Expand Down Expand Up @@ -131,7 +131,7 @@ def _directly_compute_coeffs_leaf(
a_div_h=a_div_h,
)

def _tay_coeffs_single(self, c: LangevinLeaf) -> _QUICSORTCoeffs:
def _tay_coeffs_single(self, c: UnderdampedLangevinLeaf) -> _QUICSORTCoeffs:
del self
# c is a leaf of gamma
dtype = jnp.result_type(c)
Expand Down Expand Up @@ -188,19 +188,19 @@ def _compute_step(
self,
h: RealScalarLike,
levy: AbstractSpaceTimeTimeLevyArea,
x0: LangevinX,
v0: LangevinX,
langevin_args: _LangevinArgs,
x0: UnderdampedLangevinX,
v0: UnderdampedLangevinX,
uld_args: ULDArgs,
coeffs: _QUICSORTCoeffs,
rho: LangevinX,
prev_f: LangevinX,
) -> tuple[LangevinX, LangevinX, LangevinX, None]:
rho: UnderdampedLangevinX,
prev_f: UnderdampedLangevinX,
) -> tuple[UnderdampedLangevinX, UnderdampedLangevinX, UnderdampedLangevinX, None]:
dtypes = jtu.tree_map(jnp.result_type, x0)
w: LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)
kk: LangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes)
w: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)
kk: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes)

gamma, u, f = langevin_args
gamma, u, f = uld_args

def _extract_coeffs(coeff, index):
return jtu.tree_map(lambda arr: arr[..., index], coeff)
Expand Down
Loading

0 comments on commit ae72aed

Please sign in to comment.