Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ PEtab SciML - Scientific Machine Learning Format and Tooling
layers
tutorial

.. toctree::
:caption: Tool developer guide
:maxdepth: 3

training_approaches

.. toctree::
:caption: Python package
:maxdepth: 3
Expand Down
212 changes: 212 additions & 0 deletions doc/training_approaches.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
SciML Training strategies at the PEtab level
============================================

Training (parameter estimating) SciML models can be challenging, and often
standard ML training workflows (e.g., training with Adam for a fixed number of
epochs) fail to find a good minimum or require many training epochs.

Several training strategies have been developed to address this. These include
curriculum learning, multiple shooting, and combined curriculum multiple
shooting, all of which can be implemented at the PEtab abstraction level for
ODE models as well as hybrid PEtab SciML problems. This page describes these
PEtab-level abstractions for tool developers. The PEtab SciML library also
provides reference implementations.

Curriculum learning
-------------------

Curriculum learning is a training strategy where the training problem is
made progressively harder over successive curriculum stages. For PEtab
problems, a curriculum can be defined by gradually increasing the number of
measurement time points (and typically the simulation end time) over a fixed
number of stages. This can be implemented at the PEtab level as follows:

Inputs:

- A PEtab problem (PEtab v1 or v2).
- The number of curriculum stages, ``nStages``.
- A schedule ``n_i`` specifying how many measurements are included in stage
``i``.

1. Sort the measurement table in the input PEtab problem by the ``time``
column.
2. Create ``nStages`` PEtab sub-problems by copying the input problem. For
stage ``i``, filter the time-sorted measurement table to keep the first
``n_i`` measurements.
Comment on lines +33 to +35
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me but I guess the difficulty comes from simulation time rather than number of measurements, due to e.g. blow-up? If so then could change this to:

  1. Group measurements by time
  2. Sort all groups by their time
  3. First train with first group. Then train with first+second group. Then groups 1-3. etc. Only simulate until the last included group.

Anyway, I guess the current example is better here just to show how a curriculum learning variant can be implemented by adjusting the measurement table.

Should we make helper methods in the Python library for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me but I guess the difficulty comes from simulation time rather than number of measurements

Actually, I have seen that splitting on measurements for models with few measurements, but many repeats, this approach does well!

Should we make helper methods in the Python library for this?

Yes, but that is a bit down the line :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me but I guess the difficulty comes from simulation time rather than number of measurements

Actually, I have seen that splitting on measurements for models with few measurements, but many repeats, this approach does well!

Hm, either the benefit is from (1) the reduced simulation time, or from (2) the subsampling of measurements at each time point/window (or both (1) and (2)).

(1) would be what I described above.
(2) would be that all time points are represented but are subsampled (and increasing the subsample size over iterations).
(1)+(2) would be both of these together.

But the algorithm described here doesn't exactly do (1), (2), or (1)+(2), so it's difficult to understand...

Anyway, good enough to merge for now. Maybe later you could add your future paper DOI as a justification/explanation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are looking at something like (1)+(2), but measurements as well as simulation time likely increases the complexity of the loss function landscape

3. Optionally filter the condition, observable and experiment tables to only
include entries required by the measurement table for each sub-problem.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm questioning whether this step should be optional. Would we expect all importers to handle a case where, for example, a condition specifies a time beyond what is is the measurements table?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all importer should handle it, but the filtering is nice because I imagine some importers (like PEtab.jl) will throw warnings


A practical consideration for tools implementing and/or importing curriculum
problems is to keep parameter ordering consistent across stages, which
simplifies transferring parameters between stages.

.. _multiple_shooting:

Multiple shooting
-----------------

In multiple shooting, the simulation time span of each PEtab experiment is
split into windows that are fitted jointly. Each window has its own estimated
initial state values, and a continuity penalty is introduced to encourage a
continuous trajectory between adjacent windows. This can be implemented at the
PEtab level as follows:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The explanation is good, but it's a hard concept to convey in words. Is there a schematic you could add here to explain multiple shooting?

