From b3e4ca11ab7e85a5bebead3848bb76c2c891e68d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 09:17:29 +0200 Subject: [PATCH 01/24] Move name, is_suitable* from _Strategy to _ExtraImpl --- probdiffeq/ivpsolvers.py | 100 +++++++++++++++++++++++---------------- 1 file changed, 58 insertions(+), 42 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 064d3884..1b5cc43f 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -174,8 +174,12 @@ class _ExtraImpl: """Extrapolation model interface.""" prior: _MarkovProcess + name: str ssm: Any + is_suitable_for_save_at: int + is_suitable_for_save_every_step: int + def initial_condition(self): """Compute an initial condition from a set of Taylor coefficients.""" raise NotImplementedError @@ -569,51 +573,47 @@ class _StrategyState(containers.NamedTuple): class _Strategy: """Estimation strategy.""" - name: str extrapolation: _ExtraImpl - correction: _Correction ssm: Any - is_suitable_for_save_at: int is_suitable_for_offgrid_marginals: int - is_suitable_for_save_every_step: int prior: _MarkovProcess @property def num_derivatives(self): return self.prior.num_derivatives - def init(self, t, posterior, /) -> _StrategyState: + def init(self, t, posterior, /, correction) -> _StrategyState: """Initialise a state from a posterior.""" rv, extra = self.extrapolation.init(posterior) - rv, corr = self.correction.init(rv) + rv, corr = correction.init(rv) return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) def initial_condition(self): """Construct an initial condition from a set of Taylor coefficients.""" return self.extrapolation.initial_condition() - def begin(self, state: _StrategyState, /, *, dt, vector_field): + def begin(self, state: _StrategyState, /, *, dt, vector_field, correction): """Predict the error of an upcoming step.""" hidden, extra = self.extrapolation.begin(state.hidden, state.aux_extra, dt=dt) t = state.t + dt - error, observed, corr = self.correction.estimate_error( + error, observed, corr = correction.estimate_error( hidden, vector_field=vector_field, t=t ) state = _StrategyState(t=t, hidden=hidden, aux_extra=extra, aux_corr=corr) return error, observed, state - def complete(self, state, /, *, output_scale): + def complete(self, state, /, *, output_scale, correction): """Complete the step after the error has been predicted.""" hidden, extra = self.extrapolation.complete( state.hidden, state.aux_extra, output_scale=output_scale ) - hidden, corr = self.correction.complete(hidden, state.aux_corr) + hidden, corr = correction.complete(hidden, state.aux_corr) return _StrategyState(t=state.t, hidden=hidden, aux_extra=extra, aux_corr=corr) - def extract(self, state: _StrategyState, /): + def extract(self, state: _StrategyState, /, *, correction): """Extract the solution from a state.""" - hidden = self.correction.extract(state.hidden) + hidden = correction.extract(state.hidden) sol = self.extrapolation.extract(hidden, state.aux_extra) return state.t, sol @@ -662,14 +662,16 @@ def _state(t_, x): step_from=step_from, interpolated=interpolated, interp_from=interp_from ) - def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale): + def offgrid_marginals( + self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale, correction + ): """Compute offgrid_marginals.""" if not self.is_suitable_for_offgrid_marginals: raise NotImplementedError dt0 = t - t0 dt1 = t1 - t - state_t0 = self.init(t0, posterior_t0) + state_t0 = self.init(t0, posterior_t0, correction=correction) interp = self.extrapolation.interpolate( state_t0=(state_t0.hidden, state_t0.aux_extra), @@ -686,47 +688,50 @@ def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_sca def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy: """Construct a smoother.""" - extrapolation = _ExtraImplSmoother(prior, ssm=ssm) - return _Strategy( + extrapolation = _ExtraImplSmoother(prior=prior, name="Smoother", ssm=ssm) + strategy = _Strategy( extrapolation=extrapolation, - correction=correction, prior=prior, ssm=ssm, is_suitable_for_save_at=False, is_suitable_for_save_every_step=True, is_suitable_for_offgrid_marginals=True, - name="Smoother", ) + return strategy, correction def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy: """Construct a fixedpoint-smoother.""" - extrapolation = _ExtraImplFixedPoint(prior, ssm=ssm) - return _Strategy( + extrapolation = _ExtraImplFixedPoint( + prior=prior, name="Fixed-point smoother", ssm=ssm + ) + strategy = _Strategy( extrapolation=extrapolation, - correction=correction, ssm=ssm, prior=prior, is_suitable_for_save_at=True, is_suitable_for_save_every_step=False, is_suitable_for_offgrid_marginals=False, - name="Fixed-point smoother", ) + return strategy, correction def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy: """Construct a filter.""" - extrapolation = _ExtraImplFilter(prior, ssm=ssm) - return _Strategy( + extrapolation = _ExtraImplFilter( + prior=prior, name="Filter", + ssm=ssm, + is_suitable_for_save_at=True, + is_suitable_for_save_every_step=True, + ) + strategy = _Strategy( prior=prior, extrapolation=extrapolation, - correction=correction, - is_suitable_for_save_at=True, is_suitable_for_offgrid_marginals=True, - is_suitable_for_save_every_step=True, ssm=ssm, ) + return strategy, correction, extrapolation @containers.dataclass @@ -752,14 +757,19 @@ def t(self): @containers.dataclass class _ProbabilisticSolver: name: str - calibration: _Calibration + requires_rescaling: bool + step_implementation: Callable + + extrapolation: _ExtraImpl + calibration: _Calibration + correction: _Correction strategy: _Strategy - requires_rescaling: bool - @property - def offgrid_marginals(self): - return self.strategy.offgrid_marginals + def offgrid_marginals(self, *args, **kwargs): + return self.strategy.offgrid_marginals( + *args, **kwargs, correction=self.correction + ) @property def error_contraction_rate(self): @@ -767,15 +777,15 @@ def error_contraction_rate(self): @property def is_suitable_for_save_at(self): - return self.strategy.is_suitable_for_save_at + return self.extrapolation.is_suitable_for_save_at @property def is_suitable_for_save_every_step(self): - return self.strategy.is_suitable_for_save_every_step + return self.extrapolation.is_suitable_for_save_every_step def init(self, t, initial_condition) -> _SolverState: posterior, output_scale = initial_condition - state_strategy = self.strategy.init(t, posterior) + state_strategy = self.strategy.init(t, posterior, correction=self.correction) calib_state = self.calibration.init(output_scale) return _SolverState(strategy=state_strategy, output_scale=calib_state) @@ -785,7 +795,7 @@ def step(self, state: _SolverState, *, vector_field, dt) -> _SolverState: ) def extract(self, state: _SolverState, /): - t, posterior = self.strategy.extract(state.strategy) + t, posterior = self.strategy.extract(state.strategy, correction=self.correction) _output_scale_prior, output_scale = self.calibration.extract(state.output_scale) return t, (posterior, output_scale) @@ -815,21 +825,22 @@ def initial_condition(self): return posterior, self.strategy.prior.output_scale -def solver_mle(strategy, *, ssm): +def solver_mle(inputs, *, ssm): """Create a solver that calibrates the output scale via maximum-likelihood. Warning: needs to be combined with a call to stats.calibrate() after solving if the MLE-calibration shall be *used*. """ + strategy, correction, extrapolation = inputs def step_mle(state, /, *, dt, vector_field, calibration): output_scale_prior, _calibrated = calibration.extract(state.output_scale) error, _, state_strategy = strategy.begin( - state.strategy, dt=dt, vector_field=vector_field + state.strategy, dt=dt, vector_field=vector_field, correction=correction ) state_strategy = strategy.complete( - state_strategy, output_scale=output_scale_prior + state_strategy, output_scale=output_scale_prior, correction=correction ) observed = state_strategy.aux_corr @@ -844,6 +855,8 @@ def step_mle(state, /, *, dt, vector_field, calibration): name="Probabilistic solver with MLE calibration", calibration=_calibration_running_mean(ssm=ssm), step_implementation=step_mle, + extrapolation=extrapolation, + correction=correction, strategy=strategy, requires_rescaling=True, ) @@ -911,17 +924,18 @@ def extract(state, /): return _Calibration(init=init, update=update, extract=extract) -def solver(strategy, /): +def solver(inputs, /): """Create a solver that does not calibrate the output scale automatically.""" + strategy, correction, extrapolation = inputs def step(state: _SolverState, *, vector_field, dt, calibration): del calibration # unused error, _observed, state_strategy = strategy.begin( - state.strategy, dt=dt, vector_field=vector_field + state.strategy, dt=dt, vector_field=vector_field, correction=correction ) state_strategy = strategy.complete( - state_strategy, output_scale=state.output_scale + state_strategy, output_scale=state.output_scale, correction=correction ) # Extract and return solution state = _SolverState(strategy=state_strategy, output_scale=state.output_scale) @@ -929,6 +943,8 @@ def step(state: _SolverState, *, vector_field, dt, calibration): return _ProbabilisticSolver( strategy=strategy, + extrapolation=extrapolation, + correction=correction, calibration=_calibration_none(), step_implementation=step, name="Probabilistic solver", From d5a0c37bf2c388cea64cd9d2d158d8067b0e58de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 09:25:50 +0200 Subject: [PATCH 02/24] Move extrapolation from _Strategy to _ProbabilisticSolver --- probdiffeq/ivpsolvers.py | 104 ++++++++++++++++++++++++++------------- 1 file changed, 70 insertions(+), 34 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 1b5cc43f..76432389 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -573,7 +573,6 @@ class _StrategyState(containers.NamedTuple): class _Strategy: """Estimation strategy.""" - extrapolation: _ExtraImpl ssm: Any is_suitable_for_offgrid_marginals: int @@ -583,19 +582,21 @@ class _Strategy: def num_derivatives(self): return self.prior.num_derivatives - def init(self, t, posterior, /, correction) -> _StrategyState: + def init(self, t, posterior, /, *, extrapolation, correction) -> _StrategyState: """Initialise a state from a posterior.""" - rv, extra = self.extrapolation.init(posterior) + rv, extra = extrapolation.init(posterior) rv, corr = correction.init(rv) return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) - def initial_condition(self): + def initial_condition(self, *, extrapolation): """Construct an initial condition from a set of Taylor coefficients.""" - return self.extrapolation.initial_condition() + return extrapolation.initial_condition() - def begin(self, state: _StrategyState, /, *, dt, vector_field, correction): + def begin( + self, state: _StrategyState, /, *, dt, vector_field, extrapolation, correction + ): """Predict the error of an upcoming step.""" - hidden, extra = self.extrapolation.begin(state.hidden, state.aux_extra, dt=dt) + hidden, extra = extrapolation.begin(state.hidden, state.aux_extra, dt=dt) t = state.t + dt error, observed, corr = correction.estimate_error( hidden, vector_field=vector_field, t=t @@ -603,23 +604,25 @@ def begin(self, state: _StrategyState, /, *, dt, vector_field, correction): state = _StrategyState(t=t, hidden=hidden, aux_extra=extra, aux_corr=corr) return error, observed, state - def complete(self, state, /, *, output_scale, correction): + def complete(self, state, /, *, output_scale, extrapolation, correction): """Complete the step after the error has been predicted.""" - hidden, extra = self.extrapolation.complete( + hidden, extra = extrapolation.complete( state.hidden, state.aux_extra, output_scale=output_scale ) hidden, corr = correction.complete(hidden, state.aux_corr) return _StrategyState(t=state.t, hidden=hidden, aux_extra=extra, aux_corr=corr) - def extract(self, state: _StrategyState, /, *, correction): + def extract(self, state: _StrategyState, /, *, extrapolation, correction): """Extract the solution from a state.""" hidden = correction.extract(state.hidden) - sol = self.extrapolation.extract(hidden, state.aux_extra) + sol = extrapolation.extract(hidden, state.aux_extra) return state.t, sol - def case_interpolate_at_t1(self, state_t1: _StrategyState) -> _InterpRes: + def case_interpolate_at_t1( + self, state_t1: _StrategyState, *, extrapolation + ) -> _InterpRes: """Process the solution in case t=t_n.""" - _tmp = self.extrapolation.interpolate_at_t1(state_t1.hidden, state_t1.aux_extra) + _tmp = extrapolation.interpolate_at_t1(state_t1.hidden, state_t1.aux_extra) step_from, solution, interp_from = ( _tmp.step_from, _tmp.interpolated, @@ -637,11 +640,11 @@ def _state(x): return _InterpRes(step_from, solution, interp_from) def case_interpolate( - self, t, *, s0: _StrategyState, s1: _StrategyState, output_scale + self, t, *, s0: _StrategyState, s1: _StrategyState, output_scale, extrapolation ) -> _InterpRes: """Process the solution in case t>t_n.""" # Interpolate - interp = self.extrapolation.interpolate( + interp = extrapolation.interpolate( state_t0=(s0.hidden, s0.aux_extra), marginal_t1=s1.hidden, dt0=t - s0.t, @@ -663,7 +666,16 @@ def _state(t_, x): ) def offgrid_marginals( - self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale, correction + self, + *, + t, + marginals_t1, + posterior_t0, + t0, + t1, + output_scale, + extrapolation, + correction, ): """Compute offgrid_marginals.""" if not self.is_suitable_for_offgrid_marginals: @@ -671,9 +683,11 @@ def offgrid_marginals( dt0 = t - t0 dt1 = t1 - t - state_t0 = self.init(t0, posterior_t0, correction=correction) + state_t0 = self.init( + t0, posterior_t0, extrapolation=extrapolation, correction=correction + ) - interp = self.extrapolation.interpolate( + interp = extrapolation.interpolate( state_t0=(state_t0.hidden, state_t0.aux_extra), marginal_t1=marginals_t1, dt0=dt0, @@ -725,12 +739,7 @@ def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy: is_suitable_for_save_at=True, is_suitable_for_save_every_step=True, ) - strategy = _Strategy( - prior=prior, - extrapolation=extrapolation, - is_suitable_for_offgrid_marginals=True, - ssm=ssm, - ) + strategy = _Strategy(prior=prior, is_suitable_for_offgrid_marginals=True, ssm=ssm) return strategy, correction, extrapolation @@ -768,7 +777,10 @@ class _ProbabilisticSolver: def offgrid_marginals(self, *args, **kwargs): return self.strategy.offgrid_marginals( - *args, **kwargs, correction=self.correction + *args, + **kwargs, + extrapolation=self.extrapolation, + correction=self.correction, ) @property @@ -785,7 +797,9 @@ def is_suitable_for_save_every_step(self): def init(self, t, initial_condition) -> _SolverState: posterior, output_scale = initial_condition - state_strategy = self.strategy.init(t, posterior, correction=self.correction) + state_strategy = self.strategy.init( + t, posterior, correction=self.correction, extrapolation=self.extrapolation + ) calib_state = self.calibration.init(output_scale) return _SolverState(strategy=state_strategy, output_scale=calib_state) @@ -795,7 +809,9 @@ def step(self, state: _SolverState, *, vector_field, dt) -> _SolverState: ) def extract(self, state: _SolverState, /): - t, posterior = self.strategy.extract(state.strategy, correction=self.correction) + t, posterior = self.strategy.extract( + state.strategy, extrapolation=self.extrapolation, correction=self.correction + ) _output_scale_prior, output_scale = self.calibration.extract(state.output_scale) return t, (posterior, output_scale) @@ -804,7 +820,11 @@ def interpolate( ) -> _InterpRes: output_scale, _ = self.calibration.extract(interp_to.output_scale) interp = self.strategy.case_interpolate( - t, s0=interp_from.strategy, s1=interp_to.strategy, output_scale=output_scale + t, + s0=interp_from.strategy, + s1=interp_to.strategy, + output_scale=output_scale, + extrapolation=self.extrapolation, ) prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale) sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale) @@ -812,7 +832,9 @@ def interpolate( return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: - x = self.strategy.case_interpolate_at_t1(interp_to.strategy) + x = self.strategy.case_interpolate_at_t1( + interp_to.strategy, extrapolation=self.extrapolation + ) prev = _SolverState(x.interp_from, output_scale=interp_from.output_scale) sol = _SolverState(x.interpolated, output_scale=interp_to.output_scale) @@ -821,7 +843,7 @@ def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: def initial_condition(self): """Construct an initial condition.""" - posterior = self.strategy.initial_condition() + posterior = self.strategy.initial_condition(extrapolation=self.extrapolation) return posterior, self.strategy.prior.output_scale @@ -836,11 +858,18 @@ def solver_mle(inputs, *, ssm): def step_mle(state, /, *, dt, vector_field, calibration): output_scale_prior, _calibrated = calibration.extract(state.output_scale) error, _, state_strategy = strategy.begin( - state.strategy, dt=dt, vector_field=vector_field, correction=correction + state.strategy, + dt=dt, + vector_field=vector_field, + extrapolation=extrapolation, + correction=correction, ) state_strategy = strategy.complete( - state_strategy, output_scale=output_scale_prior, correction=correction + state_strategy, + output_scale=output_scale_prior, + extrapolation=extrapolation, + correction=correction, ) observed = state_strategy.aux_corr @@ -932,10 +961,17 @@ def step(state: _SolverState, *, vector_field, dt, calibration): del calibration # unused error, _observed, state_strategy = strategy.begin( - state.strategy, dt=dt, vector_field=vector_field, correction=correction + state.strategy, + dt=dt, + vector_field=vector_field, + extrapolation=extrapolation, + correction=correction, ) state_strategy = strategy.complete( - state_strategy, output_scale=state.output_scale, correction=correction + state_strategy, + output_scale=state.output_scale, + extrapolation=extrapolation, + correction=correction, ) # Extract and return solution state = _SolverState(strategy=state_strategy, output_scale=state.output_scale) From a7e0ae71ef6a97c467e04f0304f70e6570de8e99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 09:33:32 +0200 Subject: [PATCH 03/24] Move is_suitable_for_offgrid_marginals and ssm --- .../use_equinox_bounded_while_loop.py | 2 +- .../neural_ode.py | 8 +++---- .../physics_enhanced_regression_1.py | 2 +- .../physics_enhanced_regression_2.py | 4 ++-- .../conditioning-on-zero-residual.py | 3 ++- probdiffeq/ivpsolvers.py | 24 +++++++++++++------ .../test_save_at_vs_save_every_step.py | 2 +- .../test_strategy_smoother_vs_filter.py | 4 ++-- .../test_strategy_smoother_vs_fixedpoint.py | 8 +++---- ..._strategy_warnings_for_wrong_strategies.py | 4 ++-- .../test_log_marginal_likelihood.py | 4 ++-- ...log_marginal_likelihood_terminal_values.py | 2 +- tests/test_stats/test_offgrid_marginals.py | 4 ++-- tests/test_stats/test_sample.py | 2 +- 14 files changed, 42 insertions(+), 31 deletions(-) diff --git a/docs/examples_misc/use_equinox_bounded_while_loop.py b/docs/examples_misc/use_equinox_bounded_while_loop.py index 2beedf9c..e8b4847e 100644 --- a/docs/examples_misc/use_equinox_bounded_while_loop.py +++ b/docs/examples_misc/use_equinox_bounded_while_loop.py @@ -65,7 +65,7 @@ def vf(y, *, t): # noqa: ARG001 ts0 = ivpsolvers.correction_ts0(ode_order=1, ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm) init = solver.initial_condition() diff --git a/docs/examples_parameter_estimation/neural_ode.py b/docs/examples_parameter_estimation/neural_ode.py index fa7d425c..57e77e85 100644 --- a/docs/examples_parameter_estimation/neural_ode.py +++ b/docs/examples_parameter_estimation/neural_ode.py @@ -71,7 +71,7 @@ def loss_fn(parameters): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver_ts0 = ivpsolvers.solver(strategy) + solver_ts0 = ivpsolvers.solver(strategy, ssm=ssm) init = solver_ts0.initial_condition() sol = ivpsolve.solve_fixed_grid( @@ -129,7 +129,7 @@ def vf(y, *, t, p): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) -solver_ts0 = ivpsolvers.solver(strategy) +solver_ts0 = ivpsolvers.solver(strategy, ssm=ssm) init = solver_ts0.initial_condition() # + @@ -169,7 +169,7 @@ def vf(y, *, t, p): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) -solver_ts0 = ivpsolvers.solver(strategy) +solver_ts0 = ivpsolvers.solver(strategy, ssm=ssm) init = solver_ts0.initial_condition() sol = ivpsolve.solve_fixed_grid( @@ -183,7 +183,7 @@ def vf(y, *, t, p): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) -solver_ts0 = ivpsolvers.solver(strategy) +solver_ts0 = ivpsolvers.solver(strategy, ssm=ssm) init = solver_ts0.initial_condition() sol = ivpsolve.solve_fixed_grid( diff --git a/docs/examples_parameter_estimation/physics_enhanced_regression_1.py b/docs/examples_parameter_estimation/physics_enhanced_regression_1.py index f835ce46..3002971f 100644 --- a/docs/examples_parameter_estimation/physics_enhanced_regression_1.py +++ b/docs/examples_parameter_estimation/physics_enhanced_regression_1.py @@ -73,7 +73,7 @@ def solve(p): ) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) init = solver.initial_condition() return ivpsolve.solve_fixed_grid( diff --git a/docs/examples_parameter_estimation/physics_enhanced_regression_2.py b/docs/examples_parameter_estimation/physics_enhanced_regression_2.py index e585d249..936d36c2 100644 --- a/docs/examples_parameter_estimation/physics_enhanced_regression_2.py +++ b/docs/examples_parameter_estimation/physics_enhanced_regression_2.py @@ -192,7 +192,7 @@ def solve_fixed(theta, *, ts): ) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) init = solver.initial_condition() sol = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm) return sol[-1] @@ -209,7 +209,7 @@ def solve_adaptive(theta, *, save_at): ) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm) init = solver.initial_condition() diff --git a/docs/examples_solver_config/conditioning-on-zero-residual.py b/docs/examples_solver_config/conditioning-on-zero-residual.py index ecc9b67d..3e00646a 100644 --- a/docs/examples_solver_config/conditioning-on-zero-residual.py +++ b/docs/examples_solver_config/conditioning-on-zero-residual.py @@ -79,7 +79,8 @@ def vector_field(y, t): # noqa: ARG001 ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="dense") slr1 = ivpsolvers.correction_ts1(ssm=ssm) -solver = ivpsolvers.solver(ivpsolvers.strategy_fixedpoint(ibm, slr1, ssm=ssm)) +strategy = ivpsolvers.strategy_fixedpoint(ibm, slr1, ssm=ssm) +solver = ivpsolvers.solver(strategy, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-1, rtol=1e-2, ssm=ssm) dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,)) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 76432389..a47ca136 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -179,6 +179,7 @@ class _ExtraImpl: is_suitable_for_save_at: int is_suitable_for_save_every_step: int + is_suitable_for_offgrid_marginals: int def initial_condition(self): """Compute an initial condition from a set of Taylor coefficients.""" @@ -573,9 +574,6 @@ class _StrategyState(containers.NamedTuple): class _Strategy: """Estimation strategy.""" - ssm: Any - - is_suitable_for_offgrid_marginals: int prior: _MarkovProcess @property @@ -676,9 +674,11 @@ def offgrid_marginals( output_scale, extrapolation, correction, + ssm, + is_suitable_for_offgrid_marginals, ): """Compute offgrid_marginals.""" - if not self.is_suitable_for_offgrid_marginals: + if not is_suitable_for_offgrid_marginals: raise NotImplementedError dt0 = t - t0 @@ -696,7 +696,7 @@ def offgrid_marginals( ) (marginals, _aux) = interp.interpolated - u = self.ssm.stats.qoi(marginals) + u = ssm.stats.qoi(marginals) return u, marginals @@ -738,8 +738,9 @@ def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy: ssm=ssm, is_suitable_for_save_at=True, is_suitable_for_save_every_step=True, + is_suitable_for_offgrid_marginals=True, ) - strategy = _Strategy(prior=prior, is_suitable_for_offgrid_marginals=True, ssm=ssm) + strategy = _Strategy(prior=prior) return strategy, correction, extrapolation @@ -770,6 +771,7 @@ class _ProbabilisticSolver: step_implementation: Callable + ssm: Any extrapolation: _ExtraImpl calibration: _Calibration correction: _Correction @@ -779,14 +781,20 @@ def offgrid_marginals(self, *args, **kwargs): return self.strategy.offgrid_marginals( *args, **kwargs, + ssm=self.ssm, extrapolation=self.extrapolation, correction=self.correction, + is_suitable_for_offgrid_marginals=self.is_suitable_for_offgrid_marginals, ) @property def error_contraction_rate(self): return self.strategy.num_derivatives + 1 + @property + def is_suitable_for_offgrid_marginals(self): + return self.extrapolation.is_suitable_for_offgrid_marginals + @property def is_suitable_for_save_at(self): return self.extrapolation.is_suitable_for_save_at @@ -881,6 +889,7 @@ def step_mle(state, /, *, dt, vector_field, calibration): return dt * error, state return _ProbabilisticSolver( + ssm=ssm, name="Probabilistic solver with MLE calibration", calibration=_calibration_running_mean(ssm=ssm), step_implementation=step_mle, @@ -953,7 +962,7 @@ def extract(state, /): return _Calibration(init=init, update=update, extract=extract) -def solver(inputs, /): +def solver(inputs, /, *, ssm): """Create a solver that does not calibrate the output scale automatically.""" strategy, correction, extrapolation = inputs @@ -979,6 +988,7 @@ def step(state: _SolverState, *, vector_field, dt, calibration): return _ProbabilisticSolver( strategy=strategy, + ssm=ssm, extrapolation=extrapolation, correction=correction, calibration=_calibration_none(), diff --git a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py index edb82a18..ffc80402 100644 --- a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py @@ -16,7 +16,7 @@ def test_save_at_result_matches_interpolated_adaptive_result(fact): ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py b/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py index 3ab46254..30277d50 100644 --- a/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py +++ b/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py @@ -21,7 +21,7 @@ def fixture_filter_solution(solver_setup): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=solver_setup["fact"]) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) init = solver.initial_condition() return ivpsolve.solve_fixed_grid( @@ -35,7 +35,7 @@ def fixture_smoother_solution(solver_setup): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=solver_setup["fact"]) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) init = solver.initial_condition() return ivpsolve.solve_fixed_grid( diff --git a/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py b/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py index 2852edd0..3bcbcc8f 100644 --- a/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py +++ b/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py @@ -23,7 +23,7 @@ def fixture_solution_smoother(solver_setup): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-3, rtol=1e-3, ssm=ssm) init = solver.initial_condition() @@ -44,7 +44,7 @@ def test_fixedpoint_smoother_equivalent_same_grid(solver_setup, solution_smoothe ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-3, rtol=1e-3, ssm=ssm) save_at = solution_smoother.t @@ -70,7 +70,7 @@ def test_fixedpoint_smoother_equivalent_different_grid(solver_setup, solution_sm ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver_smoother = ivpsolvers.solver(strategy) + solver_smoother = ivpsolvers.solver(strategy, ssm=ssm) # Compute the offgrid-marginals ts = np.linspace(save_at[0], save_at[-1], num=7, endpoint=True) @@ -83,7 +83,7 @@ def test_fixedpoint_smoother_equivalent_different_grid(solver_setup, solution_sm ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-3, rtol=1e-3, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py b/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py index bbbcf1a8..354eea55 100644 --- a/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py +++ b/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py @@ -14,7 +14,7 @@ def test_warning_for_fixedpoint_in_save_every_step_mode(fact): ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() @@ -33,7 +33,7 @@ def test_warning_for_smoother_in_save_at_mode(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_stats/test_log_marginal_likelihood.py b/tests/test_stats/test_log_marginal_likelihood.py index e1f54bd9..d1e67c3a 100644 --- a/tests/test_stats/test_log_marginal_likelihood.py +++ b/tests/test_stats/test_log_marginal_likelihood.py @@ -15,7 +15,7 @@ def fixture_solution(fact): ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() @@ -97,7 +97,7 @@ def test_raises_error_for_filter(fact): ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) grid = np.linspace(t0, t1, num=3) init = solver.initial_condition() diff --git a/tests/test_stats/test_log_marginal_likelihood_terminal_values.py b/tests/test_stats/test_log_marginal_likelihood_terminal_values.py index 172a391a..8776bb4a 100644 --- a/tests/test_stats/test_log_marginal_likelihood_terminal_values.py +++ b/tests/test_stats/test_log_marginal_likelihood_terminal_values.py @@ -30,7 +30,7 @@ def fixture_solution(strategy_func, fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = strategy_func(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_stats/test_offgrid_marginals.py b/tests/test_stats/test_offgrid_marginals.py index 5ae6c420..2ebcaf58 100644 --- a/tests/test_stats/test_offgrid_marginals.py +++ b/tests/test_stats/test_offgrid_marginals.py @@ -14,7 +14,7 @@ def test_filter_marginals_close_only_to_left_boundary(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) init = solver.initial_condition() grid = np.linspace(t0, t1, endpoint=True, num=5) @@ -38,7 +38,7 @@ def test_smoother_marginals_close_to_both_boundaries(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) init = solver.initial_condition() grid = np.linspace(t0, t1, endpoint=True, num=5) diff --git a/tests/test_stats/test_sample.py b/tests/test_stats/test_sample.py index 9beddef9..fe6c75eb 100644 --- a/tests/test_stats/test_sample.py +++ b/tests/test_stats/test_sample.py @@ -13,7 +13,7 @@ def fixture_approximation(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy) + solver = ivpsolvers.solver(strategy, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() From c4ad501133a47984491c657746406ce4986f9345 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 09:39:15 +0200 Subject: [PATCH 04/24] Move the prior --- probdiffeq/ivpsolvers.py | 31 ++++++++----------- .../test_calibration_mle_vs_none.py | 12 +++---- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index a47ca136..07154f83 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -574,12 +574,6 @@ class _StrategyState(containers.NamedTuple): class _Strategy: """Estimation strategy.""" - prior: _MarkovProcess - - @property - def num_derivatives(self): - return self.prior.num_derivatives - def init(self, t, posterior, /, *, extrapolation, correction) -> _StrategyState: """Initialise a state from a posterior.""" rv, extra = extrapolation.init(posterior) @@ -717,17 +711,15 @@ def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy: def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy: """Construct a fixedpoint-smoother.""" extrapolation = _ExtraImplFixedPoint( - prior=prior, name="Fixed-point smoother", ssm=ssm - ) - strategy = _Strategy( - extrapolation=extrapolation, - ssm=ssm, prior=prior, + name="Fixed-point smoother", + ssm=ssm, is_suitable_for_save_at=True, is_suitable_for_save_every_step=False, is_suitable_for_offgrid_marginals=False, ) - return strategy, correction + strategy = _Strategy() + return strategy, correction, extrapolation, prior def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy: @@ -740,8 +732,8 @@ def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy: is_suitable_for_save_every_step=True, is_suitable_for_offgrid_marginals=True, ) - strategy = _Strategy(prior=prior) - return strategy, correction, extrapolation + strategy = _Strategy() + return strategy, correction, extrapolation, prior @containers.dataclass @@ -771,6 +763,7 @@ class _ProbabilisticSolver: step_implementation: Callable + prior: _MarkovProcess ssm: Any extrapolation: _ExtraImpl calibration: _Calibration @@ -789,7 +782,7 @@ def offgrid_marginals(self, *args, **kwargs): @property def error_contraction_rate(self): - return self.strategy.num_derivatives + 1 + return self.prior.num_derivatives + 1 @property def is_suitable_for_offgrid_marginals(self): @@ -852,7 +845,7 @@ def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: def initial_condition(self): """Construct an initial condition.""" posterior = self.strategy.initial_condition(extrapolation=self.extrapolation) - return posterior, self.strategy.prior.output_scale + return posterior, self.prior.output_scale def solver_mle(inputs, *, ssm): @@ -861,7 +854,7 @@ def solver_mle(inputs, *, ssm): Warning: needs to be combined with a call to stats.calibrate() after solving if the MLE-calibration shall be *used*. """ - strategy, correction, extrapolation = inputs + strategy, correction, extrapolation, prior = inputs def step_mle(state, /, *, dt, vector_field, calibration): output_scale_prior, _calibrated = calibration.extract(state.output_scale) @@ -891,6 +884,7 @@ def step_mle(state, /, *, dt, vector_field, calibration): return _ProbabilisticSolver( ssm=ssm, name="Probabilistic solver with MLE calibration", + prior=prior, calibration=_calibration_running_mean(ssm=ssm), step_implementation=step_mle, extrapolation=extrapolation, @@ -964,7 +958,7 @@ def extract(state, /): def solver(inputs, /, *, ssm): """Create a solver that does not calibrate the output scale automatically.""" - strategy, correction, extrapolation = inputs + strategy, correction, extrapolation, prior = inputs def step(state: _SolverState, *, vector_field, dt, calibration): del calibration # unused @@ -989,6 +983,7 @@ def step(state: _SolverState, *, vector_field, dt, calibration): return _ProbabilisticSolver( strategy=strategy, ssm=ssm, + prior=prior, extrapolation=extrapolation, correction=correction, calibration=_calibration_none(), diff --git a/tests/test_ivpsolvers/test_calibration_mle_vs_none.py b/tests/test_ivpsolvers/test_calibration_mle_vs_none.py index 78bbdfc6..ff746ef1 100644 --- a/tests/test_ivpsolvers/test_calibration_mle_vs_none.py +++ b/tests/test_ivpsolvers/test_calibration_mle_vs_none.py @@ -6,8 +6,8 @@ """ from probdiffeq import ivpsolve, ivpsolvers, stats, taylor -from probdiffeq.backend import functools, ode, testing from probdiffeq.backend import numpy as np +from probdiffeq.backend import ode, testing @testing.case() @@ -22,7 +22,7 @@ def case_solve_fixed_grid(fact): def solver_to_solution(solver_fun, strategy_fun): strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = solver_fun(strategy) + solver = solver_fun(strategy, ssm=ssm) init = solver.initial_condition() return ivpsolve.solve_fixed_grid(vf, init, solver=solver, **kwargs) @@ -45,7 +45,7 @@ def case_solve_adaptive_save_at(fact): def solver_to_solution(solver_fun, strategy_fun): strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = solver_fun(strategy) + solver = solver_fun(strategy, ssm=ssm) init = solver.initial_condition() adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) @@ -70,7 +70,7 @@ def case_solve_adaptive_save_every_step(fact): def solver_to_solution(solver_fun, strategy_fun): strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = solver_fun(strategy) + solver = solver_fun(strategy, ssm=ssm) init = solver.initial_condition() adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) @@ -95,7 +95,7 @@ def case_simulate_terminal_values(fact): def solver_to_solution(solver_fun, strategy_fun): strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = solver_fun(strategy) + solver = solver_fun(strategy, ssm=ssm) init = solver.initial_condition() adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm, atol=1e-2, rtol=1e-2) @@ -114,7 +114,7 @@ def solver_to_solution(solver_fun, strategy_fun): def fixture_uncalibrated_and_mle_solution(solver_to_solution, strategy_fun): solve, ssm = solver_to_solution uncalib = solve(ivpsolvers.solver, strategy_fun) - mle = solve(functools.partial(ivpsolvers.solver_mle, ssm=ssm), strategy_fun) + mle = solve(ivpsolvers.solver_mle, strategy_fun) return (uncalib, mle), ssm From 3ad0300e3c714593443bea6955b9334d7dc34557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 09:41:05 +0200 Subject: [PATCH 05/24] Update the smoother --- probdiffeq/ivpsolvers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 07154f83..05447894 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -696,16 +696,16 @@ def offgrid_marginals( def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy: """Construct a smoother.""" - extrapolation = _ExtraImplSmoother(prior=prior, name="Smoother", ssm=ssm) - strategy = _Strategy( - extrapolation=extrapolation, + extrapolation = _ExtraImplSmoother( prior=prior, + name="Smoother", ssm=ssm, is_suitable_for_save_at=False, is_suitable_for_save_every_step=True, is_suitable_for_offgrid_marginals=True, ) - return strategy, correction + strategy = _Strategy() + return strategy, correction, extrapolation, prior def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy: From 089f261477a44a014a5c444e1824e9a7290604d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 10:11:53 +0200 Subject: [PATCH 06/24] Move prior out of ExtraImpl --- probdiffeq/ivpsolvers.py | 141 +++++++++++++++++++++++++-------------- 1 file changed, 90 insertions(+), 51 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 05447894..7f7668a6 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -173,7 +173,6 @@ def _tensor_points(x, /, *, d): class _ExtraImpl: """Extrapolation model interface.""" - prior: _MarkovProcess name: str ssm: Any @@ -181,7 +180,7 @@ class _ExtraImpl: is_suitable_for_save_every_step: int is_suitable_for_offgrid_marginals: int - def initial_condition(self): + def initial_condition(self, *, prior): """Compute an initial condition from a set of Taylor coefficients.""" raise NotImplementedError @@ -189,7 +188,7 @@ def init(self, sol: stats.MarkovSeq, /): """Initialise a state from a solution.""" raise NotImplementedError - def begin(self, rv, _extra, /, dt): + def begin(self, rv, _extra, /, *, prior_discretized): """Begin the extrapolation.""" raise NotImplementedError @@ -201,27 +200,27 @@ def extract(self, hidden_state, extra, /): """Extract a solution from a state.""" raise NotImplementedError - def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): + def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale, prior): """Interpolate.""" raise NotImplementedError - def interpolate_at_t1(self, rv, extra, /): + def interpolate_at_t1(self, rv, extra, /, *, prior): """Process the state at checkpoint t=t_n.""" raise NotImplementedError @containers.dataclass class _ExtraImplSmoother(_ExtraImpl): - def initial_condition(self): - rv = self.ssm.normal.from_tcoeffs(self.prior.tcoeffs) - cond = self.ssm.conditional.identity(len(self.prior.tcoeffs)) + def initial_condition(self, *, prior): + rv = self.ssm.normal.from_tcoeffs(prior.tcoeffs) + cond = self.ssm.conditional.identity(len(prior.tcoeffs)) return stats.MarkovSeq(init=rv, conditional=cond) def init(self, sol: stats.MarkovSeq, /): return sol.init, sol.conditional - def begin(self, rv, _extra, /, dt): - cond, (p, p_inv) = self.prior.discretize(dt) + def begin(self, rv, _extra, /, *, prior_discretized): + cond, (p, p_inv) = prior_discretized rv_p = self.ssm.normal.preconditioner_apply(rv, p_inv) @@ -248,7 +247,7 @@ def complete(self, _ssv, extra, /, output_scale): def extract(self, hidden_state, extra, /): return stats.MarkovSeq(init=hidden_state, conditional=extra) - def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): + def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale, prior): """Interpolate. A smoother interpolates by_ @@ -263,9 +262,14 @@ def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): Subsequent IVP solver steps continue from the value at 't1'. """ # Extrapolate from t0 to t, and from t to t1. This yields all building blocks. - extrapolated_t = self._extrapolate(*state_t0, dt=dt0, output_scale=output_scale) + prior0 = prior.discretize(dt0) + extrapolated_t = self._extrapolate( + *state_t0, output_scale=output_scale, prior_discretized=prior0 + ) + + prior1 = prior.discretize(dt1) extrapolated_t1 = self._extrapolate( - *extrapolated_t, dt=dt1, output_scale=output_scale + *extrapolated_t, output_scale=output_scale, prior_discretized=prior1 ) # Marginalise from t1 to t to obtain the interpolated solution. @@ -283,11 +287,12 @@ def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): interp_from=solution_at_t, ) - def _extrapolate(self, state, extra, /, *, dt, output_scale): - state, cache = self.begin(state, extra, dt=dt) + def _extrapolate(self, state, extra, /, *, output_scale, prior_discretized): + state, cache = self.begin(state, extra, prior_discretized=prior_discretized) return self.complete(state, cache, output_scale=output_scale) - def interpolate_at_t1(self, rv, extra, /): + def interpolate_at_t1(self, rv, extra, /, *, prior): + del prior return _InterpRes((rv, extra), (rv, extra), (rv, extra)) @@ -296,11 +301,11 @@ class _ExtraImplFilter(_ExtraImpl): def init(self, sol, /): return sol, None - def initial_condition(self): - return self.ssm.normal.from_tcoeffs(self.prior.tcoeffs) + def initial_condition(self, *, prior): + return self.ssm.normal.from_tcoeffs(prior.tcoeffs) - def begin(self, rv, _extra, /, dt): - cond, (p, p_inv) = self.prior.discretize(dt) + def begin(self, rv, _extra, /, prior_discretized): + cond, (p, p_inv) = prior_discretized rv_p = self.ssm.normal.preconditioner_apply(rv, p_inv) @@ -326,13 +331,14 @@ def complete(self, _ssv, extra, /, output_scale): # Gather and return return extrapolated, None - def interpolate(self, state_t0, marginal_t1, dt0, dt1, output_scale): + def interpolate(self, state_t0, marginal_t1, dt0, dt1, output_scale, *, prior): # todo: by ditching marginal_t1 and dt1, this function _extrapolates # (no *inter*polation happening) del dt1 hidden, extra = state_t0 - hidden, extra = self.begin(hidden, extra, dt=dt0) + prior0 = prior.discretize(dt0) + hidden, extra = self.begin(hidden, extra, prior_discretized=prior0) hidden, extra = self.complete(hidden, extra, output_scale=output_scale) # Consistent state-types in interpolation result. @@ -340,22 +346,23 @@ def interpolate(self, state_t0, marginal_t1, dt0, dt1, output_scale): step_from = (marginal_t1, None) return _InterpRes(step_from=step_from, interpolated=interp, interp_from=interp) - def interpolate_at_t1(self, rv, extra, /): + def interpolate_at_t1(self, rv, extra, /, *, prior): + del prior return _InterpRes((rv, extra), (rv, extra), (rv, extra)) @containers.dataclass class _ExtraImplFixedPoint(_ExtraImpl): - def initial_condition(self): - rv = self.ssm.normal.from_tcoeffs(self.prior.tcoeffs) - cond = self.ssm.conditional.identity(len(self.prior.tcoeffs)) + def initial_condition(self, prior): + rv = self.ssm.normal.from_tcoeffs(prior.tcoeffs) + cond = self.ssm.conditional.identity(len(prior.tcoeffs)) return stats.MarkovSeq(init=rv, conditional=cond) def init(self, sol: stats.MarkovSeq, /): return sol.init, sol.conditional - def begin(self, rv, extra, /, dt): - cond, (p, p_inv) = self.prior.discretize(dt) + def begin(self, rv, extra, /, prior_discretized): + cond, (p, p_inv) = prior_discretized rv_p = self.ssm.normal.preconditioner_apply(rv, p_inv) @@ -385,7 +392,7 @@ def complete(self, _rv, extra, /, output_scale): # Gather and return return extrapolated, cond - def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): + def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale, prior): """Interpolate. A fixed-point smoother interpolates by @@ -424,11 +431,16 @@ def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): then don't understand why tests fail.) """ # Extrapolate from t0 to t, and from t to t1. This yields all building blocks. - extrapolated_t = self._extrapolate(*state_t0, dt=dt0, output_scale=output_scale) - conditional_id = self.ssm.conditional.identity(self.prior.num_derivatives + 1) + prior0 = prior.discretize(dt0) + extrapolated_t = self._extrapolate( + *state_t0, output_scale=output_scale, prior_discretized=prior0 + ) + conditional_id = self.ssm.conditional.identity(prior.num_derivatives + 1) previous_new = (extrapolated_t[0], conditional_id) + + prior1 = prior.discretize(dt1) extrapolated_t1 = self._extrapolate( - *previous_new, dt=dt1, output_scale=output_scale + *previous_new, output_scale=output_scale, prior_discretized=prior1 ) # Marginalise from t1 to t to obtain the interpolated solution. @@ -442,12 +454,12 @@ def interpolate(self, state_t0, marginal_t1, *, dt0, dt1, output_scale): interp_from=previous_new, ) - def _extrapolate(self, state, extra, /, *, dt, output_scale): - x, cache = self.begin(state, extra, dt=dt) + def _extrapolate(self, state, extra, /, *, output_scale, prior_discretized): + x, cache = self.begin(state, extra, prior_discretized=prior_discretized) return self.complete(x, cache, output_scale=output_scale) - def interpolate_at_t1(self, rv, extra, /): - cond_identity = self.ssm.conditional.identity(self.prior.num_derivatives + 1) + def interpolate_at_t1(self, rv, extra, /, *, prior): + cond_identity = self.ssm.conditional.identity(prior.num_derivatives + 1) return _InterpRes((rv, cond_identity), (rv, extra), (rv, cond_identity)) @@ -580,15 +592,26 @@ def init(self, t, posterior, /, *, extrapolation, correction) -> _StrategyState: rv, corr = correction.init(rv) return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) - def initial_condition(self, *, extrapolation): + def initial_condition(self, *, extrapolation, prior): """Construct an initial condition from a set of Taylor coefficients.""" - return extrapolation.initial_condition() + return extrapolation.initial_condition(prior=prior) def begin( - self, state: _StrategyState, /, *, dt, vector_field, extrapolation, correction + self, + state: _StrategyState, + /, + *, + dt, + vector_field, + prior, + extrapolation, + correction, ): """Predict the error of an upcoming step.""" - hidden, extra = extrapolation.begin(state.hidden, state.aux_extra, dt=dt) + prior_discretized = prior.discretize(dt) + hidden, extra = extrapolation.begin( + state.hidden, state.aux_extra, prior_discretized=prior_discretized + ) t = state.t + dt error, observed, corr = correction.estimate_error( hidden, vector_field=vector_field, t=t @@ -611,14 +634,16 @@ def extract(self, state: _StrategyState, /, *, extrapolation, correction): return state.t, sol def case_interpolate_at_t1( - self, state_t1: _StrategyState, *, extrapolation + self, state_t1: _StrategyState, *, extrapolation, prior ) -> _InterpRes: """Process the solution in case t=t_n.""" - _tmp = extrapolation.interpolate_at_t1(state_t1.hidden, state_t1.aux_extra) + tmp = extrapolation.interpolate_at_t1( + state_t1.hidden, state_t1.aux_extra, prior=prior + ) step_from, solution, interp_from = ( - _tmp.step_from, - _tmp.interpolated, - _tmp.interp_from, + tmp.step_from, + tmp.interpolated, + tmp.interp_from, ) def _state(x): @@ -632,7 +657,14 @@ def _state(x): return _InterpRes(step_from, solution, interp_from) def case_interpolate( - self, t, *, s0: _StrategyState, s1: _StrategyState, output_scale, extrapolation + self, + t, + *, + s0: _StrategyState, + s1: _StrategyState, + output_scale, + extrapolation, + prior, ) -> _InterpRes: """Process the solution in case t>t_n.""" # Interpolate @@ -642,6 +674,7 @@ def case_interpolate( dt0=t - s0.t, dt1=s1.t - t, output_scale=output_scale, + prior=prior, ) # Turn outputs into valid states @@ -668,6 +701,7 @@ def offgrid_marginals( output_scale, extrapolation, correction, + prior, ssm, is_suitable_for_offgrid_marginals, ): @@ -681,12 +715,14 @@ def offgrid_marginals( t0, posterior_t0, extrapolation=extrapolation, correction=correction ) + # TODO: Replace dt0, dt1, and prior with prior_dt0, and prior_dt1 interp = extrapolation.interpolate( state_t0=(state_t0.hidden, state_t0.aux_extra), marginal_t1=marginals_t1, dt0=dt0, dt1=dt1, output_scale=output_scale, + prior=prior, ) (marginals, _aux) = interp.interpolated @@ -697,7 +733,6 @@ def offgrid_marginals( def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy: """Construct a smoother.""" extrapolation = _ExtraImplSmoother( - prior=prior, name="Smoother", ssm=ssm, is_suitable_for_save_at=False, @@ -711,7 +746,6 @@ def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy: def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy: """Construct a fixedpoint-smoother.""" extrapolation = _ExtraImplFixedPoint( - prior=prior, name="Fixed-point smoother", ssm=ssm, is_suitable_for_save_at=True, @@ -725,7 +759,6 @@ def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy: def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy: """Construct a filter.""" extrapolation = _ExtraImplFilter( - prior=prior, name="Filter", ssm=ssm, is_suitable_for_save_at=True, @@ -778,6 +811,7 @@ def offgrid_marginals(self, *args, **kwargs): extrapolation=self.extrapolation, correction=self.correction, is_suitable_for_offgrid_marginals=self.is_suitable_for_offgrid_marginals, + prior=self.prior, ) @property @@ -826,6 +860,7 @@ def interpolate( s1=interp_to.strategy, output_scale=output_scale, extrapolation=self.extrapolation, + prior=self.prior, ) prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale) sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale) @@ -834,7 +869,7 @@ def interpolate( def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: x = self.strategy.case_interpolate_at_t1( - interp_to.strategy, extrapolation=self.extrapolation + interp_to.strategy, extrapolation=self.extrapolation, prior=self.prior ) prev = _SolverState(x.interp_from, output_scale=interp_from.output_scale) @@ -844,7 +879,9 @@ def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: def initial_condition(self): """Construct an initial condition.""" - posterior = self.strategy.initial_condition(extrapolation=self.extrapolation) + posterior = self.strategy.initial_condition( + prior=self.prior, extrapolation=self.extrapolation + ) return posterior, self.prior.output_scale @@ -864,6 +901,7 @@ def step_mle(state, /, *, dt, vector_field, calibration): vector_field=vector_field, extrapolation=extrapolation, correction=correction, + prior=prior, ) state_strategy = strategy.complete( @@ -969,6 +1007,7 @@ def step(state: _SolverState, *, vector_field, dt, calibration): vector_field=vector_field, extrapolation=extrapolation, correction=correction, + prior=prior, ) state_strategy = strategy.complete( state_strategy, From 889a42ad0cda0a2bc77cc23d1a3a0c9212009dc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 10:13:29 +0200 Subject: [PATCH 07/24] Delete Strategy.initial_condition --- probdiffeq/ivpsolvers.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 7f7668a6..e1932b7c 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -592,10 +592,6 @@ def init(self, t, posterior, /, *, extrapolation, correction) -> _StrategyState: rv, corr = correction.init(rv) return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) - def initial_condition(self, *, extrapolation, prior): - """Construct an initial condition from a set of Taylor coefficients.""" - return extrapolation.initial_condition(prior=prior) - def begin( self, state: _StrategyState, @@ -879,9 +875,7 @@ def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: def initial_condition(self): """Construct an initial condition.""" - posterior = self.strategy.initial_condition( - prior=self.prior, extrapolation=self.extrapolation - ) + posterior = self.extrapolation.initial_condition(prior=self.prior) return posterior, self.prior.output_scale From 3882d8ed5dfdf558d210f52e59264829dd6f4f86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 10:18:16 +0200 Subject: [PATCH 08/24] Move init() and offgrid_marginals() --- probdiffeq/ivpsolvers.py | 96 +++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 55 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index e1932b7c..87467c87 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -586,11 +586,12 @@ class _StrategyState(containers.NamedTuple): class _Strategy: """Estimation strategy.""" - def init(self, t, posterior, /, *, extrapolation, correction) -> _StrategyState: - """Initialise a state from a posterior.""" - rv, extra = extrapolation.init(posterior) - rv, corr = correction.init(rv) - return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) + # + # def init(self, t, posterior, /, *, extrapolation, correction) -> _StrategyState: + # """Initialise a state from a posterior.""" + # rv, extra = extrapolation.init(posterior) + # rv, corr = correction.init(rv) + # return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) def begin( self, @@ -686,45 +687,6 @@ def _state(t_, x): step_from=step_from, interpolated=interpolated, interp_from=interp_from ) - def offgrid_marginals( - self, - *, - t, - marginals_t1, - posterior_t0, - t0, - t1, - output_scale, - extrapolation, - correction, - prior, - ssm, - is_suitable_for_offgrid_marginals, - ): - """Compute offgrid_marginals.""" - if not is_suitable_for_offgrid_marginals: - raise NotImplementedError - - dt0 = t - t0 - dt1 = t1 - t - state_t0 = self.init( - t0, posterior_t0, extrapolation=extrapolation, correction=correction - ) - - # TODO: Replace dt0, dt1, and prior with prior_dt0, and prior_dt1 - interp = extrapolation.interpolate( - state_t0=(state_t0.hidden, state_t0.aux_extra), - marginal_t1=marginals_t1, - dt0=dt0, - dt1=dt1, - output_scale=output_scale, - prior=prior, - ) - - (marginals, _aux) = interp.interpolated - u = ssm.stats.qoi(marginals) - return u, marginals - def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy: """Construct a smoother.""" @@ -799,17 +761,32 @@ class _ProbabilisticSolver: correction: _Correction strategy: _Strategy - def offgrid_marginals(self, *args, **kwargs): - return self.strategy.offgrid_marginals( - *args, - **kwargs, - ssm=self.ssm, - extrapolation=self.extrapolation, - correction=self.correction, - is_suitable_for_offgrid_marginals=self.is_suitable_for_offgrid_marginals, + def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale): + """Compute offgrid_marginals.""" + if not self.is_suitable_for_offgrid_marginals: + raise NotImplementedError + + dt0 = t - t0 + dt1 = t1 - t + + rv, extra = self.extrapolation.init(posterior_t0) + rv, corr = self.correction.init(rv) + state_t0 = _StrategyState(t=t0, hidden=rv, aux_extra=extra, aux_corr=corr) + + # TODO: Replace dt0, dt1, and prior with prior_dt0, and prior_dt1 + interp = self.extrapolation.interpolate( + state_t0=(state_t0.hidden, state_t0.aux_extra), + marginal_t1=marginals_t1, + dt0=dt0, + dt1=dt1, + output_scale=output_scale, prior=self.prior, ) + (marginals, _aux) = interp.interpolated + u = self.ssm.stats.qoi(marginals) + return u, marginals + @property def error_contraction_rate(self): return self.prior.num_derivatives + 1 @@ -826,11 +803,20 @@ def is_suitable_for_save_at(self): def is_suitable_for_save_every_step(self): return self.extrapolation.is_suitable_for_save_every_step + # + # def init(self, t, posterior, /, *, extrapolation, correction) -> _StrategyState: + # """Initialise a state from a posterior.""" + # rv, extra = extrapolation.init(posterior) + # rv, corr = correction.init(rv) + # return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) + def init(self, t, initial_condition) -> _SolverState: posterior, output_scale = initial_condition - state_strategy = self.strategy.init( - t, posterior, correction=self.correction, extrapolation=self.extrapolation - ) + + rv, extra = self.extrapolation.init(posterior) + rv, corr = self.correction.init(rv) + state_strategy = _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) + calib_state = self.calibration.init(output_scale) return _SolverState(strategy=state_strategy, output_scale=calib_state) From bdc898b0aa542b85d319d1801d7685970aeae677 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 10:22:28 +0200 Subject: [PATCH 09/24] Move strategy.begin --- probdiffeq/ivpsolvers.py | 77 ++++++++++++++-------------------------- 1 file changed, 26 insertions(+), 51 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 87467c87..763a3609 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -586,36 +586,6 @@ class _StrategyState(containers.NamedTuple): class _Strategy: """Estimation strategy.""" - # - # def init(self, t, posterior, /, *, extrapolation, correction) -> _StrategyState: - # """Initialise a state from a posterior.""" - # rv, extra = extrapolation.init(posterior) - # rv, corr = correction.init(rv) - # return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) - - def begin( - self, - state: _StrategyState, - /, - *, - dt, - vector_field, - prior, - extrapolation, - correction, - ): - """Predict the error of an upcoming step.""" - prior_discretized = prior.discretize(dt) - hidden, extra = extrapolation.begin( - state.hidden, state.aux_extra, prior_discretized=prior_discretized - ) - t = state.t + dt - error, observed, corr = correction.estimate_error( - hidden, vector_field=vector_field, t=t - ) - state = _StrategyState(t=t, hidden=hidden, aux_extra=extra, aux_corr=corr) - return error, observed, state - def complete(self, state, /, *, output_scale, extrapolation, correction): """Complete the step after the error has been predicted.""" hidden, extra = extrapolation.complete( @@ -803,13 +773,6 @@ def is_suitable_for_save_at(self): def is_suitable_for_save_every_step(self): return self.extrapolation.is_suitable_for_save_every_step - # - # def init(self, t, posterior, /, *, extrapolation, correction) -> _StrategyState: - # """Initialise a state from a posterior.""" - # rv, extra = extrapolation.init(posterior) - # rv, corr = correction.init(rv) - # return _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) - def init(self, t, initial_condition) -> _SolverState: posterior, output_scale = initial_condition @@ -875,13 +838,19 @@ def solver_mle(inputs, *, ssm): def step_mle(state, /, *, dt, vector_field, calibration): output_scale_prior, _calibrated = calibration.extract(state.output_scale) - error, _, state_strategy = strategy.begin( - state.strategy, - dt=dt, - vector_field=vector_field, - extrapolation=extrapolation, - correction=correction, - prior=prior, + + prior_discretized = prior.discretize(dt) + hidden, extra = extrapolation.begin( + state.strategy.hidden, + state.strategy.aux_extra, + prior_discretized=prior_discretized, + ) + t = state.t + dt + error, _, corr = correction.estimate_error( + hidden, vector_field=vector_field, t=t + ) + state_strategy = _StrategyState( + t=t, hidden=hidden, aux_extra=extra, aux_corr=corr ) state_strategy = strategy.complete( @@ -981,14 +950,20 @@ def solver(inputs, /, *, ssm): def step(state: _SolverState, *, vector_field, dt, calibration): del calibration # unused - error, _observed, state_strategy = strategy.begin( - state.strategy, - dt=dt, - vector_field=vector_field, - extrapolation=extrapolation, - correction=correction, - prior=prior, + prior_discretized = prior.discretize(dt) + hidden, extra = extrapolation.begin( + state.strategy.hidden, + state.strategy.aux_extra, + prior_discretized=prior_discretized, + ) + t = state.t + dt + error, _, corr = correction.estimate_error( + hidden, vector_field=vector_field, t=t ) + state_strategy = _StrategyState( + t=t, hidden=hidden, aux_extra=extra, aux_corr=corr + ) + state_strategy = strategy.complete( state_strategy, output_scale=state.output_scale, From ac1b83f4735f1112f8de42d42b2bfdf3619d3b67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 10:31:22 +0200 Subject: [PATCH 10/24] Move strategy.extract --- probdiffeq/ivpsolvers.py | 49 ++++++++++++++++------------------------ 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 763a3609..4ee13e70 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -586,19 +586,12 @@ class _StrategyState(containers.NamedTuple): class _Strategy: """Estimation strategy.""" - def complete(self, state, /, *, output_scale, extrapolation, correction): - """Complete the step after the error has been predicted.""" - hidden, extra = extrapolation.complete( - state.hidden, state.aux_extra, output_scale=output_scale - ) - hidden, corr = correction.complete(hidden, state.aux_corr) - return _StrategyState(t=state.t, hidden=hidden, aux_extra=extra, aux_corr=corr) - - def extract(self, state: _StrategyState, /, *, extrapolation, correction): - """Extract the solution from a state.""" - hidden = correction.extract(state.hidden) - sol = extrapolation.extract(hidden, state.aux_extra) - return state.t, sol + # + # def extract(self, state: _StrategyState, /, *, extrapolation, correction): + # """Extract the solution from a state.""" + # hidden = correction.extract(state.hidden) + # sol = extrapolation.extract(hidden, state.aux_extra) + # return state.t, sol def case_interpolate_at_t1( self, state_t1: _StrategyState, *, extrapolation, prior @@ -789,9 +782,10 @@ def step(self, state: _SolverState, *, vector_field, dt) -> _SolverState: ) def extract(self, state: _SolverState, /): - t, posterior = self.strategy.extract( - state.strategy, extrapolation=self.extrapolation, correction=self.correction - ) + hidden = self.correction.extract(state.strategy.hidden) + posterior = self.extrapolation.extract(hidden, state.strategy.aux_extra) + t = state.strategy.t + _output_scale_prior, output_scale = self.calibration.extract(state.output_scale) return t, (posterior, output_scale) @@ -849,16 +843,14 @@ def step_mle(state, /, *, dt, vector_field, calibration): error, _, corr = correction.estimate_error( hidden, vector_field=vector_field, t=t ) + + hidden, extra = extrapolation.complete( + hidden, extra, output_scale=output_scale_prior + ) + hidden, corr = correction.complete(hidden, corr) state_strategy = _StrategyState( t=t, hidden=hidden, aux_extra=extra, aux_corr=corr ) - - state_strategy = strategy.complete( - state_strategy, - output_scale=output_scale_prior, - extrapolation=extrapolation, - correction=correction, - ) observed = state_strategy.aux_corr # Calibrate @@ -960,16 +952,15 @@ def step(state: _SolverState, *, vector_field, dt, calibration): error, _, corr = correction.estimate_error( hidden, vector_field=vector_field, t=t ) + + hidden, extra = extrapolation.complete( + hidden, extra, output_scale=state.output_scale + ) + hidden, corr = correction.complete(hidden, corr) state_strategy = _StrategyState( t=t, hidden=hidden, aux_extra=extra, aux_corr=corr ) - state_strategy = strategy.complete( - state_strategy, - output_scale=state.output_scale, - extrapolation=extrapolation, - correction=correction, - ) # Extract and return solution state = _SolverState(strategy=state_strategy, output_scale=state.output_scale) return dt * error, state From 8f5828fff5078db5b5234c1f6a418e5e32984c95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 10:37:12 +0200 Subject: [PATCH 11/24] Migrate interpolate_at_t1 --- probdiffeq/ivpsolvers.py | 50 +++++++++++++++------------------------- 1 file changed, 18 insertions(+), 32 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 4ee13e70..2fceb95c 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -586,36 +586,6 @@ class _StrategyState(containers.NamedTuple): class _Strategy: """Estimation strategy.""" - # - # def extract(self, state: _StrategyState, /, *, extrapolation, correction): - # """Extract the solution from a state.""" - # hidden = correction.extract(state.hidden) - # sol = extrapolation.extract(hidden, state.aux_extra) - # return state.t, sol - - def case_interpolate_at_t1( - self, state_t1: _StrategyState, *, extrapolation, prior - ) -> _InterpRes: - """Process the solution in case t=t_n.""" - tmp = extrapolation.interpolate_at_t1( - state_t1.hidden, state_t1.aux_extra, prior=prior - ) - step_from, solution, interp_from = ( - tmp.step_from, - tmp.interpolated, - tmp.interp_from, - ) - - def _state(x): - t = state_t1.t - corr_like = tree_util.tree_map(np.empty_like, state_t1.aux_corr) - return _StrategyState(t=t, hidden=x[0], aux_extra=x[1], aux_corr=corr_like) - - step_from = _state(step_from) - solution = _state(solution) - interp_from = _state(interp_from) - return _InterpRes(step_from, solution, interp_from) - def case_interpolate( self, t, @@ -807,10 +777,26 @@ def interpolate( return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: - x = self.strategy.case_interpolate_at_t1( - interp_to.strategy, extrapolation=self.extrapolation, prior=self.prior + """Process the solution in case t=t_n.""" + tmp = self.extrapolation.interpolate_at_t1( + interp_to.strategy.hidden, interp_to.strategy.aux_extra, prior=self.prior + ) + step_from_, solution_, interp_from_ = ( + tmp.step_from, + tmp.interpolated, + tmp.interp_from, ) + def _state(s): + t = interp_to.strategy.t + corr_like = tree_util.tree_map(np.empty_like, interp_to.strategy.aux_corr) + return _StrategyState(t=t, hidden=s[0], aux_extra=s[1], aux_corr=corr_like) + + step_from_ = _state(step_from_) + solution_ = _state(solution_) + interp_from_ = _state(interp_from_) + x = _InterpRes(step_from_, solution_, interp_from_) + prev = _SolverState(x.interp_from, output_scale=interp_from.output_scale) sol = _SolverState(x.interpolated, output_scale=interp_to.output_scale) acc = _SolverState(x.step_from, output_scale=interp_to.output_scale) From 01c921e337fac944ae3900b89d37b95921c0ee21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 10:39:09 +0200 Subject: [PATCH 12/24] Migrate case_interpolate --- probdiffeq/ivpsolvers.py | 70 +++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 41 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 2fceb95c..41264e1f 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -586,40 +586,6 @@ class _StrategyState(containers.NamedTuple): class _Strategy: """Estimation strategy.""" - def case_interpolate( - self, - t, - *, - s0: _StrategyState, - s1: _StrategyState, - output_scale, - extrapolation, - prior, - ) -> _InterpRes: - """Process the solution in case t>t_n.""" - # Interpolate - interp = extrapolation.interpolate( - state_t0=(s0.hidden, s0.aux_extra), - marginal_t1=s1.hidden, - dt0=t - s0.t, - dt1=s1.t - t, - output_scale=output_scale, - prior=prior, - ) - - # Turn outputs into valid states - - def _state(t_, x): - corr_like = tree_util.tree_map(np.empty_like, s0.aux_corr) - return _StrategyState(t=t_, hidden=x[0], aux_extra=x[1], aux_corr=corr_like) - - step_from = _state(s1.t, interp.step_from) - interpolated = _state(t, interp.interpolated) - interp_from = _state(t, interp.interp_from) - return _InterpRes( - step_from=step_from, interpolated=interpolated, interp_from=interp_from - ) - def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy: """Construct a smoother.""" @@ -763,19 +729,41 @@ def interpolate( self, t, *, interp_from: _SolverState, interp_to: _SolverState ) -> _InterpRes: output_scale, _ = self.calibration.extract(interp_to.output_scale) - interp = self.strategy.case_interpolate( - t, - s0=interp_from.strategy, - s1=interp_to.strategy, - output_scale=output_scale, - extrapolation=self.extrapolation, - prior=self.prior, + interp = self._case_interpolate( + t, s0=interp_from.strategy, s1=interp_to.strategy, output_scale=output_scale ) prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale) sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale) acc = _SolverState(interp.step_from, output_scale=interp_to.output_scale) return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) + def _case_interpolate( + self, t, *, s0: _StrategyState, s1: _StrategyState, output_scale + ) -> _InterpRes: + """Process the solution in case t>t_n.""" + # Interpolate + interp = self.extrapolation.interpolate( + state_t0=(s0.hidden, s0.aux_extra), + marginal_t1=s1.hidden, + dt0=t - s0.t, + dt1=s1.t - t, + output_scale=output_scale, + prior=self.prior, + ) + + # Turn outputs into valid states + + def _state(t_, x): + corr_like = tree_util.tree_map(np.empty_like, s0.aux_corr) + return _StrategyState(t=t_, hidden=x[0], aux_extra=x[1], aux_corr=corr_like) + + step_from = _state(s1.t, interp.step_from) + interpolated = _state(t, interp.interpolated) + interp_from = _state(t, interp.interp_from) + return _InterpRes( + step_from=step_from, interpolated=interpolated, interp_from=interp_from + ) + def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: """Process the solution in case t=t_n.""" tmp = self.extrapolation.interpolate_at_t1( From fa31d3637d95278d7536ee1b7d5a3003e92af3e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 11:15:03 +0200 Subject: [PATCH 13/24] Update the dynamic solver --- probdiffeq/ivpsolvers.py | 67 +++++++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 41264e1f..9cdb1d9b 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -575,19 +575,17 @@ def _estimate_error(observed, /, *, ssm): return output_scale * error_estimate_unscaled -class _StrategyState(containers.NamedTuple): - t: Any - hidden: Any - aux_extra: Any - aux_corr: Any - - -@containers.dataclass -class _Strategy: - """Estimation strategy.""" +# msg = ( +# "Next up: " +# "Delete the _Strategy." +# "Then slowly migrate the strategy-state attributes to the solver state. " +# "Then update the signature of strategy_* functions " +# "by removing correction, prior, and so on and fix all tests" +# ) +# raise RuntimeError(msg) -def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy: +def strategy_smoother(prior, correction: _Correction, /, ssm): """Construct a smoother.""" extrapolation = _ExtraImplSmoother( name="Smoother", @@ -596,11 +594,11 @@ def strategy_smoother(prior, correction: _Correction, /, ssm) -> _Strategy: is_suitable_for_save_every_step=True, is_suitable_for_offgrid_marginals=True, ) - strategy = _Strategy() + strategy = None return strategy, correction, extrapolation, prior -def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy: +def strategy_fixedpoint(prior, correction: _Correction, /, ssm): """Construct a fixedpoint-smoother.""" extrapolation = _ExtraImplFixedPoint( name="Fixed-point smoother", @@ -609,11 +607,11 @@ def strategy_fixedpoint(prior, correction: _Correction, /, ssm) -> _Strategy: is_suitable_for_save_every_step=False, is_suitable_for_offgrid_marginals=False, ) - strategy = _Strategy() + strategy = None return strategy, correction, extrapolation, prior -def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy: +def strategy_filter(prior, correction: _Correction, /, *, ssm): """Construct a filter.""" extrapolation = _ExtraImplFilter( name="Filter", @@ -622,7 +620,7 @@ def strategy_filter(prior, correction: _Correction, /, *, ssm) -> _Strategy: is_suitable_for_save_every_step=True, is_suitable_for_offgrid_marginals=True, ) - strategy = _Strategy() + strategy = None return strategy, correction, extrapolation, prior @@ -635,6 +633,13 @@ class _Calibration: extract: Callable +class _StrategyState(containers.NamedTuple): + t: Any + hidden: Any + aux_extra: Any + aux_corr: Any + + class _SolverState(containers.NamedTuple): """Solver state.""" @@ -658,7 +663,6 @@ class _ProbabilisticSolver: extrapolation: _ExtraImpl calibration: _Calibration correction: _Correction - strategy: _Strategy def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale): """Compute offgrid_marginals.""" @@ -842,7 +846,6 @@ def step_mle(state, /, *, dt, vector_field, calibration): step_implementation=step_mle, extrapolation=extrapolation, correction=correction, - strategy=strategy, requires_rescaling=True, ) @@ -870,25 +873,40 @@ def extract(state, /): return _Calibration(init=init, update=update, extract=extract) -def solver_dynamic(strategy, *, ssm): +def solver_dynamic(inputs, *, ssm): """Create a solver that calibrates the output scale dynamically.""" + strategy, correction, extrapolation, prior = inputs def step_dynamic(state, /, *, dt, vector_field, calibration): - error, observed, state_strategy = strategy.begin( - state.strategy, dt=dt, vector_field=vector_field + prior_discretized = prior.discretize(dt) + hidden, extra = extrapolation.begin( + state.strategy.hidden, + state.strategy.aux_extra, + prior_discretized=prior_discretized, + ) + t = state.t + dt + error, observed, corr = correction.estimate_error( + hidden, vector_field=vector_field, t=t ) output_scale = calibration.update(state.output_scale, observed=observed) - prior, _calibrated = calibration.extract(output_scale) - state_strategy = strategy.complete(state_strategy, output_scale=prior) + prior_, _calibrated = calibration.extract(output_scale) + hidden, extra = extrapolation.complete(hidden, extra, output_scale=prior_) + hidden, corr = correction.complete(hidden, corr) + state_strategy = _StrategyState( + t=t, hidden=hidden, aux_extra=extra, aux_corr=corr + ) # Return solution state = _SolverState(strategy=state_strategy, output_scale=output_scale) return dt * error, state return _ProbabilisticSolver( - strategy=strategy, + prior=prior, + ssm=ssm, + extrapolation=extrapolation, + correction=correction, calibration=_calibration_most_recent(ssm=ssm), name="Dynamic probabilistic solver", step_implementation=step_dynamic, @@ -940,7 +958,6 @@ def step(state: _SolverState, *, vector_field, dt, calibration): return dt * error, state return _ProbabilisticSolver( - strategy=strategy, ssm=ssm, prior=prior, extrapolation=extrapolation, From be6bdaaef4746e4643dcfd9d45addfde88d95efe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 11:16:34 +0200 Subject: [PATCH 14/24] Remove the strategy variable --- probdiffeq/ivpsolvers.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 9cdb1d9b..f22127cc 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -594,8 +594,7 @@ def strategy_smoother(prior, correction: _Correction, /, ssm): is_suitable_for_save_every_step=True, is_suitable_for_offgrid_marginals=True, ) - strategy = None - return strategy, correction, extrapolation, prior + return correction, extrapolation, prior def strategy_fixedpoint(prior, correction: _Correction, /, ssm): @@ -607,8 +606,7 @@ def strategy_fixedpoint(prior, correction: _Correction, /, ssm): is_suitable_for_save_every_step=False, is_suitable_for_offgrid_marginals=False, ) - strategy = None - return strategy, correction, extrapolation, prior + return correction, extrapolation, prior def strategy_filter(prior, correction: _Correction, /, *, ssm): @@ -620,8 +618,7 @@ def strategy_filter(prior, correction: _Correction, /, *, ssm): is_suitable_for_save_every_step=True, is_suitable_for_offgrid_marginals=True, ) - strategy = None - return strategy, correction, extrapolation, prior + return correction, extrapolation, prior @containers.dataclass @@ -806,7 +803,7 @@ def solver_mle(inputs, *, ssm): Warning: needs to be combined with a call to stats.calibrate() after solving if the MLE-calibration shall be *used*. """ - strategy, correction, extrapolation, prior = inputs + correction, extrapolation, prior = inputs def step_mle(state, /, *, dt, vector_field, calibration): output_scale_prior, _calibrated = calibration.extract(state.output_scale) @@ -875,7 +872,7 @@ def extract(state, /): def solver_dynamic(inputs, *, ssm): """Create a solver that calibrates the output scale dynamically.""" - strategy, correction, extrapolation, prior = inputs + correction, extrapolation, prior = inputs def step_dynamic(state, /, *, dt, vector_field, calibration): prior_discretized = prior.discretize(dt) @@ -929,7 +926,7 @@ def extract(state, /): def solver(inputs, /, *, ssm): """Create a solver that does not calibrate the output scale automatically.""" - strategy, correction, extrapolation, prior = inputs + correction, extrapolation, prior = inputs def step(state: _SolverState, *, vector_field, dt, calibration): del calibration # unused From fdee8c749dfe3d247f289f10eb35b22560682651 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 11:26:46 +0200 Subject: [PATCH 15/24] Move 't' from StrategyState to SolverState --- probdiffeq/ivpsolvers.py | 93 ++++++++++++++++++++++------------------ 1 file changed, 51 insertions(+), 42 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index f22127cc..f0c9ce3f 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -575,14 +575,12 @@ def _estimate_error(observed, /, *, ssm): return output_scale * error_estimate_unscaled -# msg = ( +# TODO = ( # "Next up: " -# "Delete the _Strategy." # "Then slowly migrate the strategy-state attributes to the solver state. " # "Then update the signature of strategy_* functions " # "by removing correction, prior, and so on and fix all tests" -# ) -# raise RuntimeError(msg) +# ) def strategy_smoother(prior, correction: _Correction, /, ssm): @@ -631,7 +629,6 @@ class _Calibration: class _StrategyState(containers.NamedTuple): - t: Any hidden: Any aux_extra: Any aux_corr: Any @@ -640,13 +637,10 @@ class _StrategyState(containers.NamedTuple): class _SolverState(containers.NamedTuple): """Solver state.""" + t: Any strategy: Any output_scale: Any - @property - def t(self): - return self.strategy.t - @containers.dataclass class _ProbabilisticSolver: @@ -671,7 +665,7 @@ def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_sca rv, extra = self.extrapolation.init(posterior_t0) rv, corr = self.correction.init(rv) - state_t0 = _StrategyState(t=t0, hidden=rv, aux_extra=extra, aux_corr=corr) + state_t0 = _StrategyState(hidden=rv, aux_extra=extra, aux_corr=corr) # TODO: Replace dt0, dt1, and prior with prior_dt0, and prior_dt1 interp = self.extrapolation.interpolate( @@ -708,10 +702,10 @@ def init(self, t, initial_condition) -> _SolverState: rv, extra = self.extrapolation.init(posterior) rv, corr = self.correction.init(rv) - state_strategy = _StrategyState(t=t, hidden=rv, aux_extra=extra, aux_corr=corr) + state_strategy = _StrategyState(hidden=rv, aux_extra=extra, aux_corr=corr) calib_state = self.calibration.init(output_scale) - return _SolverState(strategy=state_strategy, output_scale=calib_state) + return _SolverState(t=t, strategy=state_strategy, output_scale=calib_state) def step(self, state: _SolverState, *, vector_field, dt) -> _SolverState: return self.step_implementation( @@ -721,7 +715,7 @@ def step(self, state: _SolverState, *, vector_field, dt) -> _SolverState: def extract(self, state: _SolverState, /): hidden = self.correction.extract(state.strategy.hidden) posterior = self.extrapolation.extract(hidden, state.strategy.aux_extra) - t = state.strategy.t + t = state.t _output_scale_prior, output_scale = self.calibration.extract(state.output_scale) return t, (posterior, output_scale) @@ -731,36 +725,49 @@ def interpolate( ) -> _InterpRes: output_scale, _ = self.calibration.extract(interp_to.output_scale) interp = self._case_interpolate( - t, s0=interp_from.strategy, s1=interp_to.strategy, output_scale=output_scale + t, + t0=interp_from.t, + t1=interp_to.t, + s0=interp_from.strategy, + s1=interp_to.strategy, + output_scale=output_scale, + ) + prev = _SolverState( + t=t, strategy=interp.interp_from, output_scale=interp_from.output_scale + ) + sol = _SolverState( + t=t, strategy=interp.interpolated, output_scale=interp_to.output_scale + ) + acc = _SolverState( + t=interp_to.t, + strategy=interp.step_from, + output_scale=interp_to.output_scale, ) - prev = _SolverState(interp.interp_from, output_scale=interp_from.output_scale) - sol = _SolverState(interp.interpolated, output_scale=interp_to.output_scale) - acc = _SolverState(interp.step_from, output_scale=interp_to.output_scale) return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) def _case_interpolate( - self, t, *, s0: _StrategyState, s1: _StrategyState, output_scale + self, t, *, t0, t1, s0: _StrategyState, s1: _StrategyState, output_scale ) -> _InterpRes: """Process the solution in case t>t_n.""" # Interpolate interp = self.extrapolation.interpolate( state_t0=(s0.hidden, s0.aux_extra), marginal_t1=s1.hidden, - dt0=t - s0.t, - dt1=s1.t - t, + dt0=t - t0, + dt1=t1 - t, output_scale=output_scale, prior=self.prior, ) # Turn outputs into valid states - def _state(t_, x): + def _state(x): corr_like = tree_util.tree_map(np.empty_like, s0.aux_corr) - return _StrategyState(t=t_, hidden=x[0], aux_extra=x[1], aux_corr=corr_like) + return _StrategyState(hidden=x[0], aux_extra=x[1], aux_corr=corr_like) - step_from = _state(s1.t, interp.step_from) - interpolated = _state(t, interp.interpolated) - interp_from = _state(t, interp.interp_from) + step_from = _state(interp.step_from) + interpolated = _state(interp.interpolated) + interp_from = _state(interp.interp_from) return _InterpRes( step_from=step_from, interpolated=interpolated, interp_from=interp_from ) @@ -777,18 +784,24 @@ def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: ) def _state(s): - t = interp_to.strategy.t corr_like = tree_util.tree_map(np.empty_like, interp_to.strategy.aux_corr) - return _StrategyState(t=t, hidden=s[0], aux_extra=s[1], aux_corr=corr_like) + return _StrategyState(hidden=s[0], aux_extra=s[1], aux_corr=corr_like) step_from_ = _state(step_from_) solution_ = _state(solution_) interp_from_ = _state(interp_from_) x = _InterpRes(step_from_, solution_, interp_from_) - prev = _SolverState(x.interp_from, output_scale=interp_from.output_scale) - sol = _SolverState(x.interpolated, output_scale=interp_to.output_scale) - acc = _SolverState(x.step_from, output_scale=interp_to.output_scale) + t = interp_to.t + prev = _SolverState( + t=t, strategy=x.interp_from, output_scale=interp_from.output_scale + ) + sol = _SolverState( + t=t, strategy=x.interpolated, output_scale=interp_to.output_scale + ) + acc = _SolverState( + t=t, strategy=x.step_from, output_scale=interp_to.output_scale + ) return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) def initial_condition(self): @@ -823,16 +836,14 @@ def step_mle(state, /, *, dt, vector_field, calibration): hidden, extra, output_scale=output_scale_prior ) hidden, corr = correction.complete(hidden, corr) - state_strategy = _StrategyState( - t=t, hidden=hidden, aux_extra=extra, aux_corr=corr - ) + state_strategy = _StrategyState(hidden=hidden, aux_extra=extra, aux_corr=corr) observed = state_strategy.aux_corr # Calibrate output_scale = calibration.update(state.output_scale, observed=observed) # Return - state = _SolverState(strategy=state_strategy, output_scale=output_scale) + state = _SolverState(t=t, strategy=state_strategy, output_scale=output_scale) return dt * error, state return _ProbabilisticSolver( @@ -891,12 +902,10 @@ def step_dynamic(state, /, *, dt, vector_field, calibration): prior_, _calibrated = calibration.extract(output_scale) hidden, extra = extrapolation.complete(hidden, extra, output_scale=prior_) hidden, corr = correction.complete(hidden, corr) - state_strategy = _StrategyState( - t=t, hidden=hidden, aux_extra=extra, aux_corr=corr - ) + state_strategy = _StrategyState(hidden=hidden, aux_extra=extra, aux_corr=corr) # Return solution - state = _SolverState(strategy=state_strategy, output_scale=output_scale) + state = _SolverState(t=t, strategy=state_strategy, output_scale=output_scale) return dt * error, state return _ProbabilisticSolver( @@ -946,12 +955,12 @@ def step(state: _SolverState, *, vector_field, dt, calibration): hidden, extra, output_scale=state.output_scale ) hidden, corr = correction.complete(hidden, corr) - state_strategy = _StrategyState( - t=t, hidden=hidden, aux_extra=extra, aux_corr=corr - ) + state_strategy = _StrategyState(hidden=hidden, aux_extra=extra, aux_corr=corr) # Extract and return solution - state = _SolverState(strategy=state_strategy, output_scale=state.output_scale) + state = _SolverState( + t=t, strategy=state_strategy, output_scale=state.output_scale + ) return dt * error, state return _ProbabilisticSolver( From 0654908a539c1b584cc14a1b71d994dc57b5e49b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 11:30:45 +0200 Subject: [PATCH 16/24] Delete the aux_corr field --- probdiffeq/ivpsolvers.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index f0c9ce3f..7664c46c 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -631,7 +631,6 @@ class _Calibration: class _StrategyState(containers.NamedTuple): hidden: Any aux_extra: Any - aux_corr: Any class _SolverState(containers.NamedTuple): @@ -665,7 +664,7 @@ def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_sca rv, extra = self.extrapolation.init(posterior_t0) rv, corr = self.correction.init(rv) - state_t0 = _StrategyState(hidden=rv, aux_extra=extra, aux_corr=corr) + state_t0 = _StrategyState(hidden=rv, aux_extra=extra) # TODO: Replace dt0, dt1, and prior with prior_dt0, and prior_dt1 interp = self.extrapolation.interpolate( @@ -702,7 +701,7 @@ def init(self, t, initial_condition) -> _SolverState: rv, extra = self.extrapolation.init(posterior) rv, corr = self.correction.init(rv) - state_strategy = _StrategyState(hidden=rv, aux_extra=extra, aux_corr=corr) + state_strategy = _StrategyState(hidden=rv, aux_extra=extra) calib_state = self.calibration.init(output_scale) return _SolverState(t=t, strategy=state_strategy, output_scale=calib_state) @@ -762,8 +761,7 @@ def _case_interpolate( # Turn outputs into valid states def _state(x): - corr_like = tree_util.tree_map(np.empty_like, s0.aux_corr) - return _StrategyState(hidden=x[0], aux_extra=x[1], aux_corr=corr_like) + return _StrategyState(hidden=x[0], aux_extra=x[1]) step_from = _state(interp.step_from) interpolated = _state(interp.interpolated) @@ -784,8 +782,7 @@ def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: ) def _state(s): - corr_like = tree_util.tree_map(np.empty_like, interp_to.strategy.aux_corr) - return _StrategyState(hidden=s[0], aux_extra=s[1], aux_corr=corr_like) + return _StrategyState(hidden=s[0], aux_extra=s[1]) step_from_ = _state(step_from_) solution_ = _state(solution_) @@ -835,9 +832,8 @@ def step_mle(state, /, *, dt, vector_field, calibration): hidden, extra = extrapolation.complete( hidden, extra, output_scale=output_scale_prior ) - hidden, corr = correction.complete(hidden, corr) - state_strategy = _StrategyState(hidden=hidden, aux_extra=extra, aux_corr=corr) - observed = state_strategy.aux_corr + hidden, observed = correction.complete(hidden, corr) + state_strategy = _StrategyState(hidden=hidden, aux_extra=extra) # Calibrate output_scale = calibration.update(state.output_scale, observed=observed) @@ -902,7 +898,7 @@ def step_dynamic(state, /, *, dt, vector_field, calibration): prior_, _calibrated = calibration.extract(output_scale) hidden, extra = extrapolation.complete(hidden, extra, output_scale=prior_) hidden, corr = correction.complete(hidden, corr) - state_strategy = _StrategyState(hidden=hidden, aux_extra=extra, aux_corr=corr) + state_strategy = _StrategyState(hidden=hidden, aux_extra=extra) # Return solution state = _SolverState(t=t, strategy=state_strategy, output_scale=output_scale) @@ -955,7 +951,7 @@ def step(state: _SolverState, *, vector_field, dt, calibration): hidden, extra, output_scale=state.output_scale ) hidden, corr = correction.complete(hidden, corr) - state_strategy = _StrategyState(hidden=hidden, aux_extra=extra, aux_corr=corr) + state_strategy = _StrategyState(hidden=hidden, aux_extra=extra) # Extract and return solution state = _SolverState( From 29f5f28eb3428c876e9b1dbb29c0343ccfa05d86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 11:51:13 +0200 Subject: [PATCH 17/24] Delete the strategy state --- probdiffeq/ivpsolve.py | 4 +- probdiffeq/ivpsolvers.py | 127 +++++++++++++++++---------------------- 2 files changed, 58 insertions(+), 73 deletions(-) diff --git a/probdiffeq/ivpsolve.py b/probdiffeq/ivpsolve.py index 727d46c8..a401cdca 100644 --- a/probdiffeq/ivpsolve.py +++ b/probdiffeq/ivpsolve.py @@ -205,8 +205,8 @@ def body_fn(state: _RejectionState) -> _RejectionState: dt=self.control.extract(state_control), ) # Normalise the error - u_proposed = self.ssm.stats.qoi(state_proposed.strategy.hidden)[0] - u_step_from = self.ssm.stats.qoi(state_proposed.strategy.hidden)[0] + u_proposed = self.ssm.stats.qoi(state_proposed.hidden)[0] + u_step_from = self.ssm.stats.qoi(state_proposed.hidden)[0] u = np.maximum(np.abs(u_proposed), np.abs(u_step_from)) error_power = _error_scale_and_normalize(error_estimate, u=u) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 7664c46c..6266fcba 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -628,16 +628,12 @@ class _Calibration: extract: Callable -class _StrategyState(containers.NamedTuple): - hidden: Any - aux_extra: Any - - class _SolverState(containers.NamedTuple): """Solver state.""" t: Any - strategy: Any + hidden: Any + aux_extra: Any output_scale: Any @@ -664,11 +660,10 @@ def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_sca rv, extra = self.extrapolation.init(posterior_t0) rv, corr = self.correction.init(rv) - state_t0 = _StrategyState(hidden=rv, aux_extra=extra) # TODO: Replace dt0, dt1, and prior with prior_dt0, and prior_dt1 interp = self.extrapolation.interpolate( - state_t0=(state_t0.hidden, state_t0.aux_extra), + state_t0=(rv, extra), marginal_t1=marginals_t1, dt0=dt0, dt1=dt1, @@ -701,10 +696,9 @@ def init(self, t, initial_condition) -> _SolverState: rv, extra = self.extrapolation.init(posterior) rv, corr = self.correction.init(rv) - state_strategy = _StrategyState(hidden=rv, aux_extra=extra) calib_state = self.calibration.init(output_scale) - return _SolverState(t=t, strategy=state_strategy, output_scale=calib_state) + return _SolverState(t=t, hidden=rv, aux_extra=extra, output_scale=calib_state) def step(self, state: _SolverState, *, vector_field, dt) -> _SolverState: return self.step_implementation( @@ -712,8 +706,8 @@ def step(self, state: _SolverState, *, vector_field, dt) -> _SolverState: ) def extract(self, state: _SolverState, /): - hidden = self.correction.extract(state.strategy.hidden) - posterior = self.extrapolation.extract(hidden, state.strategy.aux_extra) + hidden = self.correction.extract(state.hidden) + posterior = self.extrapolation.extract(hidden, state.aux_extra) t = state.t _output_scale_prior, output_scale = self.calibration.extract(state.output_scale) @@ -723,49 +717,42 @@ def interpolate( self, t, *, interp_from: _SolverState, interp_to: _SolverState ) -> _InterpRes: output_scale, _ = self.calibration.extract(interp_to.output_scale) - interp = self._case_interpolate( - t, - t0=interp_from.t, - t1=interp_to.t, - s0=interp_from.strategy, - s1=interp_to.strategy, - output_scale=output_scale, + return self._case_interpolate( + t, s0=interp_from, s1=interp_to, output_scale=output_scale ) - prev = _SolverState( - t=t, strategy=interp.interp_from, output_scale=interp_from.output_scale - ) - sol = _SolverState( - t=t, strategy=interp.interpolated, output_scale=interp_to.output_scale - ) - acc = _SolverState( - t=interp_to.t, - strategy=interp.step_from, - output_scale=interp_to.output_scale, - ) - return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) - - def _case_interpolate( - self, t, *, t0, t1, s0: _StrategyState, s1: _StrategyState, output_scale - ) -> _InterpRes: + # prev = _SolverState( + # t=t, strategy=interp.interp_from, output_scale=interp_from.output_scale + # ) + # sol = _SolverState( + # t=t, strategy=interp.interpolated, output_scale=interp_to.output_scale + # ) + # acc = _SolverState( + # t=interp_to.t, + # strategy=interp.step_from, + # output_scale=interp_to.output_scale, + # ) + # return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) + + def _case_interpolate(self, t, *, s0, s1, output_scale) -> _InterpRes: """Process the solution in case t>t_n.""" # Interpolate interp = self.extrapolation.interpolate( state_t0=(s0.hidden, s0.aux_extra), marginal_t1=s1.hidden, - dt0=t - t0, - dt1=t1 - t, + dt0=t - s0.t, + dt1=s1.t - t, output_scale=output_scale, prior=self.prior, ) # Turn outputs into valid states - def _state(x): - return _StrategyState(hidden=x[0], aux_extra=x[1]) + def _state(t_, x, scale): + return _SolverState(t=t_, hidden=x[0], aux_extra=x[1], output_scale=scale) - step_from = _state(interp.step_from) - interpolated = _state(interp.interpolated) - interp_from = _state(interp.interp_from) + step_from = _state(s1.t, interp.step_from, s1.output_scale) + interpolated = _state(t, interp.interpolated, s1.output_scale) + interp_from = _state(t, interp.interp_from, s0.output_scale) return _InterpRes( step_from=step_from, interpolated=interpolated, interp_from=interp_from ) @@ -773,7 +760,7 @@ def _state(x): def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: """Process the solution in case t=t_n.""" tmp = self.extrapolation.interpolate_at_t1( - interp_to.strategy.hidden, interp_to.strategy.aux_extra, prior=self.prior + interp_to.hidden, interp_to.aux_extra, prior=self.prior ) step_from_, solution_, interp_from_ = ( tmp.step_from, @@ -781,24 +768,25 @@ def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: tmp.interp_from, ) - def _state(s): - return _StrategyState(hidden=s[0], aux_extra=s[1]) - - step_from_ = _state(step_from_) - solution_ = _state(solution_) - interp_from_ = _state(interp_from_) - x = _InterpRes(step_from_, solution_, interp_from_) + def _state(t_, s, scale): + return _SolverState(t=t_, hidden=s[0], aux_extra=s[1], output_scale=scale) t = interp_to.t - prev = _SolverState( - t=t, strategy=x.interp_from, output_scale=interp_from.output_scale - ) - sol = _SolverState( - t=t, strategy=x.interpolated, output_scale=interp_to.output_scale - ) - acc = _SolverState( - t=t, strategy=x.step_from, output_scale=interp_to.output_scale - ) + prev = _state(t, interp_from_, interp_from.output_scale) + sol = _state(t, solution_, interp_to.output_scale) + acc = _state(t, step_from_, interp_to.output_scale) + # x = _InterpRes(step_from_, solution_, interp_from_) + # + # t = interp_to.t + # prev = _SolverState( + # t=t, strategy=x.interp_from, output_scale=interp_from.output_scale + # ) + # sol = _SolverState( + # t=t, strategy=x.interpolated, output_scale=interp_to.output_scale + # ) + # acc = _SolverState( + # t=t, strategy=x.step_from, output_scale=interp_to.output_scale + # ) return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) def initial_condition(self): @@ -820,9 +808,7 @@ def step_mle(state, /, *, dt, vector_field, calibration): prior_discretized = prior.discretize(dt) hidden, extra = extrapolation.begin( - state.strategy.hidden, - state.strategy.aux_extra, - prior_discretized=prior_discretized, + state.hidden, state.aux_extra, prior_discretized=prior_discretized ) t = state.t + dt error, _, corr = correction.estimate_error( @@ -833,13 +819,14 @@ def step_mle(state, /, *, dt, vector_field, calibration): hidden, extra, output_scale=output_scale_prior ) hidden, observed = correction.complete(hidden, corr) - state_strategy = _StrategyState(hidden=hidden, aux_extra=extra) # Calibrate output_scale = calibration.update(state.output_scale, observed=observed) # Return - state = _SolverState(t=t, strategy=state_strategy, output_scale=output_scale) + state = _SolverState( + t=t, hidden=hidden, aux_extra=extra, output_scale=output_scale + ) return dt * error, state return _ProbabilisticSolver( @@ -898,10 +885,11 @@ def step_dynamic(state, /, *, dt, vector_field, calibration): prior_, _calibrated = calibration.extract(output_scale) hidden, extra = extrapolation.complete(hidden, extra, output_scale=prior_) hidden, corr = correction.complete(hidden, corr) - state_strategy = _StrategyState(hidden=hidden, aux_extra=extra) # Return solution - state = _SolverState(t=t, strategy=state_strategy, output_scale=output_scale) + state = _SolverState( + t=t, hidden=hidden, aux_extra=extra, output_scale=output_scale + ) return dt * error, state return _ProbabilisticSolver( @@ -938,9 +926,7 @@ def step(state: _SolverState, *, vector_field, dt, calibration): prior_discretized = prior.discretize(dt) hidden, extra = extrapolation.begin( - state.strategy.hidden, - state.strategy.aux_extra, - prior_discretized=prior_discretized, + state.hidden, state.aux_extra, prior_discretized=prior_discretized ) t = state.t + dt error, _, corr = correction.estimate_error( @@ -951,11 +937,10 @@ def step(state: _SolverState, *, vector_field, dt, calibration): hidden, extra, output_scale=state.output_scale ) hidden, corr = correction.complete(hidden, corr) - state_strategy = _StrategyState(hidden=hidden, aux_extra=extra) # Extract and return solution state = _SolverState( - t=t, strategy=state_strategy, output_scale=state.output_scale + t=t, hidden=hidden, aux_extra=extra, output_scale=state.output_scale ) return dt * error, state From 559041160f8baaf33a57706feacce47145911acc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 11:52:16 +0200 Subject: [PATCH 18/24] Remove the _StrategyState --- probdiffeq/ivpsolvers.py | 57 +++++++++------------------------------- 1 file changed, 12 insertions(+), 45 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 6266fcba..f50dca60 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -628,7 +628,7 @@ class _Calibration: extract: Callable -class _SolverState(containers.NamedTuple): +class _State(containers.NamedTuple): """Solver state.""" t: Any @@ -691,21 +691,21 @@ def is_suitable_for_save_at(self): def is_suitable_for_save_every_step(self): return self.extrapolation.is_suitable_for_save_every_step - def init(self, t, initial_condition) -> _SolverState: + def init(self, t, initial_condition) -> _State: posterior, output_scale = initial_condition rv, extra = self.extrapolation.init(posterior) rv, corr = self.correction.init(rv) calib_state = self.calibration.init(output_scale) - return _SolverState(t=t, hidden=rv, aux_extra=extra, output_scale=calib_state) + return _State(t=t, hidden=rv, aux_extra=extra, output_scale=calib_state) - def step(self, state: _SolverState, *, vector_field, dt) -> _SolverState: + def step(self, state: _State, *, vector_field, dt) -> _State: return self.step_implementation( state, vector_field=vector_field, dt=dt, calibration=self.calibration ) - def extract(self, state: _SolverState, /): + def extract(self, state: _State, /): hidden = self.correction.extract(state.hidden) posterior = self.extrapolation.extract(hidden, state.aux_extra) t = state.t @@ -713,25 +713,11 @@ def extract(self, state: _SolverState, /): _output_scale_prior, output_scale = self.calibration.extract(state.output_scale) return t, (posterior, output_scale) - def interpolate( - self, t, *, interp_from: _SolverState, interp_to: _SolverState - ) -> _InterpRes: + def interpolate(self, t, *, interp_from: _State, interp_to: _State) -> _InterpRes: output_scale, _ = self.calibration.extract(interp_to.output_scale) return self._case_interpolate( t, s0=interp_from, s1=interp_to, output_scale=output_scale ) - # prev = _SolverState( - # t=t, strategy=interp.interp_from, output_scale=interp_from.output_scale - # ) - # sol = _SolverState( - # t=t, strategy=interp.interpolated, output_scale=interp_to.output_scale - # ) - # acc = _SolverState( - # t=interp_to.t, - # strategy=interp.step_from, - # output_scale=interp_to.output_scale, - # ) - # return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) def _case_interpolate(self, t, *, s0, s1, output_scale) -> _InterpRes: """Process the solution in case t>t_n.""" @@ -748,7 +734,7 @@ def _case_interpolate(self, t, *, s0, s1, output_scale) -> _InterpRes: # Turn outputs into valid states def _state(t_, x, scale): - return _SolverState(t=t_, hidden=x[0], aux_extra=x[1], output_scale=scale) + return _State(t=t_, hidden=x[0], aux_extra=x[1], output_scale=scale) step_from = _state(s1.t, interp.step_from, s1.output_scale) interpolated = _state(t, interp.interpolated, s1.output_scale) @@ -769,24 +755,12 @@ def interpolate_at_t1(self, *, interp_from, interp_to) -> _InterpRes: ) def _state(t_, s, scale): - return _SolverState(t=t_, hidden=s[0], aux_extra=s[1], output_scale=scale) + return _State(t=t_, hidden=s[0], aux_extra=s[1], output_scale=scale) t = interp_to.t prev = _state(t, interp_from_, interp_from.output_scale) sol = _state(t, solution_, interp_to.output_scale) acc = _state(t, step_from_, interp_to.output_scale) - # x = _InterpRes(step_from_, solution_, interp_from_) - # - # t = interp_to.t - # prev = _SolverState( - # t=t, strategy=x.interp_from, output_scale=interp_from.output_scale - # ) - # sol = _SolverState( - # t=t, strategy=x.interpolated, output_scale=interp_to.output_scale - # ) - # acc = _SolverState( - # t=t, strategy=x.step_from, output_scale=interp_to.output_scale - # ) return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev) def initial_condition(self): @@ -820,13 +794,8 @@ def step_mle(state, /, *, dt, vector_field, calibration): ) hidden, observed = correction.complete(hidden, corr) - # Calibrate output_scale = calibration.update(state.output_scale, observed=observed) - - # Return - state = _SolverState( - t=t, hidden=hidden, aux_extra=extra, output_scale=output_scale - ) + state = _State(t=t, hidden=hidden, aux_extra=extra, output_scale=output_scale) return dt * error, state return _ProbabilisticSolver( @@ -887,9 +856,7 @@ def step_dynamic(state, /, *, dt, vector_field, calibration): hidden, corr = correction.complete(hidden, corr) # Return solution - state = _SolverState( - t=t, hidden=hidden, aux_extra=extra, output_scale=output_scale - ) + state = _State(t=t, hidden=hidden, aux_extra=extra, output_scale=output_scale) return dt * error, state return _ProbabilisticSolver( @@ -921,7 +888,7 @@ def solver(inputs, /, *, ssm): """Create a solver that does not calibrate the output scale automatically.""" correction, extrapolation, prior = inputs - def step(state: _SolverState, *, vector_field, dt, calibration): + def step(state: _State, *, vector_field, dt, calibration): del calibration # unused prior_discretized = prior.discretize(dt) @@ -939,7 +906,7 @@ def step(state: _SolverState, *, vector_field, dt, calibration): hidden, corr = correction.complete(hidden, corr) # Extract and return solution - state = _SolverState( + state = _State( t=t, hidden=hidden, aux_extra=extra, output_scale=state.output_scale ) return dt * error, state From c7d678f4d263d5934747dc71105661792c3c223b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 12:16:16 +0200 Subject: [PATCH 19/24] Update the dynamic solver --- probdiffeq/ivpsolvers.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index f50dca60..451548ca 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -577,7 +577,6 @@ def _estimate_error(observed, /, *, ssm): # TODO = ( # "Next up: " -# "Then slowly migrate the strategy-state attributes to the solver state. " # "Then update the signature of strategy_* functions " # "by removing correction, prior, and so on and fix all tests" # ) @@ -840,9 +839,7 @@ def solver_dynamic(inputs, *, ssm): def step_dynamic(state, /, *, dt, vector_field, calibration): prior_discretized = prior.discretize(dt) hidden, extra = extrapolation.begin( - state.strategy.hidden, - state.strategy.aux_extra, - prior_discretized=prior_discretized, + state.hidden, state.aux_extra, prior_discretized=prior_discretized ) t = state.t + dt error, observed, corr = correction.estimate_error( From b2b67b4bf48cd4dd0923bcccff6311906fb06657 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 12:27:45 +0200 Subject: [PATCH 20/24] Update test_ivpsolve --- probdiffeq/ivpsolvers.py | 23 +++++-------------- .../test_fixed_grid_vs_save_every_step.py | 4 ++-- .../test_save_at_vs_save_every_step.py | 5 ++-- tests/test_ivpsolve/test_save_every_step.py | 4 ++-- tests/test_ivpsolve/test_solution_object.py | 8 +++---- ...test_terminal_values_vs_save_every_step.py | 4 ++-- 6 files changed, 18 insertions(+), 30 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 451548ca..6ae321a1 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -575,23 +575,15 @@ def _estimate_error(observed, /, *, ssm): return output_scale * error_estimate_unscaled -# TODO = ( -# "Next up: " -# "Then update the signature of strategy_* functions " -# "by removing correction, prior, and so on and fix all tests" -# ) - - -def strategy_smoother(prior, correction: _Correction, /, ssm): +def strategy_smoother(*, ssm): """Construct a smoother.""" - extrapolation = _ExtraImplSmoother( + return _ExtraImplSmoother( name="Smoother", ssm=ssm, is_suitable_for_save_at=False, is_suitable_for_save_every_step=True, is_suitable_for_offgrid_marginals=True, ) - return correction, extrapolation, prior def strategy_fixedpoint(prior, correction: _Correction, /, ssm): @@ -606,16 +598,15 @@ def strategy_fixedpoint(prior, correction: _Correction, /, ssm): return correction, extrapolation, prior -def strategy_filter(prior, correction: _Correction, /, *, ssm): +def strategy_filter(*, ssm): """Construct a filter.""" - extrapolation = _ExtraImplFilter( + return _ExtraImplFilter( name="Filter", ssm=ssm, is_suitable_for_save_at=True, is_suitable_for_save_every_step=True, is_suitable_for_offgrid_marginals=True, ) - return correction, extrapolation, prior @containers.dataclass @@ -768,13 +759,12 @@ def initial_condition(self): return posterior, self.prior.output_scale -def solver_mle(inputs, *, ssm): +def solver_mle(extrapolation, /, *, correction, prior, ssm): """Create a solver that calibrates the output scale via maximum-likelihood. Warning: needs to be combined with a call to stats.calibrate() after solving if the MLE-calibration shall be *used*. """ - correction, extrapolation, prior = inputs def step_mle(state, /, *, dt, vector_field, calibration): output_scale_prior, _calibrated = calibration.extract(state.output_scale) @@ -881,9 +871,8 @@ def extract(state, /): return _Calibration(init=init, update=update, extract=extract) -def solver(inputs, /, *, ssm): +def solver(extrapolation, /, *, correction, prior, ssm): """Create a solver that does not calibrate the output scale automatically.""" - correction, extrapolation, prior = inputs def step(state: _State, *, vector_field, dt, calibration): del calibration # unused diff --git a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py index 49863791..cfb4f390 100644 --- a/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_fixed_grid_vs_save_every_step.py @@ -18,8 +18,8 @@ class Taylor(containers.NamedTuple): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) control = ivpsolve.control_integral(clip=True) # Any clipped controller will do. asolver = ivpsolve.adaptive(solver, ssm=ssm, atol=1e-2, rtol=1e-2, control=control) diff --git a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py index ffc80402..4bf35b35 100644 --- a/tests/test_ivpsolve/test_save_at_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_save_at_vs_save_every_step.py @@ -13,10 +13,9 @@ def test_save_at_result_matches_interpolated_adaptive_result(fact): # Generate a solver tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), u0, num=2) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) - ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_ivpsolve/test_save_every_step.py b/tests/test_ivpsolve/test_save_every_step.py index 664a2b95..24d4e7e8 100644 --- a/tests/test_ivpsolve/test_save_every_step.py +++ b/tests/test_ivpsolve/test_save_every_step.py @@ -24,8 +24,8 @@ def python_loop_solution(ivp, *, fact, strategy_fun): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = strategy_fun(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) # clip=False because we need to test adaptive-step-interpolation # for smoothers diff --git a/tests/test_ivpsolve/test_solution_object.py b/tests/test_ivpsolve/test_solution_object.py index a7831a5e..b9924349 100644 --- a/tests/test_ivpsolve/test_solution_object.py +++ b/tests/test_ivpsolve/test_solution_object.py @@ -24,8 +24,8 @@ def fixture_approximate_solution(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) asolver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() @@ -88,8 +88,8 @@ def solve(init): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) initcond = solver.initial_condition() diff --git a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py index f3e274b5..98e36064 100644 --- a/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py +++ b/tests/test_ivpsolve/test_terminal_values_vs_save_every_step.py @@ -14,8 +14,8 @@ def test_terminal_values_identical(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) asolver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() From 279d4fa3a925f3fbc3bc43dbc4e34a8d920747f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 12:34:27 +0200 Subject: [PATCH 21/24] Update test_ivpsolvers --- probdiffeq/ivpsolvers.py | 8 +++----- .../test_calibration_dynamic_vs_mle.py | 4 ++-- .../test_calibration_mle_vs_none.py | 16 ++++++++-------- tests/test_ivpsolvers/test_corrections.py | 4 ++-- .../test_strategy_smoother_vs_filter.py | 8 ++++---- .../test_strategy_smoother_vs_fixedpoint.py | 16 ++++++++-------- ...est_strategy_warnings_for_wrong_strategies.py | 8 ++++---- 7 files changed, 31 insertions(+), 33 deletions(-) diff --git a/probdiffeq/ivpsolvers.py b/probdiffeq/ivpsolvers.py index 6ae321a1..a198d9ea 100644 --- a/probdiffeq/ivpsolvers.py +++ b/probdiffeq/ivpsolvers.py @@ -586,16 +586,15 @@ def strategy_smoother(*, ssm): ) -def strategy_fixedpoint(prior, correction: _Correction, /, ssm): +def strategy_fixedpoint(*, ssm): """Construct a fixedpoint-smoother.""" - extrapolation = _ExtraImplFixedPoint( + return _ExtraImplFixedPoint( name="Fixed-point smoother", ssm=ssm, is_suitable_for_save_at=True, is_suitable_for_save_every_step=False, is_suitable_for_offgrid_marginals=False, ) - return correction, extrapolation, prior def strategy_filter(*, ssm): @@ -822,9 +821,8 @@ def extract(state, /): return _Calibration(init=init, update=update, extract=extract) -def solver_dynamic(inputs, *, ssm): +def solver_dynamic(extrapolation, *, correction, prior, ssm): """Create a solver that calibrates the output scale dynamically.""" - correction, extrapolation, prior = inputs def step_dynamic(state, /, *, dt, vector_field, calibration): prior_discretized = prior.discretize(dt) diff --git a/tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py b/tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py index 32167aa8..f03b727f 100644 --- a/tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py +++ b/tests/test_ivpsolvers/test_calibration_dynamic_vs_mle.py @@ -16,8 +16,8 @@ def test_exponential_approximated_well(fact): ibm, ssm = ivpsolvers.prior_ibm((*u0, vf(*u0, t=t0)), ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver_dynamic(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_ivpsolvers/test_calibration_mle_vs_none.py b/tests/test_ivpsolvers/test_calibration_mle_vs_none.py index ff746ef1..36526039 100644 --- a/tests/test_ivpsolvers/test_calibration_mle_vs_none.py +++ b/tests/test_ivpsolvers/test_calibration_mle_vs_none.py @@ -21,8 +21,8 @@ def case_solve_fixed_grid(fact): kwargs = {"grid": np.linspace(t0, t1, endpoint=True, num=5), "ssm": ssm} def solver_to_solution(solver_fun, strategy_fun): - strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = solver_fun(strategy, ssm=ssm) + strategy = strategy_fun(ssm=ssm) + solver = solver_fun(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() return ivpsolve.solve_fixed_grid(vf, init, solver=solver, **kwargs) @@ -44,8 +44,8 @@ def case_solve_adaptive_save_at(fact): kwargs = {"save_at": save_at, "dt0": dt0, "ssm": ssm} def solver_to_solution(solver_fun, strategy_fun): - strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = solver_fun(strategy, ssm=ssm) + strategy = strategy_fun(ssm=ssm) + solver = solver_fun(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) @@ -69,8 +69,8 @@ def case_solve_adaptive_save_every_step(fact): kwargs = {"t0": t0, "t1": t1, "dt0": dt0, "ssm": ssm} def solver_to_solution(solver_fun, strategy_fun): - strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = solver_fun(strategy, ssm=ssm) + strategy = strategy_fun(ssm=ssm) + solver = solver_fun(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) @@ -94,8 +94,8 @@ def case_simulate_terminal_values(fact): kwargs = {"t0": t0, "t1": t1, "dt0": dt0, "ssm": ssm} def solver_to_solution(solver_fun, strategy_fun): - strategy = strategy_fun(ibm, ts0, ssm=ssm) - solver = solver_fun(strategy, ssm=ssm) + strategy = strategy_fun(ssm=ssm) + solver = solver_fun(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm, atol=1e-2, rtol=1e-2) diff --git a/tests/test_ivpsolvers/test_corrections.py b/tests/test_ivpsolvers/test_corrections.py index ada4ae58..482a3ff7 100644 --- a/tests/test_ivpsolvers/test_corrections.py +++ b/tests/test_ivpsolvers/test_corrections.py @@ -46,8 +46,8 @@ def fixture_solution(correction_impl, fact): except NotImplementedError: testing.skip(reason="This type of linearisation has not been implemented.") - strategy = ivpsolvers.strategy_filter(ibm, corr, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=corr, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) adaptive_kwargs = {"adaptive_solver": adaptive_solver, "dt0": 0.1, "ssm": ssm} diff --git a/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py b/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py index 30277d50..b2d4dee5 100644 --- a/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py +++ b/tests/test_ivpsolvers/test_strategy_smoother_vs_filter.py @@ -20,8 +20,8 @@ def fixture_filter_solution(solver_setup): tcoeffs = solver_setup["tcoeffs"] ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=solver_setup["fact"]) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() return ivpsolve.solve_fixed_grid( @@ -34,8 +34,8 @@ def fixture_smoother_solution(solver_setup): tcoeffs = solver_setup["tcoeffs"] ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=solver_setup["fact"]) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() return ivpsolve.solve_fixed_grid( diff --git a/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py b/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py index 3bcbcc8f..8f7e0274 100644 --- a/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py +++ b/tests/test_ivpsolvers/test_strategy_smoother_vs_fixedpoint.py @@ -22,8 +22,8 @@ def fixture_solution_smoother(solver_setup): tcoeffs, fact = solver_setup["tcoeffs"], solver_setup["fact"] ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-3, rtol=1e-3, ssm=ssm) init = solver.initial_condition() @@ -43,8 +43,8 @@ def test_fixedpoint_smoother_equivalent_same_grid(solver_setup, solution_smoothe tcoeffs, fact = solver_setup["tcoeffs"], solver_setup["fact"] ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-3, rtol=1e-3, ssm=ssm) save_at = solution_smoother.t @@ -69,8 +69,8 @@ def test_fixedpoint_smoother_equivalent_different_grid(solver_setup, solution_sm tcoeffs, fact = solver_setup["tcoeffs"], solver_setup["fact"] ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver_smoother = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver_smoother = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) # Compute the offgrid-marginals ts = np.linspace(save_at[0], save_at[-1], num=7, endpoint=True) @@ -82,8 +82,8 @@ def test_fixedpoint_smoother_equivalent_different_grid(solver_setup, solution_sm tcoeffs, fact = solver_setup["tcoeffs"], solver_setup["fact"] ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-3, rtol=1e-3, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py b/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py index 354eea55..6a776906 100644 --- a/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py +++ b/tests/test_ivpsolvers/test_strategy_warnings_for_wrong_strategies.py @@ -13,8 +13,8 @@ def test_warning_for_fixedpoint_in_save_every_step_mode(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() @@ -32,8 +32,8 @@ def test_warning_for_smoother_in_save_at_mode(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() From bc088ee34f31f4d66688599d122e65cb565f2705 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 12:44:32 +0200 Subject: [PATCH 22/24] Update test_stats --- tests/test_stats/test_log_marginal_likelihood.py | 8 ++++---- .../test_log_marginal_likelihood_terminal_values.py | 4 ++-- tests/test_stats/test_offgrid_marginals.py | 8 ++++---- tests/test_stats/test_sample.py | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/test_stats/test_log_marginal_likelihood.py b/tests/test_stats/test_log_marginal_likelihood.py index d1e67c3a..79601a3e 100644 --- a/tests/test_stats/test_log_marginal_likelihood.py +++ b/tests/test_stats/test_log_marginal_likelihood.py @@ -14,8 +14,8 @@ def fixture_solution(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() @@ -96,8 +96,8 @@ def test_raises_error_for_filter(fact): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) grid = np.linspace(t0, t1, num=3) init = solver.initial_condition() diff --git a/tests/test_stats/test_log_marginal_likelihood_terminal_values.py b/tests/test_stats/test_log_marginal_likelihood_terminal_values.py index 8776bb4a..c5cd651d 100644 --- a/tests/test_stats/test_log_marginal_likelihood_terminal_values.py +++ b/tests/test_stats/test_log_marginal_likelihood_terminal_values.py @@ -29,8 +29,8 @@ def fixture_solution(strategy_func, fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = strategy_func(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = strategy_func(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() diff --git a/tests/test_stats/test_offgrid_marginals.py b/tests/test_stats/test_offgrid_marginals.py index 2ebcaf58..baf8a8ad 100644 --- a/tests/test_stats/test_offgrid_marginals.py +++ b/tests/test_stats/test_offgrid_marginals.py @@ -13,8 +13,8 @@ def test_filter_marginals_close_only_to_left_boundary(fact): tcoeffs = (u0, vf(u0, t=t0)) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() grid = np.linspace(t0, t1, endpoint=True, num=5) @@ -37,8 +37,8 @@ def test_smoother_marginals_close_to_both_boundaries(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() grid = np.linspace(t0, t1, endpoint=True, num=5) diff --git a/tests/test_stats/test_sample.py b/tests/test_stats/test_sample.py index fe6c75eb..bc4cab95 100644 --- a/tests/test_stats/test_sample.py +++ b/tests/test_stats/test_sample.py @@ -12,8 +12,8 @@ def fixture_approximation(fact): tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=2) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=fact) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) init = solver.initial_condition() From dd4120452051d211bf73bf4c0d69ea3a10567015 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 12:49:40 +0200 Subject: [PATCH 23/24] Update the benchmarks --- docs/benchmarks/hires/run_hires.py | 4 ++-- docs/benchmarks/lotkavolterra/run_lotkavolterra.py | 5 +++-- docs/benchmarks/pleiades/run_pleiades.py | 6 ++++-- docs/benchmarks/vanderpol/run_vanderpol.py | 6 ++++-- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/docs/benchmarks/hires/run_hires.py b/docs/benchmarks/hires/run_hires.py index 32b53631..bef46b78 100644 --- a/docs/benchmarks/hires/run_hires.py +++ b/docs/benchmarks/hires/run_hires.py @@ -88,8 +88,8 @@ def param_to_solution(tol): tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense") ts1 = ivpsolvers.correction_ts1(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts1, ssm=ssm) - solver = ivpsolvers.solver_dynamic(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm) control = ivpsolve.control_proportional_integral(clip=True) adaptive_solver = ivpsolve.adaptive( solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm diff --git a/docs/benchmarks/lotkavolterra/run_lotkavolterra.py b/docs/benchmarks/lotkavolterra/run_lotkavolterra.py index b381ba4b..029d94e1 100644 --- a/docs/benchmarks/lotkavolterra/run_lotkavolterra.py +++ b/docs/benchmarks/lotkavolterra/run_lotkavolterra.py @@ -81,8 +81,9 @@ def param_to_solution(tol): tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact=implementation) - strategy = ivpsolvers.strategy_filter(ibm, correction(ssm=ssm), ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + corr = correction(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=corr, ssm=ssm) control = ivpsolve.control_proportional_integral() adaptive_solver = ivpsolve.adaptive( solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm diff --git a/docs/benchmarks/pleiades/run_pleiades.py b/docs/benchmarks/pleiades/run_pleiades.py index 8d61fd8b..b1b49e25 100644 --- a/docs/benchmarks/pleiades/run_pleiades.py +++ b/docs/benchmarks/pleiades/run_pleiades.py @@ -101,8 +101,10 @@ def param_to_solution(tol): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic") ts0_or_ts1 = correction_fun(ssm=ssm, ode_order=2) - strategy = ivpsolvers.strategy_filter(ibm, ts0_or_ts1, ssm=ssm) - solver = ivpsolvers.solver_dynamic(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver_dynamic( + strategy, prior=ibm, correction=ts0_or_ts1, ssm=ssm + ) control = ivpsolve.control_proportional_integral() adaptive_solver = ivpsolve.adaptive( solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm diff --git a/docs/benchmarks/vanderpol/run_vanderpol.py b/docs/benchmarks/vanderpol/run_vanderpol.py index fa5010b2..5390cae0 100644 --- a/docs/benchmarks/vanderpol/run_vanderpol.py +++ b/docs/benchmarks/vanderpol/run_vanderpol.py @@ -80,9 +80,11 @@ def param_to_solution(tol): ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense") ts0_or_ts1 = ivpsolvers.correction_ts1(ode_order=2, ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0_or_ts1, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) - solver = ivpsolvers.solver_dynamic(strategy, ssm=ssm) + solver = ivpsolvers.solver_dynamic( + strategy, prior=ibm, correction=ts0_or_ts1, ssm=ssm + ) control = ivpsolve.control_proportional_integral(clip=True) adaptive_solver = ivpsolve.adaptive( solver, atol=1e-3 * tol, rtol=tol, control=control, ssm=ssm From bdef16c09ea70693f56a1af4b3eb48d8412611d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Sat, 26 Oct 2024 12:57:29 +0200 Subject: [PATCH 24/24] Update the examples --- .../use_equinox_bounded_while_loop.py | 4 ++-- docs/examples_parameter_estimation/neural_ode.py | 16 ++++++++-------- .../physics_enhanced_regression_1.py | 4 ++-- .../physics_enhanced_regression_2.py | 8 ++++---- docs/examples_quickstart/easy_example.py | 6 +++--- .../conditioning-on-zero-residual.py | 6 +++--- .../dynamic_output_scales.py | 6 +++--- .../posterior_uncertainties.py | 8 ++++---- .../second_order_problems.py | 8 ++++---- .../taylor_coefficients.py | 4 ++-- 10 files changed, 35 insertions(+), 35 deletions(-) diff --git a/docs/examples_misc/use_equinox_bounded_while_loop.py b/docs/examples_misc/use_equinox_bounded_while_loop.py index e8b4847e..c9f543b8 100644 --- a/docs/examples_misc/use_equinox_bounded_while_loop.py +++ b/docs/examples_misc/use_equinox_bounded_while_loop.py @@ -64,8 +64,8 @@ def vf(y, *, t): # noqa: ARG001 ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ode_order=1, ssm=ssm) - strategy = ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm) init = solver.initial_condition() diff --git a/docs/examples_parameter_estimation/neural_ode.py b/docs/examples_parameter_estimation/neural_ode.py index 57e77e85..9de1cc24 100644 --- a/docs/examples_parameter_estimation/neural_ode.py +++ b/docs/examples_parameter_estimation/neural_ode.py @@ -70,8 +70,8 @@ def loss_fn(parameters): tcoeffs = (*initial_values, vf(*initial_values, t=t0, p=parameters)) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver_ts0 = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver_ts0.initial_condition() sol = ivpsolve.solve_fixed_grid( @@ -128,8 +128,8 @@ def vf(y, *, t, p): tcoeffs = (u0, vf(u0, t=t0, p=f_args)) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) -solver_ts0 = ivpsolvers.solver(strategy, ssm=ssm) +strategy = ivpsolvers.strategy_smoother(ssm=ssm) +solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver_ts0.initial_condition() # + @@ -168,8 +168,8 @@ def vf(y, *, t, p): tcoeffs = (u0, vf(u0, t=t0, p=f_args)) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) -solver_ts0 = ivpsolvers.solver(strategy, ssm=ssm) +strategy = ivpsolvers.strategy_smoother(ssm=ssm) +solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver_ts0.initial_condition() sol = ivpsolve.solve_fixed_grid( @@ -182,8 +182,8 @@ def vf(y, *, t, p): tcoeffs = (u0, vf(u0, t=t0, p=f_args)) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) -solver_ts0 = ivpsolvers.solver(strategy, ssm=ssm) +strategy = ivpsolvers.strategy_smoother(ssm=ssm) +solver_ts0 = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver_ts0.initial_condition() sol = ivpsolve.solve_fixed_grid( diff --git a/docs/examples_parameter_estimation/physics_enhanced_regression_1.py b/docs/examples_parameter_estimation/physics_enhanced_regression_1.py index 3002971f..90668dfc 100644 --- a/docs/examples_parameter_estimation/physics_enhanced_regression_1.py +++ b/docs/examples_parameter_estimation/physics_enhanced_regression_1.py @@ -72,8 +72,8 @@ def solve(p): tcoeffs, output_scale=output_scale, ssm_fact="isotropic" ) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_smoother(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() return ivpsolve.solve_fixed_grid( diff --git a/docs/examples_parameter_estimation/physics_enhanced_regression_2.py b/docs/examples_parameter_estimation/physics_enhanced_regression_2.py index 936d36c2..f78541ae 100644 --- a/docs/examples_parameter_estimation/physics_enhanced_regression_2.py +++ b/docs/examples_parameter_estimation/physics_enhanced_regression_2.py @@ -191,8 +191,8 @@ def solve_fixed(theta, *, ts): tcoeffs, output_scale=output_scale, ssm_fact="isotropic" ) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) init = solver.initial_condition() sol = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver, ssm=ssm) return sol[-1] @@ -208,8 +208,8 @@ def solve_adaptive(theta, *, save_at): tcoeffs, output_scale=output_scale, ssm_fact="isotropic" ) ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) - solver = ivpsolvers.solver(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_filter(ssm=ssm) + solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm) init = solver.initial_condition() diff --git a/docs/examples_quickstart/easy_example.py b/docs/examples_quickstart/easy_example.py index 658fb797..65dfeb0a 100644 --- a/docs/examples_quickstart/easy_example.py +++ b/docs/examples_quickstart/easy_example.py @@ -55,11 +55,11 @@ def vf(y, *, t): # noqa: ARG001 # Set up a state-space model tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, ssm_fact="dense") +ts0 = ivpsolvers.correction_ts1(ode_order=1, ssm=ssm) +strategy = ivpsolvers.strategy_smoother(ssm=ssm) # Build a solver -ts0 = ivpsolvers.correction_ts1(ode_order=1, ssm=ssm) -strategy = ivpsolvers.strategy_smoother(ibm, ts0, ssm=ssm) -solver = ivpsolvers.solver_mle(strategy, ssm=ssm) +solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, ssm=ssm) # - diff --git a/docs/examples_solver_config/conditioning-on-zero-residual.py b/docs/examples_solver_config/conditioning-on-zero-residual.py index 3e00646a..86299d29 100644 --- a/docs/examples_solver_config/conditioning-on-zero-residual.py +++ b/docs/examples_solver_config/conditioning-on-zero-residual.py @@ -78,9 +78,9 @@ def vector_field(y, t): # noqa: ARG001 # Compute the posterior ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="dense") -slr1 = ivpsolvers.correction_ts1(ssm=ssm) -strategy = ivpsolvers.strategy_fixedpoint(ibm, slr1, ssm=ssm) -solver = ivpsolvers.solver(strategy, ssm=ssm) +ts1 = ivpsolvers.correction_ts1(ssm=ssm) +strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) +solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts1, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-1, rtol=1e-2, ssm=ssm) dt0 = ivpsolve.dt0(lambda y: vector_field(y, t=t0), (u0,)) diff --git a/docs/examples_solver_config/dynamic_output_scales.py b/docs/examples_solver_config/dynamic_output_scales.py index 2cce9d52..a709bab6 100644 --- a/docs/examples_solver_config/dynamic_output_scales.py +++ b/docs/examples_solver_config/dynamic_output_scales.py @@ -72,9 +72,9 @@ def vf(*ys, t): # noqa: ARG001 tcoeffs = (u0, vf(u0, t=t0)) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="dense") ts1 = ivpsolvers.correction_ts1(ssm=ssm) -strategy = ivpsolvers.strategy_filter(ibm, ts1, ssm=ssm) -dynamic = ivpsolvers.solver_dynamic(strategy, ssm=ssm) -mle = ivpsolvers.solver_mle(strategy, ssm=ssm) +strategy = ivpsolvers.strategy_filter(ssm=ssm) +dynamic = ivpsolvers.solver_dynamic(strategy, prior=ibm, correction=ts1, ssm=ssm) +mle = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts1, ssm=ssm) # + t0, t1 = 0.0, 3.0 diff --git a/docs/examples_solver_config/posterior_uncertainties.py b/docs/examples_solver_config/posterior_uncertainties.py index 18706fce..5b02742e 100644 --- a/docs/examples_solver_config/posterior_uncertainties.py +++ b/docs/examples_solver_config/posterior_uncertainties.py @@ -61,7 +61,8 @@ def vf(*ys, t): # noqa: ARG001 tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=4) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) -solver = ivpsolvers.solver_mle(ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm), ssm=ssm) +strategy = ivpsolvers.strategy_filter(ssm=ssm) +solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) ts = jnp.linspace(t0, t0 + 2.0, endpoint=True, num=500) @@ -115,9 +116,8 @@ def vf(*ys, t): # noqa: ARG001 # + ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) -solver = ivpsolvers.solver_mle( - ivpsolvers.strategy_fixedpoint(ibm, ts0, ssm=ssm), ssm=ssm -) +strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) +solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver = ivpsolve.adaptive(solver, atol=1e-2, rtol=1e-2, ssm=ssm) ts = jnp.linspace(t0, t0 + 2.0, endpoint=True, num=500) diff --git a/docs/examples_solver_config/second_order_problems.py b/docs/examples_solver_config/second_order_problems.py index 447d76c8..c67c8e89 100644 --- a/docs/examples_solver_config/second_order_problems.py +++ b/docs/examples_solver_config/second_order_problems.py @@ -52,8 +52,8 @@ def vf_1(y, t): # noqa: ARG001 tcoeffs = taylor.odejet_padded_scan(lambda y: vf_1(y, t=t0), (u0,), num=4) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ssm=ssm) -strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) -solver_1st = ivpsolvers.solver_mle(strategy, ssm=ssm) +strategy = ivpsolvers.strategy_filter(ssm=ssm) +solver_1st = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver_1st = ivpsolve.adaptive(solver_1st, atol=1e-5, rtol=1e-5, ssm=ssm) @@ -86,8 +86,8 @@ def vf_2(y, dy, t): # noqa: ARG001 tcoeffs = taylor.odejet_padded_scan(lambda *ys: vf_2(*ys, t=t0), (u0, du0), num=3) ibm, ssm = ivpsolvers.prior_ibm(tcoeffs, output_scale=1.0, ssm_fact="isotropic") ts0 = ivpsolvers.correction_ts0(ode_order=2, ssm=ssm) -strategy = ivpsolvers.strategy_filter(ibm, ts0, ssm=ssm) -solver_2nd = ivpsolvers.solver_mle(strategy, ssm=ssm) +strategy = ivpsolvers.strategy_filter(ssm=ssm) +solver_2nd = ivpsolvers.solver_mle(strategy, prior=ibm, correction=ts0, ssm=ssm) adaptive_solver_2nd = ivpsolve.adaptive(solver_2nd, atol=1e-5, rtol=1e-5, ssm=ssm) diff --git a/docs/examples_solver_config/taylor_coefficients.py b/docs/examples_solver_config/taylor_coefficients.py index d38befcd..1a323be7 100644 --- a/docs/examples_solver_config/taylor_coefficients.py +++ b/docs/examples_solver_config/taylor_coefficients.py @@ -64,8 +64,8 @@ def solve(tc): """Solve the ODE.""" prior, ssm = ivpsolvers.prior_ibm(tc, ssm_fact="dense") ts0 = ivpsolvers.correction_ts0(ssm=ssm) - strategy = ivpsolvers.strategy_fixedpoint(prior, ts0, ssm=ssm) - solver = ivpsolvers.solver_mle(strategy, ssm=ssm) + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) + solver = ivpsolvers.solver_mle(strategy, prior=prior, correction=ts0, ssm=ssm) init = solver.initial_condition() ts = jnp.linspace(t0, t1, endpoint=True, num=10)