-
Notifications
You must be signed in to change notification settings - Fork 42
Abstract loss calculator #1210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Abstract loss calculator #1210
Conversation
|
|
||
| loss_fcts: | ||
| - | ||
| - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have a discussion how the config should be structured for this
| loss_lfct = loss_lfct + loss | ||
| losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs | ||
| ctr_substeps += 1 if loss > 0.0 else 0 | ||
| loss_fcts_physical = [[name, w] for name, w in loss_fcts if name.split(":")[0] != "latent"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should do based on the names. It is better to explicitly have sub-keys for latent and physical losses.
config/default_config.yml
Outdated
| - "mse" | ||
| - 1.0 | ||
| # - | ||
| # - "latent:mse" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sophie-xhonneux : the latent loss function are largely determined by the SSL strategies (with some flexibility, e.g. if MAE or MSE for JEPA) and they are also . The latents returned by the Teacher are a dict with entries like 'DINO' : torch.tensor and iBOT : torch.Tensor. The loss function should somehow come from the SSLTargetProcessors, not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be an option, definitively. My plan was simply for the loss calculator to have a SSLLossCalculator that has a loss function for each DINO, iBOT, JEPA-L1, JEPA-L2. Because at the end of the day we have to specify it somewhere and there is some tensor reshaping and stuff to do.
| stddev_all: dict[str, Tensor] | ||
|
|
||
|
|
||
| class LossCalculatorBase: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a different name here. The LossCalculator is now the "management" class that calls/executes the individual loss terms. LossTerms is one option although it's not ideal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could call it LossFunction
| @@ -0,0 +1,377 @@ | |||
| # ruff: noqa: T201 | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would put into the same file as the base class. Having classes in the file name is non-descriptive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One concern is the dependencies between files and that you do not get cycles, which is why a baseclass is sometimes useful. I had this experience with the Abstract target computation
| return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) | ||
|
|
||
|
|
||
| class LossCalculatorLatent(LossCalculatorBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should come from the SSLTargetProcessor. That seems to be the only way to ensure the predictor and the loss are compatible and consistent.
| latent_noise_use_additive_noise: False | ||
| latent_noise_deterministic_latents: True | ||
|
|
||
| loss_fcts: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loss_fcts and loss_fcts_val will be removed from config but currently training fails when removing it because it is used in log_trainer.
Description
Enables generic loss calculation for a given set of predictions-target pairs which can be in latent and/or physical space and part of student-teacher training or diffusion models.
Proposed structure:
Issue Number
Closes #1178
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60