Also before going into the details of implementing this is PEtab you could summarise how the general concept maps onto PEtab. Something like: "each shooting window will correspond to an experiment in PEtab".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also before going into the details of implementing this is PEtab you could summarise how the general concept maps onto PEtab. Something like: "each shooting window will correspond to an experiment in PEtab".

Good idea!

The explanation is good, but it's a hard concept to convey in words. Is there a schematic you could add here to explain multiple shooting?

I plan to also add an image later, as I fully agree


Inputs:

- A PEtab problem (PEtab v2).
- The number of multiple-shooting windows, ``nWindows``.
- A window partition ``[t0_i, tf_i]`` for each window ``i = 1..nWindows`` such
that the union of windows covers the full measurement time range, and
``t0_i != tf_i`` for all windows.
Comment on lines +58 to +60
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the condition (currently t0_i != tf_i) should actually be that all measurements are non-initial and there is at least one measurement at tf_i.

Currently it's not clear that each window must contain non-initial measurements. I would do that here in the choice of the windows rather than the generation of the PEtab problem below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather, would it not be t0_i != tf_i and at least one measurement in [t0_i, tf_i]?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, theoretically I guess you only need one measurement in total, and that measurement should be the final time of the last time window. Then the multiple shooting still works by propagating information backwards via the continuity penalty -- no measurements needed in the other time windows. In this case you still get the benefit of reduced simulation time in each time window (but probably less efficient optimization).

I guess what is missing from this page is the motivation. It's written that training the SciML problems are difficult, but what's missing is what problem each of the methods solves, or the motivation for thinking that this method could make training easier. If the motivation here is to reduce simulation times (in each time window), then the description of the multiple shooting method can be simplified to reflect this -- no need to write about measurements.

Anyway, maybe I am actually misunderstanding the continuity penalty. The observable formula for the penalty is

observableFormula = sqrt(lambda) * (stateId{j} - WINDOW{i}_{expId}_init_stateId{j})

but above that is written

create a condition [...] that assigns each stateId{j} to WINDOW{i}_{expId}_init_stateId{j}

So the continuity penalty is actually w.r.t. to the estimated initial condition, and penalizes the state from moving away from its initial condition? I am not sure how this is a continuity penalty...

- A continuity penalty parameter ``lambda``.

1. Copy the input PEtab problem to create a multiple shooting (MS) PEtab
problem.
2. In the MS PEtab problem, add the penalty weight parameter ``lambda`` to the
parameter table as a non-estimated parameter and set an appropriate nominal
value.
3. For each PEtab experiment with ID ``expId`` in the MS PEtab problem:

1. Create ``nWindows`` new PEtab experiments with IDs ``WINDOW{i}_{expId}``
and set the initial time to ``t0_i`` for window ``i = 1..nWindows``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rendered as a new paragraph, probably fixed by adding the newline for the previous comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really need to improve my rst-skills, thanks for catching the formatting issues!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe using latex here instead of code is clearer. i.e. $t_{0,i}$ instead of t_0_i Or use array syntax like t_0[i].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall though, I think using latex instead of code is better here when it makes sense to -- it's a little hard to read otherwise.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, for ids I would go with code, otherwise LaTex

2. In the experiment table, remove the original experiment IDs and keep
only the windowed experiments. Assign each PEtab condition to the
corresponding window experiment. If a PEtab condition occurs at a
time point that lies in the overlap of windows ``i-1`` and ``i``, assign
the condition to experiment ``WINDOW{i-1}_{expId}``.
3. In the measurement table, assign all measurements in the time interval
``[t0_i, tf_i]`` for experiment ``expId`` to experiment
``WINDOW{i}_{expId}``. If MS windows overlap at time points that contain
measurements, duplicate those measurements so they appear in each
relevant window.
4. For each window ``i > 1`` such that there exists at least one
measurement for ``expId`` at time ``t >= t0_i`` in the original problem
(i.e., at least one subsequent window contains measurements), assign
initial window values and a continuity penalty:

