-
Notifications
You must be signed in to change notification settings - Fork 42
Sophiex/dev/ssl losses 1043 #1205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
Implemented Identity class TODO: implement EMATeacher
The big question on the EMA teacher side to me is how to allow for a fleixble teacher and student architecture that can differ We updated some APIs of the abstract base class to allow the ema_model forward, subject to change given the loss calculator, which is imho the second big question mark
Easier to read and as batchsize gets more complicated in SSL this will be a useful abstraction
It runs so far. Next steps: - Route all the config options - Start writing the loss functions to understand the state requirements
This involves creating stateful classes for each of the losses and the EMATeacher being able to run additional neural network heads for these losses.
TODO: create the various teacher head modules and run them. TODO: merge the abstract loss calculator and create the SSL one
clessig
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't look through the actual computations line by line since it seems this copy-paste from the reference code?
| @@ -0,0 +1,304 @@ | |||
| # (C) Copyright 2025 WeatherGenerator contributors. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should go to . They need to be torch.nn.modules because this are NNs, even if they are not necessarily themselves trained. I think ssl_target_processing.py (since you probably still don't like ssl_target_predictors.py)
| import torch.nn.functional as F | ||
|
|
||
|
|
||
| def lossfunc(t, s, temp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is not very descriptive :) Maybe latent_logit_loss.py? JEPA uses MAE (and one could conceivably replace by MSE) which are already implemented in loss.py. Ideally we could reuse what is there.
| Q *= B # the columns must sum to 1 so that Q is an assignment | ||
| return Q.t() | ||
|
|
||
| # def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove the stale code? What does it implement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the stale code is there for reference because it needs to go to the loss calculator later
I will do all the clean-up once we are much closer to actually merging :)
|
|
||
| def __init__( | ||
| self, | ||
| patch_out_dim, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to take a dict as arg if we potentially want to implement *TargetProcessing that requires different args.
After much consideration I decided to add the latent prediction heads to the Model, because they also need to benefit from exponential moving average of the weights and this gets unnecessarily cumbersome if they are outside the Model. TODO: make JEPA different between student and teacher TODO: use this new structure in EMATeacher
To prevent crazy nesting of model output values we created a ModelOutput Dataclass (akin to how it is done in huggingface), and we run all the latent_prediction heads.
Will need adapting based on the abstract loss calculator Currently is awaiting the streams data branch to check piping of data and configuring this
Description
[DRAFT] PR for introducing the losses for SSL student-teacher latent losses. This PR will rely on both the abstract loss calculator #1178 as well as the abstract target/aux class #1179
The idea is to get early feedback and notice issues my making code more concrete
Issue Number
Closes #1043
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60