Skip to content

Comments

Add documentation page describing PEtab-level training abstractions#64

Open
sebapersson wants to merge 4 commits intomainfrom
training_approaches
Open

Add documentation page describing PEtab-level training abstractions#64
sebapersson wants to merge 4 commits intomainfrom
training_approaches

Conversation

@sebapersson
Copy link
Collaborator

Training SciML problems can be challenging, and we are currently evaluating several
strategies (curriculum learning, multiple shooting, and curriculum multiple shooting).
This PR adds a docs page that describes how these strategies can be represented at the
PEtab level, which can also be useful for non-SciML PEtab problems.

Longer-term, the goal is for the PEtab SciML library to provide a reference implementation.
This page will help with that, but also to support tool developers across
ecosystems (e.g., Julia) without requiring a Python dependency.

Copy link
Member

@dilpath dilpath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Some list formatting issues.

Also some questions out of curiosity but no need to address them now.

Comment on lines +33 to +35
2. Create ``nStages`` PEtab sub-problems by copying the input problem. For
stage ``i``, filter the time-sorted measurement table to keep the first
``n_i`` measurements.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me but I guess the difficulty comes from simulation time rather than number of measurements, due to e.g. blow-up? If so then could change this to:

  1. Group measurements by time
  2. Sort all groups by their time
  3. First train with first group. Then train with first+second group. Then groups 1-3. etc. Only simulate until the last included group.

Anyway, I guess the current example is better here just to show how a curriculum learning variant can be implemented by adjusting the measurement table.

Should we make helper methods in the Python library for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me but I guess the difficulty comes from simulation time rather than number of measurements

Actually, I have seen that splitting on measurements for models with few measurements, but many repeats, this approach does well!

Should we make helper methods in the Python library for this?

Yes, but that is a bit down the line :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me but I guess the difficulty comes from simulation time rather than number of measurements

Actually, I have seen that splitting on measurements for models with few measurements, but many repeats, this approach does well!

Hm, either the benefit is from (1) the reduced simulation time, or from (2) the subsampling of measurements at each time point/window (or both (1) and (2)).

(1) would be what I described above.
(2) would be that all time points are represented but are subsampled (and increasing the subsample size over iterations).
(1)+(2) would be both of these together.

But the algorithm described here doesn't exactly do (1), (2), or (1)+(2), so it's difficult to understand...

Anyway, good enough to merge for now. Maybe later you could add your future paper DOI as a justification/explanation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are looking at something like (1)+(2), but measurements as well as simulation time likely increases the complexity of the loss function landscape

Comment on lines +58 to +60
- A window partition ``[t0_i, tf_i]`` for each window ``i = 1..nWindows`` such
that the union of windows covers the full measurement time range, and
``t0_i != tf_i`` for all windows.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the condition (currently t0_i != tf_i) should actually be that all measurements are non-initial and there is at least one measurement at tf_i.

Currently it's not clear that each window must contain non-initial measurements. I would do that here in the choice of the windows rather than the generation of the PEtab problem below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather, would it not be t0_i != tf_i and at least one measurement in [t0_i, tf_i]?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, theoretically I guess you only need one measurement in total, and that measurement should be the final time of the last time window. Then the multiple shooting still works by propagating information backwards via the continuity penalty -- no measurements needed in the other time windows. In this case you still get the benefit of reduced simulation time in each time window (but probably less efficient optimization).

I guess what is missing from this page is the motivation. It's written that training the SciML problems are difficult, but what's missing is what problem each of the methods solves, or the motivation for thinking that this method could make training easier. If the motivation here is to reduce simulation times (in each time window), then the description of the multiple shooting method can be simplified to reflect this -- no need to write about measurements.

Anyway, maybe I am actually misunderstanding the continuity penalty. The observable formula for the penalty is

observableFormula = sqrt(lambda) * (stateId{j} - WINDOW{i}_{expId}_init_stateId{j})

but above that is written

create a condition [...] that assigns each stateId{j} to WINDOW{i}_{expId}_init_stateId{j}

So the continuity penalty is actually w.r.t. to the estimated initial condition, and penalizes the state from moving away from its initial condition? I am not sure how this is a continuity penalty...

value.
3. For each PEtab experiment with ID ``expId`` in the MS PEtab problem:
1. Create ``nWindows`` new PEtab experiments with IDs ``WINDOW{i}_{expId}``
and set the initial time to ``t0_i`` for window ``i = 1..nWindows``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rendered as a new paragraph, probably fixed by adding the newline for the previous comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really need to improve my rst-skills, thanks for catching the formatting issues!

(i.e., at least one subsequent window contains measurements), assign
initial window values and a continuity penalty:
1. In the parameter table, create parameters
``WINDOW{i}_{expId}_init_stateId{j}`` for each model state
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also rendered as a new paragraph

value.
3. For each PEtab experiment with ID ``expId`` in the MS PEtab problem:
1. Create ``nWindows`` new PEtab experiments with IDs ``WINDOW{i}_{expId}``
and set the initial time to ``t0_i`` for window ``i = 1..nWindows``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe using latex here instead of code is clearer. i.e. $t_{0,i}$ instead of t_0_i Or use array syntax like t_0[i].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall though, I think using latex instead of code is better here when it makes sense to -- it's a little hard to read otherwise.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, for ids I would go with code, otherwise LaTex

