Pass the current epoch to the aggregate function of Weighting Interface #617
GiovanniCanali
started this conversation in
Ideas
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Currently, the
aggregate
method of any class inheriting fromWeightingInterface
accepts onlylosses
, a dictionary storing the loss per condition.In principle, this should be enough. However, some weighting schemes require updating the weights, which can be computationally expensive -- for instance, self-adaptive weighting as described in "Simulating Three-dimensional Turbulence with Physics-informed Neural Networks" (see
PirateNetwork
). In such cases, it may be preferable to perform these computations only everyk
epochs using a simple modulus operation.Unfortunately, this is not currently possible: there is no straightforward way for the
WeightingInterface
to be aware of the current epoch. While using trivial internal counters partially addresses the problem, they fail when batching is involved, as they would count each epoch multiple times.Therefore, I propose adding the current epoch as an argument to the
aggregate
method ofWeightingInterface
, allowing it to be passed directly from the solver.Beta Was this translation helpful? Give feedback.
All reactions