From 7891a46a62c6d814a7ef891e4f0f549863415253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Fri, 25 Oct 2024 16:31:32 +0200 Subject: [PATCH] Condense strategies into a single class to simplify the source-code --- probdiffeq/ivpsolvers.py | 188 ++++++++++++++++----------------------- 1 file changed, 76 insertions(+), 112 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 07e97841..a8851e4d 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -545,123 +545,56 @@ class _Strategy: """Estimation strategy.""" name: str + extrapolation: _ExtraImpl + correction: _Correction + ssm: Any is_suitable_for_save_at: int + is_suitable_for_offgrid_marginals: int is_suitable_for_save_every_step: int - prior: _MarkovProcess @property def num_derivatives(self): return self.prior.num_derivatives - initial_condition: Callable - """Construct an initial condition from a set of Taylor coefficients.""" - - init: Callable - """Initialise a state from a posterior.""" - - begin: Callable - """Predict the error of an upcoming step.""" - - complete: Callable - """Complete the step after the error has been predicted.""" - - extract: Callable - """Extract the solution from a state.""" - - case_interpolate_at_t1: Callable - """Process the solution in case t=t_n.""" - - case_interpolate: Callable - - offgrid_marginals: Callable - """Compute offgrid_marginals.""" - - -def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy: - """Construct a smoother.""" - extrapolation_impl = _ExtraImplSmoother(prior, ssm=ssm) - return _strategy( - extrapolation_impl, - correction, - ssm=ssm, - is_suitable_for_save_at=False, - is_suitable_for_save_every_step=True, - is_suitable_for_offgrid_marginals=True, - name=f"", - ) - - -def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy: - """Construct a fixedpoint-smoother.""" - extrapolation_impl = _ExtraImplFixedPoint(prior, ssm=ssm) - return _strategy( - extrapolation_impl, - correction, - ssm=ssm, - is_suitable_for_save_at=True, - is_suitable_for_save_every_step=False, - is_suitable_for_offgrid_marginals=False, - name=f"", - ) - - -def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy: - """Construct a filter.""" - extrapolation_impl = _ExtraImplFilter(prior, ssm=ssm) - return _strategy( - extrapolation_impl, - correction, - name=f"", - is_suitable_for_save_at=True, - is_suitable_for_offgrid_marginals=True, - is_suitable_for_save_every_step=True, - ssm=ssm, - ) - - -def _strategy( - extrapolation: _ExtraImpl, - correction: _Correction, - *, - name, - is_suitable_for_save_at, - is_suitable_for_save_every_step, - is_suitable_for_offgrid_marginals, - ssm, -): - def init(t, posterior, /) -> _StrategyState: - rv, extra = extrapolation.init(posterior) - rv, corr = correction.init(rv) + def init(self, t, posterior, /) -> _StrategyState: + """Initialise a state from a posterior.""" + rv, extra = self.extrapolation.init(posterior) + rv, corr = self.correction.init(rv) return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) - def initial_condition(): - return extrapolation.initial_condition() + def initial_condition(self): + """Construct an initial condition from a set of Taylor coefficients.""" + return self.extrapolation.initial_condition() - def begin(state: _StrategyState, /, *, dt, vector_field): - hidden, extra = extrapolation.begin(state.hidden, state.aux_extra, dt=dt) + def begin(self, state: _StrategyState, /, *, dt, vector_field): + """Predict the error of an upcoming step.""" + hidden, extra = self.extrapolation.begin(state.hidden, state.aux_extra, dt=dt) t = state.t + dt - error, observed, corr = correction.estimate_error( + error, observed, corr = self.correction.estimate_error( hidden, vector_field=vector_field, t=t ) state = _StrategyState(t=t, hidden=hidden, aux_extra=extra, aux_corr=corr) return error, observed, state - def complete(state, /, *, output_scale): - hidden, extra = extrapolation.complete( + def complete(self, state, /, *, output_scale): + """Complete the step after the error has been predicted.""" + hidden, extra = self.extrapolation.complete( state.hidden, state.aux_extra, output_scale=output_scale ) - hidden, corr = correction.complete(hidden, state.aux_corr) + hidden, corr = self.correction.complete(hidden, state.aux_corr) return _StrategyState(t=state.t, hidden=hidden, aux_extra=extra, aux_corr=corr) - def extract(state: _StrategyState, /): - hidden = correction.extract(state.hidden) - sol = extrapolation.extract(hidden, state.aux_extra) + def extract(self, state: _StrategyState, /): + """Extract the solution from a state.""" + hidden = self.correction.extract(state.hidden) + sol = self.extrapolation.extract(hidden, state.aux_extra) return state.t, sol - def case_interpolate_at_t1(state_t1: _StrategyState) -> _InterpRes: - _tmp = extrapolation.interpolate_at_t1(state_t1.hidden, state_t1.aux_extra) + def case_interpolate_at_t1(self, state_t1: _StrategyState) -> _InterpRes: + """Process the solution in case t=t_n.""" + _tmp = self.extrapolation.interpolate_at_t1(state_t1.hidden, state_t1.aux_extra) step_from, solution, interp_from = ( _tmp.step_from, _tmp.interpolated, @@ -679,11 +612,11 @@ def _state(x): return _InterpRes(step_from, solution, interp_from) def case_interpolate( - t, *, s0: _StrategyState, s1: _StrategyState, output_scale + self, t, *, s0: _StrategyState, s1: _StrategyState, output_scale ) -> _InterpRes: """Process the solution in case t>t_n.""" # Interpolate - interp = extrapolation.interpolate( + interp = self.extrapolation.interpolate( state_t0=(s0.hidden, s0.aux_extra), marginal_t1=s1.hidden, dt0=t - s0.t, @@ -704,15 +637,16 @@ def _state(t_, x): step_from=step_from, interpolated=interpolated, interp_from=interp_from ) - def offgrid_marginals(*, t, marginals_t1, posterior_t0, t0, t1, output_scale): - if not is_suitable_for_offgrid_marginals: + def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale): + """Compute offgrid_marginals.""" + if not self.is_suitable_for_offgrid_marginals: raise NotImplementedError dt0 = t - t0 dt1 = t1 - t - state_t0 = init(t0, posterior_t0) + state_t0 = self.init(t0, posterior_t0) - interp = extrapolation.interpolate( + interp = self.extrapolation.interpolate( state_t0=(state_t0.hidden, state_t0.aux_extra), marginal_t1=marginals_t1, dt0=dt0, @@ -721,22 +655,52 @@ def offgrid_marginals(*, t, marginals_t1, posterior_t0, t0, t1, output_scale): ) (marginals, _aux) = interp.interpolated - u = ssm.stats.qoi(marginals) + u = self.ssm.stats.qoi(marginals) return u, marginals + +def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy: + """Construct a smoother.""" + extrapolation = _ExtraImplSmoother(prior, ssm=ssm) return _Strategy( - name=name, - init=init, - initial_condition=initial_condition, - begin=begin, - complete=complete, - extract=extract, - case_interpolate_at_t1=case_interpolate_at_t1, - case_interpolate=case_interpolate, - offgrid_marginals=offgrid_marginals, - is_suitable_for_save_at=is_suitable_for_save_at, - is_suitable_for_save_every_step=is_suitable_for_save_every_step, - prior=extrapolation.prior, + extrapolation=extrapolation, + correction=correction, + prior=prior, + ssm=ssm, + is_suitable_for_save_at=False, + is_suitable_for_save_every_step=True, + is_suitable_for_offgrid_marginals=True, + name="Smoother", + ) + + +def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy: + """Construct a fixedpoint-smoother.""" + extrapolation = _ExtraImplFixedPoint(prior, ssm=ssm) + return _Strategy( + extrapolation=extrapolation, + correction=correction, + ssm=ssm, + prior=prior, + is_suitable_for_save_at=True, + is_suitable_for_save_every_step=False, + is_suitable_for_offgrid_marginals=False, + name="Fixed-point smoother", + ) + + +def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy: + """Construct a filter.""" + extrapolation = _ExtraImplFilter(prior, ssm=ssm) + return _Strategy( + name="Filter", + prior=prior, + extrapolation=extrapolation, + correction=correction, + is_suitable_for_save_at=True, + is_suitable_for_offgrid_marginals=True, + is_suitable_for_save_every_step=True, + ssm=ssm, )