Comment on lines +171 to +174
for windows ``i > 1`` as in multiple shooting. Note that the penalty
is applied at the initial time point of each window; in PEtab it is
not possible to define a continuity penalty over the full overlap
interval between two windows.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems odd to have this note but then take the most trivial choice of completely ignoring the overlap... If penalizing over the full overlap is desirable, at least do more than the initial point?

Also it's technically possible to implement the continuity penalty over the full overlap if done inside the mathematical model. But no need to include that here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems odd to have this note but then take the most trivial choice of completely ignoring the overlap... If penalizing over the full overlap is desirable, at least do more than the initial point?

I cannot see (except altering the model structure), how penalizing over the entire interval is possible? I wanted to put the note here mainly because doing the overlap penalty without model altering is not possible, and thus the approach at least I view we should go with it (benchmarks show that only penalizing the first point is quite sufficient)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I meant it's technically possible by altering the mathematical model, and therefore it's possible in PEtab since the mathematical model is a part of PEtab. But I see your point, outside the mathematical model then there's no nice way to do it. Practically you can do it by introducing a huge number of dummy measurement points.

benchmarks show that only penalizing the first point is quite sufficient

Ah right, then I think it's better to remove

in PEtab it is not possible to define a continuity penalty over the full overlap interval between two windows.

and replace with e.g.

, which in our experience is already very helpful

Comment on lines +176 to +178
3. The final stage corresponds to the original PEtab problem. Use the parameter
estimate from stage ``nStages-1`` to initialize optimization for the final
stage.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it provide nice training properties to go from duplicating most measurements in the second-last stage, to having no duplicates in the last stage? This seems like it could make training worse, rather than constructing the windows at each stage such that there is no overlap (or weighting the duplicate measurements such that there is effectively no overlap from the perspective of the objective function...).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does really improve training, so even though a bit strange, it does help!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright now I'm very excited for your paper. I would be surprised if it can be justified, but I won't argue with results 😀

Comment on lines +182 to +183
support transferring parameter values between stages, it can be beneficial to
provide a utility function for mapping parameters between stage problems.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO python package?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed :)

@sebapersson sebapersson added the draft The PR is work in progress label Feb 19, 2026
Copy link
Collaborator

@BSnelling BSnelling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really helpful!

On the whole I'm not clear yet on where the line is between helper functions we would implement in this repo and helpers that importers should implement to support these training strategies. The last section implies importers would implement splitting for curriculum learning and automatic window construction but could those not also make sense as petab_sciml helpers?

stage ``i``, filter the time-sorted measurement table to keep the first
``n_i`` measurements.
3. Optionally filter the condition, observable and experiment tables to only
include entries required by the measurement table for each sub-problem.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm questioning whether this step should be optional. Would we expect all importers to handle a case where, for example, a condition specifies a time beyond what is is the measurements table?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all importer should handle it, but the filtering is nice because I imagine some importers (like PEtab.jl) will throw warnings


- ``observableFormula = sqrt(lambda) * (stateId{j} - WINDOW{i}_{expId}_init_stateId{j})``
- ``noiseFormula = 1.0``
- ``noiseDistribution = normal``
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could render this as columns in a table, it might be clearer.

split into windows that are fitted jointly. Each window has its own estimated
initial state values, and a continuity penalty is introduced to encourage a
continuous trajectory between adjacent windows. This can be implemented at the
PEtab level as follows:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The explanation is good, but it's a hard concept to convey in words. Is there a schematic you could add here to explain multiple shooting?

Also before going into the details of implementing this is PEtab you could summarise how the general concept maps onto PEtab. Something like: "each shooting window will correspond to an experiment in PEtab".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also before going into the details of implementing this is PEtab you could summarise how the general concept maps onto PEtab. Something like: "each shooting window will correspond to an experiment in PEtab".

Good idea!

The explanation is good, but it's a hard concept to convey in words. Is there a schematic you could add here to explain multiple shooting?

I plan to also add an image later, as I fully agree

since a single penalty weight may be impossible to tune. In this case, a
log-scale penalty such as

``sqrt(lambda) * (log(abs(stateId{j})) - log(WINDOW{i}_{expId}_init_stateId{j}))``
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this definition go somewhere in the PEtab problem or is it up to the implementation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something Python library util functions could support

@sebapersson
Copy link
Collaborator Author

On the whole I'm not clear yet on where the line is between helper functions we would implement in this repo and helpers that importers should implement to support these training strategies. The last section implies importers would implement splitting for curriculum learning and automatic window construction but could those not also make sense as petab_sciml helpers?

Would definitely make sense, I think this is something to discuss a bit down the line which helpers would be needed :)

Co-authored-by: Dilan Pathirana <59329744+dilpath@users.noreply.github.com>
Co-authored-by: BSnelling <branwen.snelling@crick.ac.uk>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

draft The PR is work in progress

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants