Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal: Condense strategies into a single class to simplify the source-code #793

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 76 additions & 112 deletions probdiffeq/ivpsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,123 +545,56 @@ class _Strategy:
"""Estimation strategy."""

name: str
extrapolation: _ExtraImpl
correction: _Correction
ssm: Any

is_suitable_for_save_at: int
is_suitable_for_offgrid_marginals: int
is_suitable_for_save_every_step: int

prior: _MarkovProcess

@property
def num_derivatives(self):
return self.prior.num_derivatives

initial_condition: Callable
"""Construct an initial condition from a set of Taylor coefficients."""

init: Callable
"""Initialise a state from a posterior."""

begin: Callable
"""Predict the error of an upcoming step."""

complete: Callable
"""Complete the step after the error has been predicted."""

extract: Callable
"""Extract the solution from a state."""

case_interpolate_at_t1: Callable
"""Process the solution in case t=t_n."""

case_interpolate: Callable

offgrid_marginals: Callable
"""Compute offgrid_marginals."""


def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy:
"""Construct a smoother."""
extrapolation_impl = _ExtraImplSmoother(prior, ssm=ssm)
return _strategy(
extrapolation_impl,
correction,
ssm=ssm,
is_suitable_for_save_at=False,
is_suitable_for_save_every_step=True,
is_suitable_for_offgrid_marginals=True,
name=f"<Smoother with {extrapolation_impl}, {correction}>",
)


def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy:
"""Construct a fixedpoint-smoother."""
extrapolation_impl = _ExtraImplFixedPoint(prior, ssm=ssm)
return _strategy(
extrapolation_impl,
correction,
ssm=ssm,
is_suitable_for_save_at=True,
is_suitable_for_save_every_step=False,
is_suitable_for_offgrid_marginals=False,
name=f"<Fixedpoint smoother with {extrapolation_impl}, {correction}>",
)


def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy:
"""Construct a filter."""
extrapolation_impl = _ExtraImplFilter(prior, ssm=ssm)
return _strategy(
extrapolation_impl,
correction,
name=f"<Filter with {extrapolation_impl}, {correction}>",
is_suitable_for_save_at=True,
is_suitable_for_offgrid_marginals=True,
is_suitable_for_save_every_step=True,
ssm=ssm,
)


def _strategy(
extrapolation: _ExtraImpl,
correction: _Correction,
*,
name,
is_suitable_for_save_at,
is_suitable_for_save_every_step,
is_suitable_for_offgrid_marginals,
ssm,
):
def init(t, posterior, /) -> _StrategyState:
rv, extra = extrapolation.init(posterior)
rv, corr = correction.init(rv)
def init(self, t, posterior, /) -> _StrategyState:
"""Initialise a state from a posterior."""
rv, extra = self.extrapolation.init(posterior)
rv, corr = self.correction.init(rv)
return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr)

def initial_condition():
return extrapolation.initial_condition()
def initial_condition(self):
"""Construct an initial condition from a set of Taylor coefficients."""
return self.extrapolation.initial_condition()

