Add documentation page describing PEtab-level training abstractions#64
Add documentation page describing PEtab-level training abstractions#64sebapersson wants to merge 4 commits intomainfrom
Conversation
dilpath
left a comment
There was a problem hiding this comment.
Looks good! Some list formatting issues.
Also some questions out of curiosity but no need to address them now.
| 2. Create ``nStages`` PEtab sub-problems by copying the input problem. For | ||
| stage ``i``, filter the time-sorted measurement table to keep the first | ||
| ``n_i`` measurements. |
There was a problem hiding this comment.
Sounds good to me but I guess the difficulty comes from simulation time rather than number of measurements, due to e.g. blow-up? If so then could change this to:
- Group measurements by time
- Sort all groups by their time
- First train with first group. Then train with first+second group. Then groups 1-3. etc. Only simulate until the last included group.
Anyway, I guess the current example is better here just to show how a curriculum learning variant can be implemented by adjusting the measurement table.
Should we make helper methods in the Python library for this?
There was a problem hiding this comment.
Sounds good to me but I guess the difficulty comes from simulation time rather than number of measurements
Actually, I have seen that splitting on measurements for models with few measurements, but many repeats, this approach does well!
Should we make helper methods in the Python library for this?
Yes, but that is a bit down the line :)
There was a problem hiding this comment.
Sounds good to me but I guess the difficulty comes from simulation time rather than number of measurements
Actually, I have seen that splitting on measurements for models with few measurements, but many repeats, this approach does well!
Hm, either the benefit is from (1) the reduced simulation time, or from (2) the subsampling of measurements at each time point/window (or both (1) and (2)).
(1) would be what I described above.
(2) would be that all time points are represented but are subsampled (and increasing the subsample size over iterations).
(1)+(2) would be both of these together.
But the algorithm described here doesn't exactly do (1), (2), or (1)+(2), so it's difficult to understand...
Anyway, good enough to merge for now. Maybe later you could add your future paper DOI as a justification/explanation.
There was a problem hiding this comment.
I think we are looking at something like (1)+(2), but measurements as well as simulation time likely increases the complexity of the loss function landscape
| - A window partition ``[t0_i, tf_i]`` for each window ``i = 1..nWindows`` such | ||
| that the union of windows covers the full measurement time range, and | ||
| ``t0_i != tf_i`` for all windows. |
There was a problem hiding this comment.
I guess the condition (currently t0_i != tf_i) should actually be that all measurements are non-initial and there is at least one measurement at tf_i.
Currently it's not clear that each window must contain non-initial measurements. I would do that here in the choice of the windows rather than the generation of the PEtab problem below.
There was a problem hiding this comment.
Rather, would it not be t0_i != tf_i and at least one measurement in [t0_i, tf_i]?
There was a problem hiding this comment.
Hm, theoretically I guess you only need one measurement in total, and that measurement should be the final time of the last time window. Then the multiple shooting still works by propagating information backwards via the continuity penalty -- no measurements needed in the other time windows. In this case you still get the benefit of reduced simulation time in each time window (but probably less efficient optimization).
I guess what is missing from this page is the motivation. It's written that training the SciML problems are difficult, but what's missing is what problem each of the methods solves, or the motivation for thinking that this method could make training easier. If the motivation here is to reduce simulation times (in each time window), then the description of the multiple shooting method can be simplified to reflect this -- no need to write about measurements.
Anyway, maybe I am actually misunderstanding the continuity penalty. The observable formula for the penalty is
observableFormula = sqrt(lambda) * (stateId{j} - WINDOW{i}_{expId}_init_stateId{j})
but above that is written
create a condition [...] that assigns each
stateId{j}toWINDOW{i}_{expId}_init_stateId{j}
So the continuity penalty is actually w.r.t. to the estimated initial condition, and penalizes the state from moving away from its initial condition? I am not sure how this is a continuity penalty...
| value. | ||
| 3. For each PEtab experiment with ID ``expId`` in the MS PEtab problem: | ||
| 1. Create ``nWindows`` new PEtab experiments with IDs ``WINDOW{i}_{expId}`` | ||
| and set the initial time to ``t0_i`` for window ``i = 1..nWindows``. |
There was a problem hiding this comment.
This is rendered as a new paragraph, probably fixed by adding the newline for the previous comment
There was a problem hiding this comment.
I really need to improve my rst-skills, thanks for catching the formatting issues!
| (i.e., at least one subsequent window contains measurements), assign | ||
| initial window values and a continuity penalty: | ||
| 1. In the parameter table, create parameters | ||
| ``WINDOW{i}_{expId}_init_stateId{j}`` for each model state |
There was a problem hiding this comment.
Also rendered as a new paragraph
| value. | ||
| 3. For each PEtab experiment with ID ``expId`` in the MS PEtab problem: | ||
| 1. Create ``nWindows`` new PEtab experiments with IDs ``WINDOW{i}_{expId}`` | ||
| and set the initial time to ``t0_i`` for window ``i = 1..nWindows``. |
There was a problem hiding this comment.
Maybe using latex here instead of code is clearer. i.e. t_0_i Or use array syntax like t_0[i].
There was a problem hiding this comment.
Overall though, I think using latex instead of code is better here when it makes sense to -- it's a little hard to read otherwise.
There was a problem hiding this comment.
I agree, for ids I would go with code, otherwise LaTex
| for windows ``i > 1`` as in multiple shooting. Note that the penalty | ||
| is applied at the initial time point of each window; in PEtab it is | ||
| not possible to define a continuity penalty over the full overlap | ||
| interval between two windows. |
There was a problem hiding this comment.
It seems odd to have this note but then take the most trivial choice of completely ignoring the overlap... If penalizing over the full overlap is desirable, at least do more than the initial point?
Also it's technically possible to implement the continuity penalty over the full overlap if done inside the mathematical model. But no need to include that here.
There was a problem hiding this comment.
It seems odd to have this note but then take the most trivial choice of completely ignoring the overlap... If penalizing over the full overlap is desirable, at least do more than the initial point?
I cannot see (except altering the model structure), how penalizing over the entire interval is possible? I wanted to put the note here mainly because doing the overlap penalty without model altering is not possible, and thus the approach at least I view we should go with it (benchmarks show that only penalizing the first point is quite sufficient)
There was a problem hiding this comment.
Agreed, I meant it's technically possible by altering the mathematical model, and therefore it's possible in PEtab since the mathematical model is a part of PEtab. But I see your point, outside the mathematical model then there's no nice way to do it. Practically you can do it by introducing a huge number of dummy measurement points.
benchmarks show that only penalizing the first point is quite sufficient
Ah right, then I think it's better to remove
in PEtab it is not possible to define a continuity penalty over the full overlap interval between two windows.
and replace with e.g.
, which in our experience is already very helpful
| 3. The final stage corresponds to the original PEtab problem. Use the parameter | ||
| estimate from stage ``nStages-1`` to initialize optimization for the final | ||
| stage. |
There was a problem hiding this comment.
Does it provide nice training properties to go from duplicating most measurements in the second-last stage, to having no duplicates in the last stage? This seems like it could make training worse, rather than constructing the windows at each stage such that there is no overlap (or weighting the duplicate measurements such that there is effectively no overlap from the perspective of the objective function...).
There was a problem hiding this comment.
This does really improve training, so even though a bit strange, it does help!
There was a problem hiding this comment.
Alright now I'm very excited for your paper. I would be surprised if it can be justified, but I won't argue with results 😀
| support transferring parameter values between stages, it can be beneficial to | ||
| provide a utility function for mapping parameters between stage problems. |
BSnelling
left a comment
There was a problem hiding this comment.
This looks really helpful!
On the whole I'm not clear yet on where the line is between helper functions we would implement in this repo and helpers that importers should implement to support these training strategies. The last section implies importers would implement splitting for curriculum learning and automatic window construction but could those not also make sense as petab_sciml helpers?
| stage ``i``, filter the time-sorted measurement table to keep the first | ||
| ``n_i`` measurements. | ||
| 3. Optionally filter the condition, observable and experiment tables to only | ||
| include entries required by the measurement table for each sub-problem. |
There was a problem hiding this comment.
I'm questioning whether this step should be optional. Would we expect all importers to handle a case where, for example, a condition specifies a time beyond what is is the measurements table?
There was a problem hiding this comment.
I think all importer should handle it, but the filtering is nice because I imagine some importers (like PEtab.jl) will throw warnings
|
|
||
| - ``observableFormula = sqrt(lambda) * (stateId{j} - WINDOW{i}_{expId}_init_stateId{j})`` | ||
| - ``noiseFormula = 1.0`` | ||
| - ``noiseDistribution = normal`` |
There was a problem hiding this comment.
You could render this as columns in a table, it might be clearer.
| split into windows that are fitted jointly. Each window has its own estimated | ||
| initial state values, and a continuity penalty is introduced to encourage a | ||
| continuous trajectory between adjacent windows. This can be implemented at the | ||
| PEtab level as follows: |
There was a problem hiding this comment.
The explanation is good, but it's a hard concept to convey in words. Is there a schematic you could add here to explain multiple shooting?
Also before going into the details of implementing this is PEtab you could summarise how the general concept maps onto PEtab. Something like: "each shooting window will correspond to an experiment in PEtab".
There was a problem hiding this comment.
Also before going into the details of implementing this is PEtab you could summarise how the general concept maps onto PEtab. Something like: "each shooting window will correspond to an experiment in PEtab".
Good idea!
The explanation is good, but it's a hard concept to convey in words. Is there a schematic you could add here to explain multiple shooting?
I plan to also add an image later, as I fully agree
| since a single penalty weight may be impossible to tune. In this case, a | ||
| log-scale penalty such as | ||
|
|
||
| ``sqrt(lambda) * (log(abs(stateId{j})) - log(WINDOW{i}_{expId}_init_stateId{j}))`` |
There was a problem hiding this comment.
Does this definition go somewhere in the PEtab problem or is it up to the implementation?
There was a problem hiding this comment.
This is something Python library util functions could support
Would definitely make sense, I think this is something to discuss a bit down the line which helpers would be needed :) |
Co-authored-by: Dilan Pathirana <59329744+dilpath@users.noreply.github.com> Co-authored-by: BSnelling <branwen.snelling@crick.ac.uk>
Training SciML problems can be challenging, and we are currently evaluating several
strategies (curriculum learning, multiple shooting, and curriculum multiple shooting).
This PR adds a docs page that describes how these strategies can be represented at the
PEtab level, which can also be useful for non-SciML PEtab problems.
Longer-term, the goal is for the PEtab SciML library to provide a reference implementation.
This page will help with that, but also to support tool developers across
ecosystems (e.g., Julia) without requiring a Python dependency.