1. In the parameter table, create parameters
``WINDOW{i}_{expId}_init_stateId{j}`` for each model state
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also rendered as a new paragraph

``stateId{j}``. Mark them as estimated and choose appropriate bounds.
2. In the condition table, create a condition with ID
``WINDOW{i}_{expId}_condition0`` that assigns each ``stateId{j}`` to
``WINDOW{i}_{expId}_init_stateId{j}``.
3. Assign condition ``WINDOW{i}_{expId}_condition0`` as the initial
condition for experiment ``WINDOW{i}_{expId}`` at time ``t0_i``.
4. In the observable table, create an observable with ID
``WINDOW{i}_{expId}_penalty_stateId{j}`` for each model state
``stateId{j}`` and set

- ``observableFormula = sqrt(lambda) * (stateId{j} - WINDOW{i}_{expId}_init_stateId{j})``
- ``noiseFormula = 1.0``
- ``noiseDistribution = normal``
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could render this as columns in a table, it might be clearer.


5. In the measurement table, add a row for experiment
``WINDOW{i}_{expId}`` and observable
``WINDOW{i}_{expId}_penalty_stateId{j}`` at time ``t0_i`` with
``measurement = 0.0``. This yields an L2 (quadratic) penalty.

Naive multiple shooting can perform poorly when states have different scales,
since a single penalty weight may be impossible to tune. In this case, a
log-scale penalty such as

``sqrt(lambda) * (log(abs(stateId{j})) - log(WINDOW{i}_{expId}_init_stateId{j}))``
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this definition go somewhere in the PEtab problem or is it up to the implementation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something Python library util functions could support


can be effective, where ``abs`` avoid potential problems with states going
below zero due to numerical errors.
Comment on lines +108 to +115
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the state is going negative due to numerical issues, then it will probably be at/close to zero and log(0) will appear...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hitting exactly 0 is numerically almost impossible, so the abs should be able to fix it here?

Copy link
Member

@dilpath dilpath Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would always do log(abs(stateId{j})+epsilon) because if something can fail it will, but fine 😅


From a runtime performance perspective, the number of initial-window
parameters scales with the number of windows, states, and PEtab experiments,
which can be impractical for larger problems. Moreover, since initial-window
parameters must be estimated, this approach typically performs poorly for
partially observed systems; this is addressed by the curriculum multiple
shooting approach.

Curriculum multiple shooting
----------------------------

Curriculum multiple shooting (CL+MS) combines multiple shooting with a
curriculum schedule. The idea is to start from a multiple-shooting formulation,
which is often easier to train, and then progressively reduce the number of
windows until the original (single-window) problem is recovered. This makes the
approach less sensitive to continuity-penalty tuning and ensures the final
parameters optimize the objective of the original PEtab problem.

Practically, CL+MS defines ``nStages`` curriculum stages. Stage 1 corresponds
to a multiple-shooting problem with ``nWindows = nStages`` windows. In each
subsequent stage, the first ``nWindows-1`` windows are expanded to cover the
union of two adjacent windows, and the last window is dropped. This reduces
the number of windows by one per stage while increasing the time span covered
by each remaining window. The final stage has a single window and corresponds
to the original problem. This can be implemented at the PEtab level as follows:

Inputs:

- A PEtab problem (PEtab v2).
- The number of curriculum stages, ``nStages``.
- An initial window partition ``[t0_i, tf_i]`` for stage 1 with
``i = 1..nStages``, such that the union of windows covers the full
measurement time range and ``t0_i != tf_i`` for all windows.
- A continuity penalty parameter ``lambda`` (used in the multiple-shooting
stages).

1. Construct stage 1 as a multiple-shooting (MS) PEtab problem with
``nWindows = nStages`` using the procedure in
:ref:`Multiple shooting <multiple_shooting>`.
2. For curriculum stage ``k = 2..(nStages-1)``:

1. Set the number of windows to ``nWindows = nStages - k + 1``.
2. Define the MS window time spans for stage ``k`` by merging adjacent
windows from the previous stage:

