Skip to content

Conversation

@Jubeku
Copy link
Contributor

@Jubeku Jubeku commented Nov 5, 2025

Description

Enables generic loss calculation for a given set of predictions-target pairs which can be in latent and/or physical space and part of student-teacher training or diffusion models.

Proposed structure:

  • Classes:
    • LossCalculator class: iterates over all special loss classes and returns a combined loss object.
    • LossCalculatorBase class: generic loss calculator structure
    • LossCalculatorPhysical, LossCalculatorLatent, etc.: specific subclasses of LossCalculatorBase
  • DataClasses:
    • LossValues: Predefines the items that are returned by the loss calculator classes
    • InputOutput: Predefines the items/structure of model predictions and targets (can include forecast step logic)

Issue Number

Closes #1178

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@Jubeku Jubeku self-assigned this Nov 5, 2025
@github-actions github-actions bot added the model Related to model training or definition (not generic infra) label Nov 5, 2025
@clessig clessig self-requested a review November 6, 2025 20:54

loss_fcts:
-
-
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a discussion how the config should be structured for this

loss_lfct = loss_lfct + loss
losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs
ctr_substeps += 1 if loss > 0.0 else 0
loss_fcts_physical = [[name, w] for name, w in loss_fcts if name.split(":")[0] != "latent"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should do based on the names. It is better to explicitly have sub-keys for latent and physical losses.

- "mse"
- 1.0
# -
# - "latent:mse"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sophie-xhonneux : the latent loss function are largely determined by the SSL strategies (with some flexibility, e.g. if MAE or MSE for JEPA) and they are also . The latents returned by the Teacher are a dict with entries like 'DINO' : torch.tensor and iBOT : torch.Tensor. The loss function should somehow come from the SSLTargetProcessors, not?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be an option, definitively. My plan was simply for the loss calculator to have a SSLLossCalculator that has a loss function for each DINO, iBOT, JEPA-L1, JEPA-L2. Because at the end of the day we have to specify it somewhere and there is some tensor reshaping and stuff to do.

stddev_all: dict[str, Tensor]


class LossCalculatorBase:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a different name here. The LossCalculator is now the "management" class that calls/executes the individual loss terms. LossTerms is one option although it's not ideal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could call it LossFunction

@@ -0,0 +1,377 @@
# ruff: noqa: T201
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put into the same file as the base class. Having classes in the file name is non-descriptive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One concern is the dependencies between files and that you do not get cycles, which is why a baseclass is sometimes useful. I had this experience with the Abstract target computation

return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all)


class LossCalculatorLatent(LossCalculatorBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should come from the SSLTargetProcessor. That seems to be the only way to ensure the predictor and the loss are compatible and consistent.

@sophie-xhonneux @shmh40 ?

latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True

loss_fcts:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss_fcts and loss_fcts_val will be removed from config but currently training fails when removing it because it is used in log_trainer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Related to model training or definition (not generic infra)

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Abstract Loss Calculators

4 participants