diff --git a/NOTICE b/NOTICE new file mode 100644 index 000000000..34bf42bfa --- /dev/null +++ b/NOTICE @@ -0,0 +1,10 @@ +This project includes code derived from project "DINOv2: Learning Robust Visual Features without Supervision", +originally developed by Meta Platforms, Inc. and affiliates, +licensed under the Apache License, Version 2.0. + +Original NOTICE from project DINOv2 +-------------------------------------- + +N/A + + diff --git a/config/default_config.yml b/config/default_config.yml index 620f5c4ae..871e0129c 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -1,5 +1,25 @@ +################# +### Data ### +################# streams_directory: "./config/streams/era5_1deg/" +start_date: 197901010000 +end_date: 202012310000 +start_date_val: 202101010000 +end_date_val: 202201010000 +len_hrs: 6 +step_hrs: 6 +input_window_steps: 1 + +val_initial: False + +loader_num_workers: 8 +log_validation: 0 +analysis_streams_output: ["ERA5"] + +################# +### Model ### +################# embed_orientation: "channels" embed_local_coords: True embed_centroids_local_coords: False @@ -40,6 +60,17 @@ pred_self_attention: True pred_dyadic_dims: False pred_mlp_adaln: True +healpix_level: 5 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +################# +### Forecast ### +################# # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder forecast_offset : 0 @@ -53,24 +84,11 @@ fe_dropout_rate: 0.1 fe_with_qk_lnorm: True impute_latent_noise_std: 0.0 # 1e-4 -healpix_level: 5 - -with_mixed_precision: True -with_flash_attention: True -compile_model: False -with_fsdp: True -attention_dtype: bf16 -mlp_norm_eps: 1e-5 -norm_eps: 1e-4 - -latent_noise_kl_weight: 0.0 # 1e-5 -latent_noise_gamma: 2.0 -latent_noise_saturate_encodings: 5 -latent_noise_use_additive_noise: False -latent_noise_deterministic_latents: True - +################# +### Training ### +################# loss_fcts: - - + - - "mse" - 1.0 loss_fcts_val: @@ -91,9 +109,30 @@ validate_with_ema: True ema_ramp_up_ratio: 0.09 ema_halflife_in_thousands: 1e-3 -# training mode: "forecast" or "masking" (masked token modeling) +# training mode: "forecast" or "masking" (masked token modeling) or "student-teacher" # for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "masking" +training_mode: "student-teacher" +training_mode_config: { + "losses" : [ "iBOT", "DINO", "JEPA" ], + "shared_heads": False, + "student_temp": 0.1, + "teacher_temp": 0.1, + "dino_out_dim": 65536, # 2**16 + "ibot_patch_out_dim": 65536, # 2**16 + "teacher_style": "softmax_center", + "center_momentum": 0.9, + "target_and_aux_calc": "EMATeacher", + "teacher_model": {} +} +# training_mode: "masking" +# training_mode_config: {"losses": {LossPhysical: [['mse', 1.0]],} +# } +# # training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]], +# # LossLatent: [['mse', 0.3]], +# # LossStudentTeacher: [{'iBOT': {}, 'JEPA': {options}}],} +# # } +# validation_mode_config: {"losses": {LossPhysical: [['mse', 1.0]],} +# } # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # sample the masking rate (with normal distribution centered at masking_rate) @@ -113,6 +152,17 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } +################# +### Trainer ### +################# +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + num_epochs: 32 samples_per_epoch: 4096 samples_per_validation: 512 @@ -135,20 +185,6 @@ norm_type: "LayerNorm" nn_module: "te" log_grad_norms: False -start_date: 197901010000 -end_date: 202012310000 -start_date_val: 202101010000 -end_date_val: 202201010000 -len_hrs: 6 -step_hrs: 6 -input_window_steps: 1 - -val_initial: False - -loader_num_workers: 8 -log_validation: 0 -analysis_streams_output: ["ERA5"] - istep: 0 run_history: [] @@ -161,3 +197,4 @@ train_log_freq: terminal: 10 metrics: 20 checkpoint: 250 + diff --git a/packages/evaluate/src/weathergen/evaluate/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export_inference.py index 2c0cb4243..4e0bd7d6e 100755 --- a/packages/evaluate/src/weathergen/evaluate/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export_inference.py @@ -61,6 +61,7 @@ def detect_grid_type(input_data_array: xr.DataArray) -> str: # Otherwise it's Gaussian (irregular spacing or reduced grid) return "gaussian" + def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]: """ Find all the pressure levels for each variable using regex and returns a dictionary @@ -90,6 +91,7 @@ def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]: pl = list(set(pl)) return var_dict, pl + def reshape_dataset_adaptive(input_data_array: xr.DataArray) -> xr.Dataset: """ Reshape dataset while preserving grid structure (regular or Gaussian). @@ -176,8 +178,6 @@ def add_gaussian_grid_metadata(ds: xr.Dataset, grid_info: dict | None = None) -> return ds - - def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset: """ Add CF conventions to the dataset attributes. @@ -201,6 +201,7 @@ def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset: ds.attrs["Conventions"] = "CF-1.12" return ds + def cf_parser_gaussian_aware(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset: """ Modified CF parser that handles both regular and Gaussian grids. @@ -323,6 +324,7 @@ def cf_parser_gaussian_aware(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset: return dataset + def output_filename( prefix: str, run_id: str, diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index ca5ee6601..844a0dfb1 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -226,7 +226,9 @@ def __init__( if cf.training_mode == "forecast": self.tokenizer = TokenizerForecast(cf.healpix_level) - elif cf.training_mode == "masking": + elif ( + cf.training_mode == "masking" or cf.training_mode == "student-teacher" + ): # TODO student-teacher data masker = Masker(cf) self.tokenizer = TokenizerMasking(cf.healpix_level, masker) assert self.forecast_offset == 0, "masked token modeling requires auto-encoder training" diff --git a/src/weathergen/model/ema.py b/src/weathergen/model/ema.py index 7acbbf9f0..207362b4f 100644 --- a/src/weathergen/model/ema.py +++ b/src/weathergen/model/ema.py @@ -44,7 +44,7 @@ def reset(self): self.ema_model.to_empty(device="cuda") maybe_sharded_sd = self.original_model.state_dict() # this copies correctly tested in pdb - mkeys, ukeys = self.ema_model.load_state_dict(maybe_sharded_sd, strict=True, assign=False) + mkeys, ukeys = self.ema_model.load_state_dict(maybe_sharded_sd, strict=False, assign=False) @torch.no_grad() def update(self, cur_step, batch_size): @@ -53,7 +53,7 @@ def update(self, cur_step, batch_size): halflife_steps = min(halflife_steps, cur_step / 1e3 * self.rampup_ratio) beta = 0.5 ** (batch_size / max(halflife_steps * 1e3, 1e-6)) for p_net, p_ema in zip( - self.original_model.parameters(), self.ema_model.parameters(), strict=True + self.original_model.parameters(), self.ema_model.parameters(), strict=False ): p_ema.lerp_(p_net, 1 - beta) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 3351dabc4..3ece50a03 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -732,3 +732,15 @@ def forward(self, latent, output, latent_lens, output_lens, coordinates): else output ) return output + + +class LatentPredictionHead(nn.Module): + def __init__(self, name, in_dim, out_dim): + super().__init__() + + self.name = name + # For now this is a Linear Layer TBD what this architecture should be + self.layer = nn.Linear(in_dim, out_dim, bias=False) + + def forward(self, x): + return self.layer(x) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 000f36735..a9cbf1c94 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -9,10 +9,12 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import copy import logging import math import warnings from pathlib import Path +import dataclasses import astropy_healpix as hp import astropy_healpix.healpy @@ -32,6 +34,7 @@ LocalAssimilationEngine, TargetPredictionEngine, TargetPredictionEngineClassic, + LatentPredictionHead, ) from weathergen.model.layers import MLP, NamedLinear from weathergen.model.parametrised_prob_dist import LatentInterpolator @@ -42,6 +45,16 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass +class ModelOutput: + """ + A dataclass to encapsulate the model output and give a clear API. + """ + + physical: dict[str, torch.Tensor] + latent: dict[str, torch.Tensor] + + class ModelParams(torch.nn.Module): """Creation of query and embedding parameters of the model.""" @@ -450,6 +463,21 @@ def create(self) -> "Model": ) ) + # Latent heads for losses + # TODO write the forward function for this, has to wait until other Model PRs are done + target_losses = cf.get("target_losses", []) + shared_heads = cf.get("shared_heads", False) + self.latent_heads = nn.ModuleDict() + if ("iBOT" in target_losses and "DINO" in target_losses) and shared_heads: + self.latent_heads["iBOT-and-DINO-head"] = LatentPredictionHead( + "iBOT-and-DINO-head", cf.ae_global_dim_embed, cf.latent_pred_K + ) + elif "JEPA" in target_losses or "iBOT" in target_losses or "DINO" in target_losses: + for loss in target_losses: + self.latent_heads[loss] = LatentPredictionHead( + f"{loss}-head", cf.ae_global_dim_embed, cf.latent_pred_K + ) + return self def reset_parameters(self): @@ -653,7 +681,12 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca ) ] - return preds_all, posteriors + latents = {} + latents["posteriors"] = posteriors + for name, head in self.latent_heads: + latents[name] = head(posteriors.mode()) + + return ModelOutput(physical=preds_all, latent=latents) ######################################### def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: @@ -905,3 +938,26 @@ def predict( preds_tokens += [checkpoint(self.pred_heads[ii], tc_tokens, use_reentrant=False)] return preds_tokens + + +def get_model( + student_or_teacher, + cf: Config, + sources_size, + targets_num_channels, + targets_coords_size, + **kwargs, +): + if student_or_teacher == "student": + return Model(cf, sources_size, targets_num_channels, targets_coords_size).create() + else: + if cf["training_mode"] == "student-teacher": # implement mode "student-teacher": + teacher_cf = copy.deepcopy(cf) + for key, val in teacher_cf.training_mode_config["teacher_model"].items(): + teacher_cf[key] = val + teacher = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() + return teacher + else: + raise NotImplementedError( + f"The training mode {cf['training_mode']} is not implemented." + ) diff --git a/src/weathergen/train/loss.py b/src/weathergen/train/loss.py index 406cd051c..c494df375 100644 --- a/src/weathergen/train/loss.py +++ b/src/weathergen/train/loss.py @@ -10,6 +10,7 @@ import numpy as np import torch +import torch.nn.functional as F stat_loss_fcts = ["stats", "kernel_crps"] # Names of loss functions that need std computed @@ -195,3 +196,66 @@ def gamma_decay(forecast_steps, gamma): fsteps = np.arange(forecast_steps) weights = gamma**fsteps return weights * (len(fsteps) / np.sum(weights)) + + +def student_teacher_patch_softmax( + student_patches, teacher_patches, student_masks_flat, student_temp +): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + student_patches: (B, N, D) tensor + teacher_patches: (B, N, D) tensor + student_masks_flat: (B, N) tensor + student_temp: float + """ + loss = torch.sum( + teacher_patches * F.log_softmax(student_patches / student_temp, dim=-1), dim=-1 + ) + loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum( + dim=-1 + ).clamp(min=1.0) + return -loss.mean() + +def softmax(t, s, temp): + return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1) + +def masked_student_teacher_patch_softmax( + student_patches_masked, + teacher_patches_masked, + student_masks_flat, + student_temp, + n_masked_patches, + masks_weight, +): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + student_patches_masked, + teacher_patches_masked, + student_masks_flat, + student_temp, + n_masked_patches=None, + masks_weight=None, + """ + # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) + loss = softmax(teacher_patches_masked, student_patches_masked, student_temp) + if masks_weight is None: + masks_weight = ( + (1 / student_masks_flat.sum(-1).clamp(min=1.0)) + .unsqueeze(-1) + .expand_as(student_masks_flat)[student_masks_flat] + ) + if n_masked_patches is not None: + loss = loss[:n_masked_patches] + loss = loss * masks_weight + return -loss.sum() / student_masks_flat.shape[0] + + +def student_teacher_global_softmax(student_outputs, student_temp, teacher_outputs): + total_loss = 0 + for s in student_outputs: + lsm = F.log_softmax(s / student_temp, dim=-1) + for t in teacher_outputs: + loss = torch.sum(t * lsm, dim=-1) + total_loss -= loss.mean() + return total_loss + diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index f457d6454..dfd582ec8 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -1,3 +1,5 @@ +# ruff: noqa: T201 + # (C) Copyright 2025 WeatherGenerator contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 @@ -7,48 +9,21 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import dataclasses import logging -import numpy as np -import torch from omegaconf import DictConfig -from torch import Tensor -import weathergen.train.loss as losses -from weathergen.train.loss import stat_loss_fcts -from weathergen.utils.train_logger import TRAIN, VAL, Stage +import weathergen.train.loss_module as LossModule +from weathergen.train.loss_module_base import LossValues +from weathergen.utils.train_logger import TRAIN, Stage _logger = logging.getLogger(__name__) -@dataclasses.dataclass -class LossValues: - """ - A dataclass to encapsulate the various loss components computed by the LossCalculator. - - This provides a structured way to return the primary loss used for optimization, - along with detailed per-stream/per-channel/per-loss-function losses for logging, - and standard deviations for ensemble scenarios. - """ - - # The primary scalar loss value for optimization. - loss: Tensor - # Dictionaries containing detailed loss values for each stream, channel, and loss function, as - # well as standard deviations when operating with ensembles (e.g., when training with CRPS). - losses_all: dict[str, Tensor] - stddev_all: dict[str, Tensor] - - class LossCalculator: """ Manages and computes the overall loss for a WeatherGenerator model during training and validation stages. - - This class handles the initialization and application of various loss functions, - applies channel-specific weights, constructs masks for missing data, and - aggregates losses across different data streams, channels, and forecast steps. - It provides both the main loss for backpropagation and detailed loss metrics for logging. """ def __init__( @@ -75,246 +50,35 @@ def __init__( self.stage = stage self.device = device - # Dynamically load loss functions based on configuration and stage - loss_fcts = cf.loss_fcts if stage == TRAIN else cf.loss_fcts_val - self.loss_fcts = [ - [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] - for name, w in loss_fcts - ] - - def _get_weights(self, stream_info): - """ - Get weights for current stream - """ - - device = self.device - - # Determine stream and channel loss weights based on the current stage - if self.stage == TRAIN: - # set loss_weights to 1. when not specified - stream_info_loss_weight = stream_info.get("loss_weight", 1.0) - weights_channels = ( - torch.tensor(stream_info["target_channel_weights"]).to( - device=device, non_blocking=True - ) - if "target_channel_weights" in stream_info - else None - ) - elif self.stage == VAL: - # in validation mode, always unweighted loss - stream_info_loss_weight = 1.0 - weights_channels = None - - return stream_info_loss_weight, weights_channels - - def _get_fstep_weights(self, forecast_steps): - timestep_weight_config = self.cf.get("timestep_weight") - if timestep_weight_config is None: - return [1.0 for _ in range(forecast_steps)] - weights_timestep_fct = getattr(losses, timestep_weight_config[0]) - return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) - - def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep): - location_weight_type = stream_info.get("location_weight", None) - if location_weight_type is None: - return None - weights_locations_fct = getattr(losses, location_weight_type) - weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep) - weights_locations = weights_locations.to(device=self.device, non_blocking=True) - - return weights_locations - - def _get_substep_masks(self, stream_info, fstep, stream_data): - """ - Find substeps and create corresponding masks (reused across loss functions) - """ - - tok_spacetime = stream_info.get("tokenize_spacetime", None) - target_times = stream_data.target_times_raw[self.cf.forecast_offset + fstep] - target_times_unique = np.unique(target_times) if tok_spacetime else [target_times] - substep_masks = [] - for t in target_times_unique: - # find substep - mask_t = torch.tensor(t == target_times).to(self.device, non_blocking=True) - substep_masks.append(mask_t) - - return substep_masks - - @staticmethod - def _loss_per_loss_function( - loss_fct, - stream_info, - target: torch.Tensor, - pred: torch.Tensor, - substep_masks: list[torch.Tensor], - weights_channels: torch.Tensor, - weights_locations: torch.Tensor, - ): - """ - Compute loss for given loss function - """ - - loss_lfct = torch.tensor(0.0, device=target.device, requires_grad=True) - losses_chs = torch.zeros(target.shape[-1], device=target.device, dtype=torch.float32) - - ctr_substeps = 0 - for mask_t in substep_masks: - assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True - - loss, loss_chs = loss_fct( - target[mask_t], pred[:, mask_t], weights_channels, weights_locations - ) - - # accumulate loss - loss_lfct = loss_lfct + loss - losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs - ctr_substeps += 1 if loss > 0.0 else 0 - - # normalize over forecast steps in window - losses_chs /= ctr_substeps if ctr_substeps > 0 else 1.0 + calculator_configs = ( + cf.training_mode_config.losses if stage == TRAIN else cf.validation_mode_config.losses + ) - # TODO: substep weight - loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) + calculator_configs = [ + (getattr(LossModule, Cls), losses) for (Cls, losses) in calculator_configs.items() + ] - return loss_lfct, losses_chs + self.loss_calculators = [ + Cls(cf=cf, loss_fcts=losses, stage=stage, device=self.device) + for (Cls, losses) in calculator_configs + ] def compute_loss( self, - preds: list[list[Tensor]], - streams_data: list[list[any]], - ) -> LossValues: - """ - Computes the total loss for a given batch of predictions and corresponding - stream data. - - The computed loss is: - - Mean_{stream}( Mean_{fsteps}( Mean_{loss_fcts}( loss_fct( target, pred, weigths) ))) - - This method orchestrates the calculation of the overall loss by iterating through - different data streams, forecast steps, channels, and configured loss functions. - It applies weighting, handles NaN values through masking, and accumulates - detailed loss metrics for logging. - - Args: - preds: A nested list of prediction tensors. The outer list represents forecast steps, - the inner list represents streams. Each tensor contains predictions for that - step and stream. - streams_data: A nested list representing the input batch data. The outer list is for - batch items, the inner list for streams. Each element provides an object - (e.g., dataclass instance) containing target data and metadata. - - Returns: - A ModelLoss dataclass instance containing: - - loss: The loss for back-propagation. - - losses_all: A dictionary mapping stream names to a tensor of per-channel and - per-loss-function losses, normalized by non-empty targets/forecast steps. - - stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations - of predictions for channels with statistical loss functions, normalized. - """ - - # gradient loss - loss = torch.tensor(0.0, device=self.device, requires_grad=True) - # counter for non-empty targets - ctr_streams = 0 - - # initialize dictionaries for detailed loss tracking and standard deviation statistics - # create tensor for each stream - losses_all: dict[str, Tensor] = { - st.name: torch.zeros( - (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), - device=self.device, - ) - for st in self.cf.streams - } - stddev_all: dict[str, Tensor] = { - st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams - } - - # TODO: iterate over batch dimension - i_batch = 0 - for i_stream_info, stream_info in enumerate(self.cf.streams): - # extract target tokens for current stream from the specified forecast offset onwards - targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] - - stream_data = streams_data[i_batch][i_stream_info] - - fstep_loss_weights = self._get_fstep_weights(len(targets)) - - loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_fsteps = 0 - - stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() - if stream_is_spoof: - spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) - else: - spoof_weight = torch.tensor(1.0, device=self.device, requires_grad=False) - - for fstep, (target, fstep_weight) in enumerate( - zip(targets, fstep_loss_weights, strict=False) - ): - # skip if either target or prediction has no data points - pred = preds[fstep][i_stream_info] - if not (target.shape[0] > 0 and pred.shape[0] > 0): - continue - - # reshape prediction tensor to match target's dimensions: extract data/coords and - # remove token dimension if it exists. - # expected final shape of pred is [ensemble_size, num_samples, num_channels]. - pred = pred.reshape([pred.shape[0], *target.shape]) - assert pred.shape[1] > 0 - - # get weigths for current streams - stream_loss_weight, weights_channels = self._get_weights(stream_info) - - # get weights for locations - weights_locations = self._get_location_weights( - stream_info, stream_data, self.cf.forecast_offset, fstep - ) - - # get masks for sub-time steps - substep_masks = self._get_substep_masks(stream_info, fstep, stream_data) - - # accumulate loss from different loss functions - loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_loss_fcts = 0 - for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): - # loss for current loss function - loss_lfct, loss_lfct_chs = LossCalculator._loss_per_loss_function( - loss_fct, - stream_info, - target, - pred, - substep_masks, - weights_channels, - weights_locations, - ) - losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs - - # Add the weighted and normalized loss from this loss function to the total - # batch loss - loss_fstep = loss_fstep + ( - loss_fct_weight * loss_lfct * stream_loss_weight * fstep_weight - ) - ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 - - loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) - ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 - - loss = loss + ((spoof_weight * loss_fsteps) / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) - ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 - - # normalize by forecast step - losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 - stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 - - # replace channels without information by nan to exclude from further computations - losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan - stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan - - # normalize by all targets and forecast steps that were non-empty - # (with each having an expected loss of 1 for an uninitalized neural net) - loss = loss / ctr_streams - - # Return all computed loss components encapsulated in a ModelLoss dataclass + preds: dict, + targets: dict, + ): + loss_values = {} + loss = 0 + for calculator in self.loss_calculators: + loss_values[calculator.name] = calculator.compute_loss(preds=preds, targets=targets) + loss += loss_values[calculator.name].loss + + # Bring all loss values together + # TODO: make sure keys are explicit, e.g loss_mse.latent.loss_2t + losses_all = {} + stddev_all = {} + for _, v in loss_values.items(): + losses_all.update(v.losses_all) + stddev_all.update(v.stddev_all) return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) diff --git a/src/weathergen/train/loss_module.py b/src/weathergen/train/loss_module.py new file mode 100644 index 000000000..2b345c3fe --- /dev/null +++ b/src/weathergen/train/loss_module.py @@ -0,0 +1,398 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import numpy as np +import torch +from omegaconf import DictConfig +from torch import Tensor + +import weathergen.train.loss as losses +from weathergen.train.loss import stat_loss_fcts +from weathergen.train.loss_module_base import LossModuleBase, LossValues +from weathergen.utils.train_logger import TRAIN, VAL, Stage + +_logger = logging.getLogger(__name__) + + +class LossPhysical(LossModuleBase): + """ + Manages and computes the overall loss for a WeatherGenerator model during + training and validation stages. + + This class handles the initialization and application of various loss functions, + applies channel-specific weights, constructs masks for missing data, and + aggregates losses across different data streams, channels, and forecast steps. + It provides both the main loss for backpropagation and detailed loss metrics for logging. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + LossModuleBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + self.name = "LossPhysical" + + # Dynamically load loss functions based on configuration and stage + self.loss_fcts = [ + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + for name, w in loss_fcts + ] + + def _get_weights(self, stream_info): + """ + Get weights for current stream + """ + + device = self.device + + # Determine stream and channel loss weights based on the current stage + if self.stage == TRAIN: + # set loss_weights to 1. when not specified + stream_info_loss_weight = stream_info.get("loss_weight", 1.0) + weights_channels = ( + torch.tensor(stream_info["target_channel_weights"]).to( + device=device, non_blocking=True + ) + if "target_channel_weights" in stream_info + else None + ) + elif self.stage == VAL: + # in validation mode, always unweighted loss + stream_info_loss_weight = 1.0 + weights_channels = None + + return stream_info_loss_weight, weights_channels + + def _get_fstep_weights(self, forecast_steps): + timestep_weight_config = self.cf.get("timestep_weight") + if timestep_weight_config is None: + return [1.0 for _ in range(forecast_steps)] + weights_timestep_fct = getattr(losses, timestep_weight_config[0]) + return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) + + def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep): + location_weight_type = stream_info.get("location_weight", None) + if location_weight_type is None: + return None + weights_locations_fct = getattr(losses, location_weight_type) + weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep) + weights_locations = weights_locations.to(device=self.device, non_blocking=True) + + return weights_locations + + def _get_substep_masks(self, stream_info, fstep, stream_data): + """ + Find substeps and create corresponding masks (reused across loss functions) + """ + + tok_spacetime = stream_info.get("tokenize_spacetime", None) + target_times = stream_data.target_times_raw[self.cf.forecast_offset + fstep] + target_times_unique = np.unique(target_times) if tok_spacetime else [target_times] + substep_masks = [] + for t in target_times_unique: + # find substep + mask_t = torch.tensor(t == target_times).to(self.device, non_blocking=True) + substep_masks.append(mask_t) + + return substep_masks + + @staticmethod + def _loss_per_loss_function( + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + substep_masks: list[torch.Tensor], + weights_channels: torch.Tensor, + weights_locations: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_lfct = torch.tensor(0.0, device=target.device, requires_grad=True) + losses_chs = torch.zeros(target.shape[-1], device=target.device, dtype=torch.float32) + + ctr_substeps = 0 + for mask_t in substep_masks: + assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True + + loss, loss_chs = loss_fct( + target[mask_t], pred[:, mask_t], weights_channels, weights_locations + ) + + # accumulate loss + loss_lfct = loss_lfct + loss + losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs + ctr_substeps += 1 if loss > 0.0 else 0 + + # normalize over forecast steps in window + losses_chs /= ctr_substeps if ctr_substeps > 0 else 1.0 + + # TODO: substep weight + loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) + + return loss_lfct, losses_chs + + def compute_loss( + self, + preds: dict, + targets: dict, + ) -> LossValues: + """ + Computes the total loss for a given batch of predictions and corresponding + stream data. + + The computed loss is: + + Mean_{stream}( Mean_{fsteps}( Mean_{loss_fcts}( loss_fct( target, pred, weigths) ))) + + This method orchestrates the calculation of the overall loss by iterating through + different data streams, forecast steps, channels, and configured loss functions. + It applies weighting, handles NaN values through masking, and accumulates + detailed loss metrics for logging. + + Args: + preds: A nested list of prediction tensors. The outer list represents forecast steps, + the inner list represents streams. Each tensor contains predictions for that + step and stream. + streams_data: A nested list representing the input batch data. The outer list is for + batch items, the inner list for streams. Each element provides an object + (e.g., dataclass instance) containing target data and metadata. + + Returns: + A ModelLoss dataclass instance containing: + - loss: The loss for back-propagation. + - losses_all: A dictionary mapping stream names to a tensor of per-channel and + per-loss-function losses, normalized by non-empty targets/forecast steps. + - stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations + of predictions for channels with statistical loss functions, normalized. + """ + + preds = preds["physical"] + streams_data = targets["physical"] + + # gradient loss + loss = torch.tensor(0.0, device=self.device, requires_grad=True) + # counter for non-empty targets + ctr_streams = 0 + + # initialize dictionaries for detailed loss tracking and standard deviation statistics + # create tensor for each stream + losses_all: dict[str, Tensor] = { + st.name: torch.zeros( + (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), + device=self.device, + ) + for st in self.cf.streams + } + stddev_all: dict[str, Tensor] = { + st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams + } + + # TODO: iterate over batch dimension + i_batch = 0 + for i_stream_info, stream_info in enumerate(self.cf.streams): + # extract target tokens for current stream from the specified forecast offset onwards + targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] + + stream_data = streams_data[i_batch][i_stream_info] + + fstep_loss_weights = self._get_fstep_weights(len(targets)) + + loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps = 0 + + stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() + if stream_is_spoof: + spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) + else: + spoof_weight = torch.tensor(1.0, device=self.device, requires_grad=False) + + for fstep, (target, fstep_weight) in enumerate( + zip(targets, fstep_loss_weights, strict=False) + ): + # skip if either target or prediction has no data points + pred = preds[fstep][i_stream_info] + if not (target.shape[0] > 0 and pred.shape[0] > 0): + continue + + # reshape prediction tensor to match target's dimensions: extract data/coords and + # remove token dimension if it exists. + # expected final shape of pred is [ensemble_size, num_samples, num_channels]. + pred = pred.reshape([pred.shape[0], *target.shape]) + assert pred.shape[1] > 0 + + # get weigths for current streams + stream_loss_weight, weights_channels = self._get_weights(stream_info) + + # get weights for locations + weights_locations = self._get_location_weights( + stream_info, stream_data, self.cf.forecast_offset, fstep + ) + + # get masks for sub-time steps + substep_masks = self._get_substep_masks(stream_info, fstep, stream_data) + + # accumulate loss from different loss functions + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): + # loss for current loss function + loss_lfct, loss_lfct_chs = self._loss_per_loss_function( + loss_fct, + target, + pred, + substep_masks, + weights_channels, + weights_locations, + ) + losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + ( + loss_fct_weight * loss_lfct * stream_loss_weight * fstep_weight + ) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) + ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss + ((spoof_weight * loss_fsteps) / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) + ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 + + # normalize by forecast step + losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + + # replace channels without information by nan to exclude from further computations + losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan + stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan + + # normalize by all targets and forecast steps that were non-empty + # (with each having an expected loss of 1 for an uninitalized neural net) + loss = loss / ctr_streams + + # Return all computed loss components encapsulated in a ModelLoss dataclass + return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) + + +class LossLatent(LossModuleBase): + """ + Calculates loss in latent space. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + LossModuleBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + self.name = "LossLatent" + + # Dynamically load loss functions based on configuration and stage + self.loss_fcts = [ + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + for name, w in loss_fcts + ] + + def _loss_per_loss_function( + self, + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_val = loss_fct(target=target, ens=None, mu=pred) + + return loss_val + + def compute_loss( + self, + preds: list[list[Tensor]], + targets: list[list[any]], + ) -> LossValues: + losses_all: Tensor = torch.zeros( + len(self.loss_fcts), + device=self.device, + ) + + loss_fsteps_lat = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps_lat = 0 + # TODO: KCT, do we need the below per fstep? + for fstep in range( + 1, len(preds) + ): # the first entry in tokens_all is the source itself, so skip it + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + # if forecast_offset==0, then the timepoints correspond. Otherwise targets don't encode the source timestep, so we don't need to skip + fstep_targs = fstep if self.cf.forecast_offset == 0 else fstep - 1 + for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts_lat): + loss_lfct = self._loss_per_loss_function( + loss_fct, + stream_info=None, + target=targets[fstep_targs], + pred=preds[fstep], + ) + + losses_all[i_lfct] += loss_lfct # TODO: break into fsteps + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + (loss_fct_weight * loss_lfct) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps_lat = loss_fsteps_lat + ( + loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0 + ) + ctr_fsteps_lat += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss_fsteps_lat / (ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0) + + losses_all /= ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0 + losses_all[losses_all == 0.0] = torch.nan + + return LossValues(loss=loss, losses_all=losses_all) + + +class LossStudentTeacher(LossModuleBase): + """ + Calculates loss in latent space. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + self.name = "LossStudentTeacher" + raise NotImplementedError() + + def compute_loss(self, preds, targets): + return super().compute_loss(preds, targets) diff --git a/src/weathergen/train/loss_module_base.py b/src/weathergen/train/loss_module_base.py new file mode 100644 index 000000000..de66bda28 --- /dev/null +++ b/src/weathergen/train/loss_module_base.py @@ -0,0 +1,65 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import dataclasses +from abc import abstractmethod + +from torch import Tensor + +from weathergen.common.config import Config +from weathergen.utils.train_logger import Stage + + +@dataclasses.dataclass +class LossValues: + """ + A dataclass to encapsulate the various loss components computed by the LossCalculator. + + This provides a structured way to return the primary loss used for optimization, + along with detailed per-stream/per-channel/per-loss-function losses for logging, + and standard deviations for ensemble scenarios. + """ + + # The primary scalar loss value for optimization. + loss: Tensor + # Dictionaries containing detailed loss values for each stream, channel, and loss function, as + # well as standard deviations when operating with ensembles (e.g., when training with CRPS). + losses_all: dict[str, Tensor] + stddev_all: dict[str, Tensor] + + +class LossModuleBase: + def __init__(self): + """ + Base class for loss calculators. + + Args: + cf: The OmegaConf DictConfig object containing model and training configurations. + It should specify 'loss_fcts' for training and 'loss_fcts_val' for validation. + stage: The current operational stage, either TRAIN or VAL. + This dictates which set of loss functions (training or validation) will be used. + device: The computation device, such as 'cpu' or 'cuda:0', where tensors will reside. + """ + self.cf: Config | None = None + self.stage: Stage + self.loss_fcts = [] + + @abstractmethod + def compute_loss( + self, + preds: dict, + targets: dict, + ) -> LossValues: + """ + Computes loss given predictions and targets and returns values of LossValues dataclass. + """ + + raise NotImplementedError() diff --git a/src/weathergen/train/loss_module_ssl.py b/src/weathergen/train/loss_module_ssl.py new file mode 100644 index 000000000..0fb2e5683 --- /dev/null +++ b/src/weathergen/train/loss_module_ssl.py @@ -0,0 +1,93 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import numpy as np +from omegaconf import DictConfig + +import torch +from torch import Tensor +import torch.nn.functional as F + +import weathergen.train.loss as losses +from weathergen.train.loss import stat_loss_fcts +from weathergen.train.loss_module_base import LossModuleBase, LossValues +from weathergen.utils.train_logger import TRAIN, VAL, Stage + +_logger = logging.getLogger(__name__) + + +class LossLatentSSLStudentTeacher(LossModuleBase): + """ + Manages and computes the overall loss for a WeatherGenerator model pretraining using + DINO/iBOT/JEPA/BYOL style losses. + + This class handles the initialization and application of various loss functions, + It provides both the main loss for backpropagation and detailed loss metrics for logging. + """ + + valid_loss_names = set("DINO", "iBOT", "JEPA") + + def __init__( + self, + cf: DictConfig, + losses: list, + stage: Stage, + device: str, + ): + LossModuleBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + self.name = "LossLatentSSLStudentTeacher" + + # Dynamically load loss functions based on configuration and stage + self.losses = { + name: get_loss_function_ssl(name) for name in losses if name in self.valid_loss_names + } + + def compute_loss( + self, + preds: dict, + targets: dict, + ) -> LossValues: + # gradient loss + loss = torch.tensor(0.0, device=self.device, requires_grad=True) + + # initialize dictionaries for detailed loss tracking and standard deviation statistics + # create tensor for each stream + losses_all: dict[str, Tensor] = { loss : 0.0 + for loss in self.losses + } + + for name, loss_fn in losses: + loss_value = loss_fn(preds.latent[name], targets[name]).mean() + loss += loss_value + losses_all[name] = loss_value.item() + + return loss + + + + + +def get_loss_function_ssl(name): + if name == "iBOT": + return losses.masked_student_teacher_patch_softmax + elif name == "DINO": + return losses.student_teacher_global_softmax + elif name == "JEPA": + return F.l1_loss + else: + raise NotImplementedError( + f"{name} is not an implemented loss for the LossLatentSSLStudentTeacher" + ) diff --git a/src/weathergen/train/ssl_losses_utils.py b/src/weathergen/train/ssl_losses_utils.py new file mode 100644 index 000000000..060af460f --- /dev/null +++ b/src/weathergen/train/ssl_losses_utils.py @@ -0,0 +1,309 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + + +def lossfunc(t, s, temp): + return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1) + + +class iBOTPatchTargetProcessing(nn.Module): + """ + Code taken and adapted from the official DINOv2 implementation + https://github.com/facebookresearch/dinov2/tree/main + + Needs to be nn.Module because of the registered_buffer, it means we should have a forward + function, previously was the softmax computation, maybe we can make it the + softmax_center_teacher, etc + """ + + def __init__( + self, + patch_out_dim, + student_temp=0.1, + teacher_temp=0.1, + center_momentum=0.9, + teacher_style="softmax_center", + ): + super().__init__() + self.student_temp = student_temp + self.teacher_temp = teacher_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.zeros(1, 1, patch_out_dim)) + self.updated = True + self.reduce_handle = None + self.len_teacher_patch_tokens = None + self.async_batch_center = None + self.teacher_style = teacher_style + assert teacher_style in ["softmax_center", "sinkhorn_knopp"], f"{teacher_style} is unknown" + + @torch.no_grad() + def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp): + self.apply_center_update() + # teacher centering and sharpening + # + # WARNING: + # as self.center is a float32, everything gets casted to float32 afterwards + # + # teacher_patch_tokens = teacher_patch_tokens.float() + # return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1) + + return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1) + + # this is experimental, keep everything in float16 and let's see what happens: + # return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1) + + @torch.no_grad() + def sinkhorn_knopp_teacher( + self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3 + ): + teacher_output = teacher_output.float() + # world_size = dist.get_world_size() if dist.is_initialized() else 1 + Q = torch.exp( + teacher_output / teacher_temp + ).t() # Q is K-by-B for consistency with notations from our paper + # B = Q.shape[1] * world_size # number of samples to assign + B = n_masked_patches_tensor + dist.all_reduce(B) + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q) + Q /= sum_Q + + for it in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + # def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): + # """ + # Cross-entropy between softmax outputs of the teacher and student networks. + # student_patch_tokens: (B, N, D) tensor + # teacher_patch_tokens: (B, N, D) tensor + # student_masks_flat: (B, N) tensor + # """ + # t = teacher_patch_tokens + # s = student_patch_tokens + # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) + # loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum( + # dim=-1 + # ).clamp(min=1.0) + # return -loss.mean() + + # def forward_masked( + # self, + # student_patch_tokens_masked, + # teacher_patch_tokens_masked, + # student_masks_flat, + # n_masked_patches=None, + # masks_weight=None, + # ): + # t = teacher_patch_tokens_masked + # s = student_patch_tokens_masked + # # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) + # loss = lossfunc(t, s, self.student_temp) + # if masks_weight is None: + # masks_weight = ( + # (1 / student_masks_flat.sum(-1).clamp(min=1.0)) + # .unsqueeze(-1) + # .expand_as(student_masks_flat)[student_masks_flat] + # ) + # if n_masked_patches is not None: + # loss = loss[:n_masked_patches] + # loss = loss * masks_weight + # return -loss.sum() / student_masks_flat.shape[0] + + def forward(self, teacher_output): + # TODO deal with the iBOT head question, use the forward_masked + if self.teacher_style == "softmax_center": + processed_teacher_output = self.softmax_center_teacher( + teacher_output, self.teacher_temp + ) + self.update_center(teacher_output) + return processed_teacher_output + elif self.teacher_style == "sinkhorn_knopp": + return self.sinkhorn_knopp_teacher(teacher_output, self.teacher_temp) + else: + # this code should never be reached, see assert in __init__ + return teacher_output + + @torch.no_grad() + def update_center(self, teacher_patch_tokens): + self.reduce_center_update(teacher_patch_tokens) + + @torch.no_grad() + def reduce_center_update(self, teacher_patch_tokens): + self.updated = False + self.len_teacher_patch_tokens = len(teacher_patch_tokens) + self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True) + if dist.is_initialized(): + self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) + + @torch.no_grad() + def apply_center_update(self): + if self.updated is False: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if self.reduce_handle is not None: + self.reduce_handle.wait() + _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size) + + self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) + + self.updated = True + + +class DINOTargetProcessing(nn.Module): + """ + Code taken and adapted from the official DINOv2 implementation + https://github.com/facebookresearch/dinov2/tree/main + + Needs to be nn.Module because of the registered_buffer, it means we should have a forward + function, previously was the softmax computation, maybe we can make it the + softmax_center_teacher, etc + """ + + def __init__( + self, + out_dim, + student_temp=0.1, + center_momentum=0.9, + teacher_temp=0.1, + teacher_style="softmax_center", + ): + super().__init__() + self.student_temp = student_temp + self.teacher_temp = teacher_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.zeros(1, out_dim)) + self.updated = True + self.reduce_handle = None + self.len_teacher_output = None + self.async_batch_center = None + self.teacher_style = teacher_style + assert teacher_style in ["softmax_center", "sinkhorn_knopp"], f"{teacher_style} is unknown" + + @torch.no_grad() + def softmax_center_teacher(self, teacher_output, teacher_temp): + self.apply_center_update() + # teacher centering and sharpening + return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1) + + @torch.no_grad() + def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): + teacher_output = teacher_output.float() + world_size = dist.get_world_size() if dist.is_initialized() else 1 + Q = torch.exp( + teacher_output / teacher_temp + ).t() # Q is K-by-B for consistency with notations from our paper + B = Q.shape[1] * world_size # number of samples to assign + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q) + Q /= sum_Q + + for _it in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + def forward(self, teacher_output): + # TODO deal with the DINO head question + if self.teacher_style == "softmax_center": + processed_teacher_output = self.softmax_center_teacher( + teacher_output, self.teacher_temp + ) + self.update_center(teacher_output) + return processed_teacher_output + elif self.teacher_style == "sinkhorn_knopp": + return self.sinkhorn_knopp_teacher(teacher_output, self.teacher_temp) + else: + # this code should never be reached, see assert in __init__ + return teacher_output + + # def forward(self, student_output_list, teacher_out_softmaxed_centered_list): + # """ + # Cross-entropy between softmax outputs of the teacher and student networks. + # """ + # # TODO: Use cross_entropy_distribution here + # total_loss = 0 + # for s in student_output_list: + # lsm = F.log_softmax(s / self.student_temp, dim=-1) + # for t in teacher_out_softmaxed_centered_list: + # loss = torch.sum(t * lsm, dim=-1) + # total_loss -= loss.mean() + # return total_loss + + @torch.no_grad() + def update_center(self, teacher_output): + self.reduce_center_update(teacher_output) + + @torch.no_grad() + def reduce_center_update(self, teacher_output): + self.updated = False + self.len_teacher_output = len(teacher_output) + self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + if dist.is_initialized(): + self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) + + @torch.no_grad() + def apply_center_update(self): + if self.updated is False: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if self.reduce_handle is not None: + self.reduce_handle.wait() + _t = self.async_batch_center / (self.len_teacher_output * world_size) + + self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) + + self.updated = True + + +class JEPATargetProcessing(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args): + return args diff --git a/src/weathergen/train/target_and_aux_module_base.py b/src/weathergen/train/target_and_aux_module_base.py new file mode 100644 index 000000000..7cca5f7bc --- /dev/null +++ b/src/weathergen/train/target_and_aux_module_base.py @@ -0,0 +1,35 @@ +from typing import Any + + +class TargetAndAuxModuleBase: + def __init__(self, model, rng, **kwargs): + pass + + def reset(self): + pass + + def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: + pass + + def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: + pass + + def compute(self, *args, **kwargs) -> tuple[Any, Any]: + pass + + +class IdentityTargetAndAux(TargetAndAuxModuleBase): + def __init__(self, model, rng, config): + return + + def reset(self): + return + + def update_state_pre_backward(self, istep, batch, model, **kwargs): + return + + def update_state_post_opt_step(self, istep, batch, model, **kwargs): + return + + def compute(self, istep, batch, *args, **kwargs): + return batch[0], None diff --git a/src/weathergen/train/target_and_aux_ssl_teacher.py b/src/weathergen/train/target_and_aux_ssl_teacher.py new file mode 100644 index 000000000..6cbe3ed96 --- /dev/null +++ b/src/weathergen/train/target_and_aux_ssl_teacher.py @@ -0,0 +1,78 @@ +from typing import Any + +import torch + +from weathergen.train.target_and_aux_module_base import TargetAndAuxModuleBase +from weathergen.train.ssl_losses_utils import ( + iBOTPatchTargetProcessing, + DINOTargetProcessing, + JEPATargetProcessing, +) + + +class EMATeacher(TargetAndAuxModuleBase): + def __init__(self, model, rng, ema_model, batch_size, **kwargs): + # One of the issues is that the teacher model may have a different architecture + # to the student, e.g. JEPA. So we need quite a flexible way to instantiate the + # the teacher. Because of the device sharding etc that requires quite a bit of + # massaging we assume that the teacher creates the EMA model correctly. However, + # note that you cannot assume that model.state_dict equals ema_model.state_dict + self.ema_model = ema_model + self.batch_size = batch_size + + # is a dict of TargetProcessing classes as we may use several in parallel + self.postprocess_targets = get_target_postprocessing(kwargs["losses"], **kwargs) + + self.reset() + + def reset(self, batch_size=None): + self.ema_model.reset() + if batch_size is not None: + self.batch_size = batch_size + + def update_state_pre_backward(self, istep, batch, model, **kwargs) -> None: + return + + def update_state_post_opt_step(self, istep, batch, model, **kwargs) -> None: + self.ema_model.update(istep, self.batch_size) + + def compute( + self, bidx, batch, model_params, model, forecast_offset, forecast_steps + ) -> tuple[Any, Any]: + """ + Likely will gain in complexity as we actually implement things as different losses + DINO, iBOT, JEPA will have different heads, which then probably should be computed + in the postprocess_targets modules, which are nn.Modules + """ + targets = self.ema_model.forward_eval(model_params, batch, forecast_offset, forecast_steps) + targets = {} + for loss_name, target_module in self.postprocess_targets.items(): + with torch.no_grad(): + targets[loss_name] = None # target_module(targets["loss_name"]) + return targets, None + + +def get_target_postprocessing(target_losses: list[str], **kwargs): + return_dict = {} + for loss_name in target_losses: + if loss_name == "iBOT": + return_dict[loss_name] = iBOTPatchTargetProcessing( + patch_out_dim=kwargs["ibot_patch_out_dim"], + center_momentum=kwargs["center_momentum"], + student_temp=kwargs["student_temp"], + teacher_temp=kwargs["teacher_temp"], + teacher_style=kwargs["teacher_style"], + ) + elif loss_name == "DINO": + return_dict[loss_name] = DINOTargetProcessing( + out_dim=kwargs["dino_out_dim"], + center_momentum=kwargs["center_momentum"], + student_temp=kwargs["student_temp"], + teacher_style=kwargs["teacher_style"], + ) + elif loss_name == "JEPA": + return_dict[loss_name] = JEPATargetProcessing() + else: + # We skip losses that are not handled by the EMATeacher + continue + return return_dict diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 3d847a671..19d2f1dd2 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -42,14 +42,14 @@ ) from weathergen.model.ema import EMAModel from weathergen.model.layers import MLP -from weathergen.model.model import Model, ModelParams +from weathergen.model.model import Model, ModelParams, get_model from weathergen.model.utils import freeze_weights from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler -from weathergen.train.trainer_base import TrainerBase +from weathergen.train.trainer_base import TrainerBase, get_target_and_aux_calculator from weathergen.utils.distributed import all_gather_vlen, ddp_average, is_root from weathergen.utils.train_logger import TRAIN, VAL, Stage, TrainLogger -from weathergen.utils.utils import get_dtype +from weathergen.utils.utils import get_batch_size, get_dtype from weathergen.utils.validation_io import write_output logger = logging.getLogger(__name__) @@ -157,13 +157,15 @@ def inference(self, cf, devices, run_id_trained, epoch): self.validate(epoch=0) logger.info(f"Finished inference run with id: {cf.run_id}") - def init_model_and_shard(self, cf, devices): + def init_model_and_shard(self, cf, student_or_teacher, devices): sources_size = self.dataset.get_sources_size() targets_num_channels = self.dataset.get_targets_num_channels() targets_coords_size = self.dataset.get_targets_coords_size() with torch.device("meta"): - model = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() + model = get_model( + student_or_teacher, cf, sources_size, targets_num_channels, targets_coords_size + ) for name, module in model.named_modules(): name = module.name if hasattr(module, "name") else name @@ -294,7 +296,7 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): self.dataset_val, **loader_params, sampler=None ) - self.model, self.model_params = self.init_model_and_shard(cf, devices) + self.model, self.model_params = self.init_model_and_shard(cf, "student", devices) if run_id_contd is None: self.model.to_empty(device="cuda") @@ -313,8 +315,19 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): self.validate_with_ema = cf.get("validate_with_ema", False) self.ema_model = None + # validate_with_ema is incompatible with student-teacher + self.validate_with_ema = False # TODO remove for testing only if self.validate_with_ema: - meta_ema_model = self.init_model_and_shard(cf, devices)[0] + meta_ema_model = self.init_model_and_shard(cf, "student", devices)[0] + self.ema_model = EMAModel( + self.model, + meta_ema_model, + halflife_steps=cf.get("ema_halflife_in_thousands", 1e-3), + rampup_ratio=cf.get("ema_ramp_up_ratio", 0.09), + is_model_sharded=(cf.with_ddp and cf.with_fsdp), + ) + elif cf["training_mode"] == "student-teacher": + meta_ema_model = self.init_model_and_shard(cf, "teacher", devices)[0] self.ema_model = EMAModel( self.model, meta_ema_model, @@ -323,6 +336,14 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): is_model_sharded=(cf.with_ddp and cf.with_fsdp), ) + self.target_and_aux_calculator = get_target_and_aux_calculator( + cf, + self.model, + None, + ema_model=self.ema_model, + batch_size=get_batch_size(cf, self.world_size_original), + ) + # if with_fsdp then parameter count is unreliable if (is_root() and not cf.with_fsdp) or not cf.with_ddp: self.model.print_num_parameters() @@ -588,17 +609,33 @@ def train(self, epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - preds, posteriors = self.model( - self.model_params, batch, cf.forecast_offset, forecast_steps + output = self.model(self.model_params, batch, cf.forecast_offset, forecast_steps) + + targets, aux_outputs = self.target_and_aux_calculator.compute( + self.cf.istep, + batch, + self.model_params, + self.model, + cf.forecast_offset, + forecast_steps, ) + # predictions, posteriors = self.model( + # self.model_params, batch, cf.forecast_offset, forecast_steps + # ) + targets = {"physical": batch[0]} + preds = {"physical": predictions, "latent": posteriors} loss_values = self.loss_calculator.compute_loss( - preds=preds, - streams_data=batch[0], + preds=output.physical, + targets=targets, ) if cf.latent_noise_kl_weight > 0.0: - kl = torch.cat([posterior.kl() for posterior in posteriors]) + kl = torch.cat([posterior.kl() for posterior in output.latent["posteriors"]]) loss_values.loss += cf.latent_noise_kl_weight * kl.mean() + self.target_and_aux_calculator.update_state_pre_backward( + self.cf.istep, batch, self.model + ) + # backward pass self.optimizer.zero_grad() self.grad_scaler.scale(loss_values.loss).backward() @@ -622,14 +659,16 @@ def train(self, epoch): self.grad_scaler.update() # self.optimizer.step() + self.target_and_aux_calculator.update_state_post_opt_step(bidx, batch, self.model) + # update learning rate self.lr_scheduler.step() # EMA update if self.validate_with_ema: self.ema_model.update( - self.cf.istep * self.world_size_original * self.cf.batch_size_per_gpu, - self.world_size_original * self.cf.batch_size_per_gpu, + self.cf.istep * get_batch_size(self.cf, self.world_size_original), + get_batch_size(self.cf, self.world_size_original), ) self.loss_unweighted_hist += [loss_values.losses_all] diff --git a/src/weathergen/train/trainer_base.py b/src/weathergen/train/trainer_base.py index 684b3b54b..53eb7f8e0 100644 --- a/src/weathergen/train/trainer_base.py +++ b/src/weathergen/train/trainer_base.py @@ -17,6 +17,8 @@ import torch.multiprocessing from weathergen.common.config import Config +from weathergen.train.target_and_aux_module_base import IdentityTargetAndAux +from weathergen.train.target_and_aux_ssl_teacher import EMATeacher from weathergen.train.utils import str_to_tensor, tensor_to_str from weathergen.utils.distributed import is_root @@ -167,3 +169,16 @@ def get_perf(self): perf_mem /= len(self.device_handles) return perf_gpu, perf_mem + + +# should be moved to its own file so as to prevent cyclical imports +def get_target_and_aux_calculator(config, model, rng, batch_size, **kwargs): + target_and_aux_calc = config.training_mode_config.get("target_and_aux_calc", None) + if target_and_aux_calc is None or target_and_aux_calc == "identity": + return IdentityTargetAndAux(model, rng, config) + elif target_and_aux_calc == "EMATeacher": + return EMATeacher( + model, rng, kwargs["ema_model"], batch_size, **config.training_mode_config + ) + else: + raise NotImplementedError(f"{target_and_aux_calc} is not implemented") diff --git a/src/weathergen/utils/utils.py b/src/weathergen/utils/utils.py index 5deba9287..c84f2d298 100644 --- a/src/weathergen/utils/utils.py +++ b/src/weathergen/utils/utils.py @@ -9,6 +9,8 @@ import torch +from weathergen.common.config import Config + def get_dtype(value: str) -> torch.dtype: """ @@ -24,3 +26,7 @@ def get_dtype(value: str) -> torch.dtype: raise NotImplementedError( f"Dtype {value} is not recognized, choose either, bf16, fp16, or fp32" ) + + +def get_batch_size(cf: Config, world_size: int) -> int: + return world_size * cf.batch_size_per_gpu