- For ``i = 1..nWindows`` set ``t0_i^{(k)} = t0_i^{(k-1)}`` and
``tf_i^{(k)} = tf_{i+1}^{(k-1)}``.
- Drop the last window of stage ``k-1``.

3. Create the PEtab problem for stage ``k`` by applying the
:ref:`Multiple shooting <multiple_shooting>` construction with the
updated window partition. In particular:

- Update the experiment table to contain only experiments
``WINDOW{i}_{expId}`` for ``i = 1..nWindows``.
- Reassign and/or duplicate measurements to match
``[t0_i^{(k)}, tf_i^{(k)}]``.
Measurements that in the original problem now appear in multiple
windows must be duplicated so they appear in each window.
- Include window-initial parameters and continuity-penalty observables
for windows ``i > 1`` as in multiple shooting. Note that the penalty
is applied at the initial time point of each window; in PEtab it is
not possible to define a continuity penalty over the full overlap
interval between two windows.
Comment on lines +176 to +179
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems odd to have this note but then take the most trivial choice of completely ignoring the overlap... If penalizing over the full overlap is desirable, at least do more than the initial point?

Also it's technically possible to implement the continuity penalty over the full overlap if done inside the mathematical model. But no need to include that here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems odd to have this note but then take the most trivial choice of completely ignoring the overlap... If penalizing over the full overlap is desirable, at least do more than the initial point?

I cannot see (except altering the model structure), how penalizing over the entire interval is possible? I wanted to put the note here mainly because doing the overlap penalty without model altering is not possible, and thus the approach at least I view we should go with it (benchmarks show that only penalizing the first point is quite sufficient)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I meant it's technically possible by altering the mathematical model, and therefore it's possible in PEtab since the mathematical model is a part of PEtab. But I see your point, outside the mathematical model then there's no nice way to do it. Practically you can do it by introducing a huge number of dummy measurement points.

benchmarks show that only penalizing the first point is quite sufficient

Ah right, then I think it's better to remove

in PEtab it is not possible to define a continuity penalty over the full overlap interval between two windows.

and replace with e.g.

, which in our experience is already very helpful


3. The final stage corresponds to the original PEtab problem. Use the parameter
estimate from stage ``nStages-1`` to initialize optimization for the final
stage.
Comment on lines +181 to +183
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it provide nice training properties to go from duplicating most measurements in the second-last stage, to having no duplicates in the last stage? This seems like it could make training worse, rather than constructing the windows at each stage such that there is no overlap (or weighting the duplicate measurements such that there is effectively no overlap from the perspective of the objective function...).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does really improve training, so even though a bit strange, it does help!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright now I'm very excited for your paper. I would be surprised if it can be justified, but I won't argue with results 😀


A practical consideration for tools implementing and/or importing CL+MS is that
the number of window-initial parameters to estimate changes between stages. To
support transferring parameter values between stages, it can be beneficial to
provide a utility function for mapping parameters between stage problems.
Comment on lines +187 to +188
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO python package?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed :)


Partitioning measurements and time windows
------------------------------------------

The above training approaches require either splitting measurements into
curriculum stages (curriculum learning) or partitioning the simulation time
span into windows (multiple shooting and curriculum multiple shooting). We
recommend that tools supporting these methods provide the splitting schemes
outlined below.

For curriculum learning, the number of measurements per stage, ``n_i``, can be
chosen in two ways: (i) split by unique measurement time points and allocate
``n_i`` accordingly, or (ii) split by the total number of measurements, which
can be effective when there are few unique time points but many repeated
measurements. We recommend supporting both modes, as well as automatic
splitting (e.g., given ``nStages``, compute ``n_i`` for the user) and
user-defined schedules (e.g., explicit ``n_i`` per stage or a maximum time
point per stage).

For multiple shooting, window intervals ``[t0_i, tf_i]`` must be defined. We
recommend supporting automatic window construction (e.g., take ``nWindows`` as
input and allocate windows based on unique measurement time points) as well as
user-specified intervals. As a basic sense check, tools should ensure that
each window contains at least one measurement.