Skip to content

Conversation

@sophie-xhonneux
Copy link
Contributor

Description

[DRAFT] PR for introducing the losses for SSL student-teacher latent losses. This PR will rely on both the abstract loss calculator #1178 as well as the abstract target/aux class #1179

The idea is to get early feedback and notice issues my making code more concrete

Issue Number

Closes #1043

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

Implemented Identity class

TODO: implement EMATeacher
The big question on the EMA teacher side to me is how to allow for a
fleixble teacher and student architecture that can differ

We updated some APIs of the abstract base class to allow the ema_model
forward, subject to change given the loss calculator, which is imho the
second big question mark
Easier to read and as batchsize gets more complicated in SSL this will
be a useful abstraction
It runs so far. Next steps:
 - Route all the config options
 - Start writing the loss functions to understand the state requirements
This involves creating stateful classes for each of the losses and the
EMATeacher being able to run additional neural network heads for these
losses.
@github-actions github-actions bot added initiative Large piece of work covering multiple sprint model Related to model training or definition (not generic infra) labels Nov 5, 2025
TODO: create the various teacher head modules and run them.
TODO: merge the abstract loss calculator and create the SSL one
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't look through the actual computations line by line since it seems this copy-paste from the reference code?

@@ -0,0 +1,304 @@
# (C) Copyright 2025 WeatherGenerator contributors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should go to . They need to be torch.nn.modules because this are NNs, even if they are not necessarily themselves trained. I think ssl_target_processing.py (since you probably still don't like ssl_target_predictors.py)

import torch.nn.functional as F


def lossfunc(t, s, temp):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is not very descriptive :) Maybe latent_logit_loss.py? JEPA uses MAE (and one could conceivably replace by MSE) which are already implemented in loss.py. Ideally we could reuse what is there.

Q *= B # the columns must sum to 1 so that Q is an assignment
return Q.t()

# def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the stale code? What does it implement?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the stale code is there for reference because it needs to go to the loss calculator later

I will do all the clean-up once we are much closer to actually merging :)


def __init__(
self,
patch_out_dim,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to take a dict as arg if we potentially want to implement *TargetProcessing that requires different args.

After much consideration I decided to add the latent prediction heads to
the Model, because they also need to benefit from exponential moving
average of the weights and this gets unnecessarily cumbersome if they
are outside the Model.

TODO: make JEPA different between student and teacher
TODO: use this new structure in EMATeacher
To prevent crazy nesting of model output values we created a ModelOutput
Dataclass (akin to how it is done in huggingface), and we run all the
latent_prediction heads.
Will need adapting based on the abstract loss calculator

Currently is awaiting the streams data branch to check piping of data
and configuring this
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

initiative Large piece of work covering multiple sprint model Related to model training or definition (not generic infra)

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Student-Teacher Loss calculator

3 participants