def begin(state: _StrategyState, /, *, dt, vector_field):
hidden, extra = extrapolation.begin(state.hidden, state.aux_extra, dt=dt)
def begin(self, state: _StrategyState, /, *, dt, vector_field):
"""Predict the error of an upcoming step."""
hidden, extra = self.extrapolation.begin(state.hidden, state.aux_extra, dt=dt)
t = state.t + dt
error, observed, corr = correction.estimate_error(
error, observed, corr = self.correction.estimate_error(
hidden, vector_field=vector_field, t=t
)
state = _StrategyState(t=t, hidden=hidden, aux_extra=extra, aux_corr=corr)
return error, observed, state

def complete(state, /, *, output_scale):
hidden, extra = extrapolation.complete(
def complete(self, state, /, *, output_scale):
"""Complete the step after the error has been predicted."""
hidden, extra = self.extrapolation.complete(
state.hidden, state.aux_extra, output_scale=output_scale
)
hidden, corr = correction.complete(hidden, state.aux_corr)
hidden, corr = self.correction.complete(hidden, state.aux_corr)
return _StrategyState(t=state.t, hidden=hidden, aux_extra=extra, aux_corr=corr)

def extract(state: _StrategyState, /):
hidden = correction.extract(state.hidden)
sol = extrapolation.extract(hidden, state.aux_extra)
def extract(self, state: _StrategyState, /):
"""Extract the solution from a state."""
hidden = self.correction.extract(state.hidden)
sol = self.extrapolation.extract(hidden, state.aux_extra)
return state.t, sol

def case_interpolate_at_t1(state_t1: _StrategyState) -> _InterpRes:
_tmp = extrapolation.interpolate_at_t1(state_t1.hidden, state_t1.aux_extra)
def case_interpolate_at_t1(self, state_t1: _StrategyState) -> _InterpRes:
"""Process the solution in case t=t_n."""
_tmp = self.extrapolation.interpolate_at_t1(state_t1.hidden, state_t1.aux_extra)
step_from, solution, interp_from = (
_tmp.step_from,
_tmp.interpolated,
Expand All @@ -679,11 +612,11 @@ def _state(x):
return _InterpRes(step_from, solution, interp_from)

def case_interpolate(
t, *, s0: _StrategyState, s1: _StrategyState, output_scale
self, t, *, s0: _StrategyState, s1: _StrategyState, output_scale
) -> _InterpRes:
"""Process the solution in case t>t_n."""
# Interpolate
interp = extrapolation.interpolate(
interp = self.extrapolation.interpolate(
state_t0=(s0.hidden, s0.aux_extra),
marginal_t1=s1.hidden,
dt0=t - s0.t,
Expand All @@ -704,15 +637,16 @@ def _state(t_, x):
step_from=step_from, interpolated=interpolated, interp_from=interp_from
)

def offgrid_marginals(*, t, marginals_t1, posterior_t0, t0, t1, output_scale):
if not is_suitable_for_offgrid_marginals:
def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale):
"""Compute offgrid_marginals."""
if not self.is_suitable_for_offgrid_marginals:
raise NotImplementedError

dt0 = t - t0
dt1 = t1 - t
state_t0 = init(t0, posterior_t0)
state_t0 = self.init(t0, posterior_t0)

interp = extrapolation.interpolate(
interp = self.extrapolation.interpolate(
state_t0=(state_t0.hidden, state_t0.aux_extra),
marginal_t1=marginals_t1,
dt0=dt0,
Expand All @@ -721,22 +655,52 @@ def offgrid_marginals(*, t, marginals_t1, posterior_t0, t0, t1, output_scale):
)

(marginals, _aux) = interp.interpolated
u = ssm.stats.qoi(marginals)
u = self.ssm.stats.qoi(marginals)
return u, marginals


def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy:
"""Construct a smoother."""
extrapolation = _ExtraImplSmoother(prior, ssm=ssm)
return _Strategy(
name=name,
init=init,
initial_condition=initial_condition,
begin=begin,
complete=complete,
extract=extract,
case_interpolate_at_t1=case_interpolate_at_t1,
case_interpolate=case_interpolate,
offgrid_marginals=offgrid_marginals,
is_suitable_for_save_at=is_suitable_for_save_at,
is_suitable_for_save_every_step=is_suitable_for_save_every_step,
prior=extrapolation.prior,
extrapolation=extrapolation,
correction=correction,
prior=prior,
ssm=ssm,
is_suitable_for_save_at=False,
is_suitable_for_save_every_step=True,
is_suitable_for_offgrid_marginals=True,
name="Smoother",
)


def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy:
"""Construct a fixedpoint-smoother."""
extrapolation = _ExtraImplFixedPoint(prior, ssm=ssm)
return _Strategy(
extrapolation=extrapolation,
correction=correction,
ssm=ssm,
prior=prior,
is_suitable_for_save_at=True,
is_suitable_for_save_every_step=False,
is_suitable_for_offgrid_marginals=False,
name="Fixed-point smoother",
)


def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy:
"""Construct a filter."""
extrapolation = _ExtraImplFilter(prior, ssm=ssm)
return _Strategy(
name="Filter",
prior=prior,
extrapolation=extrapolation,
correction=correction,
is_suitable_for_save_at=True,
is_suitable_for_offgrid_marginals=True,
is_suitable_for_save_every_step=True,
ssm=ssm,
)


Expand Down
Loading