-
Notifications
You must be signed in to change notification settings - Fork 3
Add documentation page describing PEtab-level training abstractions #64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| SciML Training strategies at the PEtab level | ||
| ============================================ | ||
|
|
||
| Training (parameter estimating) SciML models can be challenging, and often | ||
| standard ML training workflows (e.g., training with Adam for a fixed number of | ||
| epochs) fail to find a good minimum or require many training epochs. | ||
|
|
||
| Several training strategies have been developed to address this. These include | ||
| curriculum learning, multiple shooting, and combined curriculum multiple | ||
| shooting, all of which can be implemented at the PEtab abstraction level for | ||
| ODE models as well as hybrid PEtab SciML problems. This page describes these | ||
| PEtab-level abstractions for tool developers. The PEtab SciML library also | ||
| provides reference implementations. | ||
|
|
||
| Curriculum learning | ||
| ------------------- | ||
|
|
||
| Curriculum learning is a training strategy where the training problem is | ||
| made progressively harder over successive curriculum stages. For PEtab | ||
| problems, a curriculum can be defined by gradually increasing the number of | ||
| measurement time points (and typically the simulation end time) over a fixed | ||
| number of stages. This can be implemented at the PEtab level as follows: | ||
|
|
||
| Inputs: | ||
|
|
||
| - A PEtab problem (PEtab v1 or v2). | ||
| - The number of curriculum stages, ``nStages``. | ||
| - A schedule ``n_i`` specifying how many measurements are included in stage | ||
| ``i``. | ||
|
|
||
| 1. Sort the measurement table in the input PEtab problem by the ``time`` | ||
| column. | ||
| 2. Create ``nStages`` PEtab sub-problems by copying the input problem. For | ||
| stage ``i``, filter the time-sorted measurement table to keep the first | ||
| ``n_i`` measurements. | ||
| 3. Optionally filter the condition, observable and experiment tables to only | ||
| include entries required by the measurement table for each sub-problem. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm questioning whether this step should be optional. Would we expect all importers to handle a case where, for example, a condition specifies a time beyond what is is the measurements table?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think all importer should handle it, but the filtering is nice because I imagine some importers (like PEtab.jl) will throw warnings |
||
|
|
||
| A practical consideration for tools implementing and/or importing curriculum | ||
| problems is to keep parameter ordering consistent across stages, which | ||
| simplifies transferring parameters between stages. | ||
|
|
||
| .. _multiple_shooting: | ||
|
|
||
| Multiple shooting | ||
| ----------------- | ||
|
|
||
| In multiple shooting, the simulation time span of each PEtab experiment is | ||
| split into windows that are fitted jointly. Each window has its own estimated | ||
| initial state values, and a continuity penalty is introduced to encourage a | ||
| continuous trajectory between adjacent windows. This can be implemented at the | ||
| PEtab level as follows: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The explanation is good, but it's a hard concept to convey in words. Is there a schematic you could add here to explain multiple shooting? Also before going into the details of implementing this is PEtab you could summarise how the general concept maps onto PEtab. Something like: "each shooting window will correspond to an experiment in PEtab".
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good idea!
I plan to also add an image later, as I fully agree |
||
|
|
||
| Inputs: | ||
|
|
||
| - A PEtab problem (PEtab v2). | ||
| - The number of multiple-shooting windows, ``nWindows``. | ||
| - A window partition ``[t0_i, tf_i]`` for each window ``i = 1..nWindows`` such | ||
| that the union of windows covers the full measurement time range, and | ||
| ``t0_i != tf_i`` for all windows. | ||
|
Comment on lines
+58
to
+60
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the condition (currently Currently it's not clear that each window must contain non-initial measurements. I would do that here in the choice of the windows rather than the generation of the PEtab problem below.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather, would it not be
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, theoretically I guess you only need one measurement in total, and that measurement should be the final time of the last time window. Then the multiple shooting still works by propagating information backwards via the continuity penalty -- no measurements needed in the other time windows. In this case you still get the benefit of reduced simulation time in each time window (but probably less efficient optimization). I guess what is missing from this page is the motivation. It's written that training the SciML problems are difficult, but what's missing is what problem each of the methods solves, or the motivation for thinking that this method could make training easier. If the motivation here is to reduce simulation times (in each time window), then the description of the multiple shooting method can be simplified to reflect this -- no need to write about measurements. Anyway, maybe I am actually misunderstanding the continuity penalty. The observable formula for the penalty is
but above that is written
So the continuity penalty is actually w.r.t. to the estimated initial condition, and penalizes the state from moving away from its initial condition? I am not sure how this is a continuity penalty... |
||
| - A continuity penalty parameter ``lambda``. | ||
|
|
||
| 1. Copy the input PEtab problem to create a multiple shooting (MS) PEtab | ||
| problem. | ||
| 2. In the MS PEtab problem, add the penalty weight parameter ``lambda`` to the | ||
| parameter table as a non-estimated parameter and set an appropriate nominal | ||
| value. | ||
| 3. For each PEtab experiment with ID ``expId`` in the MS PEtab problem: | ||
|
|
||
| 1. Create ``nWindows`` new PEtab experiments with IDs ``WINDOW{i}_{expId}`` | ||
sebapersson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| and set the initial time to ``t0_i`` for window ``i = 1..nWindows``. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is rendered as a new paragraph, probably fixed by adding the newline for the previous comment
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really need to improve my rst-skills, thanks for catching the formatting issues!
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe using latex here instead of code is clearer. i.e.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Overall though, I think using latex instead of code is better here when it makes sense to -- it's a little hard to read otherwise.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, for ids I would go with code, otherwise LaTex |
||
| 2. In the experiment table, remove the original experiment IDs and keep | ||
| only the windowed experiments. Assign each PEtab condition to the | ||
| corresponding window experiment. If a PEtab condition occurs at a | ||
| time point that lies in the overlap of windows ``i-1`` and ``i``, assign | ||
| the condition to experiment ``WINDOW{i-1}_{expId}``. | ||
| 3. In the measurement table, assign all measurements in the time interval | ||
| ``[t0_i, tf_i]`` for experiment ``expId`` to experiment | ||
| ``WINDOW{i}_{expId}``. If MS windows overlap at time points that contain | ||
| measurements, duplicate those measurements so they appear in each | ||
| relevant window. | ||
| 4. For each window ``i > 1`` such that there exists at least one | ||
| measurement for ``expId`` at time ``t >= t0_i`` in the original problem | ||
| (i.e., at least one subsequent window contains measurements), assign | ||
| initial window values and a continuity penalty: | ||
|
|
||
| 1. In the parameter table, create parameters | ||
sebapersson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ``WINDOW{i}_{expId}_init_stateId{j}`` for each model state | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also rendered as a new paragraph |
||
| ``stateId{j}``. Mark them as estimated and choose appropriate bounds. | ||
| 2. In the condition table, create a condition with ID | ||
| ``WINDOW{i}_{expId}_condition0`` that assigns each ``stateId{j}`` to | ||
| ``WINDOW{i}_{expId}_init_stateId{j}``. | ||
| 3. Assign condition ``WINDOW{i}_{expId}_condition0`` as the initial | ||
| condition for experiment ``WINDOW{i}_{expId}`` at time ``t0_i``. | ||
| 4. In the observable table, create an observable with ID | ||
| ``WINDOW{i}_{expId}_penalty_stateId{j}`` for each model state | ||
| ``stateId{j}`` and set | ||
|
|
||
| - ``observableFormula = sqrt(lambda) * (stateId{j} - WINDOW{i}_{expId}_init_stateId{j})`` | ||
| - ``noiseFormula = 1.0`` | ||
| - ``noiseDistribution = normal`` | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could render this as columns in a table, it might be clearer. |
||
|
|
||
| 5. In the measurement table, add a row for experiment | ||
| ``WINDOW{i}_{expId}`` and observable | ||
| ``WINDOW{i}_{expId}_penalty_stateId{j}`` at time ``t0_i`` with | ||
| ``measurement = 0.0``. This yields an L2 (quadratic) penalty. | ||
|
|
||
| Naive multiple shooting can perform poorly when states have different scales, | ||
| since a single penalty weight may be impossible to tune. In this case, a | ||
| log-scale penalty such as | ||
|
|
||
| ``sqrt(lambda) * (log(abs(stateId{j})) - log(WINDOW{i}_{expId}_init_stateId{j}))`` | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this definition go somewhere in the PEtab problem or is it up to the implementation?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is something Python library util functions could support |
||
|
|
||
| can be effective, where ``abs`` avoid potential problems with states going | ||
| below zero due to numerical errors. | ||
|
Comment on lines
+108
to
+115
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the state is going negative due to numerical issues, then it will probably be at/close to zero and
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hitting exactly 0 is numerically almost impossible, so the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Personally I would always do |
||
|
|
||
| From a runtime performance perspective, the number of initial-window | ||
| parameters scales with the number of windows, states, and PEtab experiments, | ||
| which can be impractical for larger problems. Moreover, since initial-window | ||
| parameters must be estimated, this approach typically performs poorly for | ||
| partially observed systems; this is addressed by the curriculum multiple | ||
| shooting approach. | ||
|
|
||
| Curriculum multiple shooting | ||
| ---------------------------- | ||
|
|
||
| Curriculum multiple shooting (CL+MS) combines multiple shooting with a | ||
| curriculum schedule. The idea is to start from a multiple-shooting formulation, | ||
| which is often easier to train, and then progressively reduce the number of | ||
| windows until the original (single-window) problem is recovered. This makes the | ||
| approach less sensitive to continuity-penalty tuning and ensures the final | ||
| parameters optimize the objective of the original PEtab problem. | ||
|
|
||
| Practically, CL+MS defines ``nStages`` curriculum stages. Stage 1 corresponds | ||
| to a multiple-shooting problem with ``nWindows = nStages`` windows. In each | ||
| subsequent stage, the first ``nWindows-1`` windows are expanded to cover the | ||
| union of two adjacent windows, and the last window is dropped. This reduces | ||
| the number of windows by one per stage while increasing the time span covered | ||
| by each remaining window. The final stage has a single window and corresponds | ||
| to the original problem. This can be implemented at the PEtab level as follows: | ||
|
|
||
| Inputs: | ||
|
|
||
| - A PEtab problem (PEtab v2). | ||
| - The number of curriculum stages, ``nStages``. | ||
| - An initial window partition ``[t0_i, tf_i]`` for stage 1 with | ||
| ``i = 1..nStages``, such that the union of windows covers the full | ||
| measurement time range and ``t0_i != tf_i`` for all windows. | ||
| - A continuity penalty parameter ``lambda`` (used in the multiple-shooting | ||
| stages). | ||
|
|
||
| 1. Construct stage 1 as a multiple-shooting (MS) PEtab problem with | ||
| ``nWindows = nStages`` using the procedure in | ||
| :ref:`Multiple shooting <multiple_shooting>`. | ||
| 2. For curriculum stage ``k = 2..(nStages-1)``: | ||
|
|
||
| 1. Set the number of windows to ``nWindows = nStages - k + 1``. | ||
sebapersson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| 2. Define the MS window time spans for stage ``k`` by merging adjacent | ||
| windows from the previous stage: | ||
|
|
||
| - For ``i = 1..nWindows`` set ``t0_i^{(k)} = t0_i^{(k-1)}`` and | ||
| ``tf_i^{(k)} = tf_{i+1}^{(k-1)}``. | ||
| - Drop the last window of stage ``k-1``. | ||
sebapersson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| 3. Create the PEtab problem for stage ``k`` by applying the | ||
| :ref:`Multiple shooting <multiple_shooting>` construction with the | ||
| updated window partition. In particular: | ||
|
|
||
| - Update the experiment table to contain only experiments | ||
| ``WINDOW{i}_{expId}`` for ``i = 1..nWindows``. | ||
| - Reassign and/or duplicate measurements to match | ||
| ``[t0_i^{(k)}, tf_i^{(k)}]``. | ||
| Measurements that in the original problem now appear in multiple | ||
| windows must be duplicated so they appear in each window. | ||
| - Include window-initial parameters and continuity-penalty observables | ||
| for windows ``i > 1`` as in multiple shooting. Note that the penalty | ||
| is applied at the initial time point of each window; in PEtab it is | ||
| not possible to define a continuity penalty over the full overlap | ||
| interval between two windows. | ||
|
Comment on lines
+176
to
+179
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems odd to have this note but then take the most trivial choice of completely ignoring the overlap... If penalizing over the full overlap is desirable, at least do more than the initial point? Also it's technically possible to implement the continuity penalty over the full overlap if done inside the mathematical model. But no need to include that here.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I cannot see (except altering the model structure), how penalizing over the entire interval is possible? I wanted to put the note here mainly because doing the overlap penalty without model altering is not possible, and thus the approach at least I view we should go with it (benchmarks show that only penalizing the first point is quite sufficient)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, I meant it's technically possible by altering the mathematical model, and therefore it's possible in PEtab since the mathematical model is a part of PEtab. But I see your point, outside the mathematical model then there's no nice way to do it. Practically you can do it by introducing a huge number of dummy measurement points.
Ah right, then I think it's better to remove
and replace with e.g.
|
||
|
|
||
| 3. The final stage corresponds to the original PEtab problem. Use the parameter | ||
| estimate from stage ``nStages-1`` to initialize optimization for the final | ||
| stage. | ||
|
Comment on lines
+181
to
+183
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it provide nice training properties to go from duplicating most measurements in the second-last stage, to having no duplicates in the last stage? This seems like it could make training worse, rather than constructing the windows at each stage such that there is no overlap (or weighting the duplicate measurements such that there is effectively no overlap from the perspective of the objective function...).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does really improve training, so even though a bit strange, it does help!
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright now I'm very excited for your paper. I would be surprised if it can be justified, but I won't argue with results 😀 |
||
|
|
||
| A practical consideration for tools implementing and/or importing CL+MS is that | ||
| the number of window-initial parameters to estimate changes between stages. To | ||
| support transferring parameter values between stages, it can be beneficial to | ||
| provide a utility function for mapping parameters between stage problems. | ||
|
Comment on lines
+187
to
+188
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO python package?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed :) |
||
|
|
||
| Partitioning measurements and time windows | ||
| ------------------------------------------ | ||
|
|
||
| The above training approaches require either splitting measurements into | ||
| curriculum stages (curriculum learning) or partitioning the simulation time | ||
| span into windows (multiple shooting and curriculum multiple shooting). We | ||
| recommend that tools supporting these methods provide the splitting schemes | ||
| outlined below. | ||
|
|
||
| For curriculum learning, the number of measurements per stage, ``n_i``, can be | ||
| chosen in two ways: (i) split by unique measurement time points and allocate | ||
| ``n_i`` accordingly, or (ii) split by the total number of measurements, which | ||
| can be effective when there are few unique time points but many repeated | ||
| measurements. We recommend supporting both modes, as well as automatic | ||
| splitting (e.g., given ``nStages``, compute ``n_i`` for the user) and | ||
| user-defined schedules (e.g., explicit ``n_i`` per stage or a maximum time | ||
| point per stage). | ||
|
|
||
| For multiple shooting, window intervals ``[t0_i, tf_i]`` must be defined. We | ||
| recommend supporting automatic window construction (e.g., take ``nWindows`` as | ||
| input and allocate windows based on unique measurement time points) as well as | ||
| user-specified intervals. As a basic sense check, tools should ensure that | ||
| each window contains at least one measurement. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me but I guess the difficulty comes from simulation time rather than number of measurements, due to e.g. blow-up? If so then could change this to:
Anyway, I guess the current example is better here just to show how a curriculum learning variant can be implemented by adjusting the measurement table.
Should we make helper methods in the Python library for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I have seen that splitting on measurements for models with few measurements, but many repeats, this approach does well!
Yes, but that is a bit down the line :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, either the benefit is from (1) the reduced simulation time, or from (2) the subsampling of measurements at each time point/window (or both (1) and (2)).
(1) would be what I described above.
(2) would be that all time points are represented but are subsampled (and increasing the subsample size over iterations).
(1)+(2) would be both of these together.
But the algorithm described here doesn't exactly do (1), (2), or (1)+(2), so it's difficult to understand...
Anyway, good enough to merge for now. Maybe later you could add your future paper DOI as a justification/explanation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are looking at something like (1)+(2), but measurements as well as simulation time likely increases the complexity of the loss function landscape