From 28d9b2264c7c5f51b6c642b32ded43e39f0f16e3 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 4 Nov 2025 09:38:28 +0100 Subject: [PATCH 01/10] adding loss calculator base class --- src/weathergen/train/loss_calculator.py | 230 +++++++++++++++++++ src/weathergen/train/loss_calculator_base.py | 98 ++++++++ 2 files changed, 328 insertions(+) create mode 100644 src/weathergen/train/loss_calculator_base.py diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index f457d6454..1d54d49c9 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -17,6 +17,7 @@ import weathergen.train.loss as losses from weathergen.train.loss import stat_loss_fcts +from weathergen.train.loss_calculator_base import LossCalculatorBase, LossValues from weathergen.utils.train_logger import TRAIN, VAL, Stage _logger = logging.getLogger(__name__) @@ -318,3 +319,232 @@ def compute_loss( # Return all computed loss components encapsulated in a ModelLoss dataclass return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) + + +class LossCalculatorPhysical(LossCalculatorBase): + """ + Manages and computes the overall loss for a WeatherGenerator model during + training and validation stages. + + This class handles the initialization and application of various loss functions, + applies channel-specific weights, constructs masks for missing data, and + aggregates losses across different data streams, channels, and forecast steps. + It provides both the main loss for backpropagation and detailed loss metrics for logging. + """ + + def __init__( + self, + cf: DictConfig, + stage: Stage, + device: str, + ): + LossCalculatorBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + + # Dynamically load loss functions based on configuration and stage + loss_fcts = cf.loss_fcts if stage == TRAIN else cf.loss_fcts_val + self.loss_fcts = [ + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + for name, w in loss_fcts + ] + + def _get_weights(self, stream_info): + """ + Get weights for current stream + """ + + device = self.device + + # Determine stream and channel loss weights based on the current stage + if self.stage == TRAIN: + # set loss_weights to 1. when not specified + stream_info_loss_weight = stream_info.get("loss_weight", 1.0) + weights_channels = ( + torch.tensor(stream_info["target_channel_weights"]).to( + device=device, non_blocking=True + ) + if "target_channel_weights" in stream_info + else None + ) + elif self.stage == VAL: + # in validation mode, always unweighted loss + stream_info_loss_weight = 1.0 + weights_channels = None + + return stream_info_loss_weight, weights_channels + + def _get_fstep_weights(self, forecast_steps): + timestep_weight_config = self.cf.get("timestep_weight") + if timestep_weight_config is None: + return [1.0 for _ in range(forecast_steps)] + weights_timestep_fct = getattr(losses, timestep_weight_config[0]) + return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) + + def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep): + location_weight_type = stream_info.get("location_weight", None) + if location_weight_type is None: + return None + weights_locations_fct = getattr(losses, location_weight_type) + weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep) + weights_locations = weights_locations.to(device=self.device, non_blocking=True) + + return weights_locations + + def _get_substep_masks(self, stream_info, fstep, stream_data): + """ + Find substeps and create corresponding masks (reused across loss functions) + """ + + tok_spacetime = stream_info.get("tokenize_spacetime", None) + target_times = stream_data.target_times_raw[self.cf.forecast_offset + fstep] + target_times_unique = np.unique(target_times) if tok_spacetime else [target_times] + substep_masks = [] + for t in target_times_unique: + # find substep + mask_t = torch.tensor(t == target_times).to(self.device, non_blocking=True) + substep_masks.append(mask_t) + + return substep_masks + + def compute_loss( + self, + preds: list[list[Tensor]], + streams_data: list[list[any]], + ) -> LossValues: + """ + Computes the total loss for a given batch of predictions and corresponding + stream data. + + The computed loss is: + + Mean_{stream}( Mean_{fsteps}( Mean_{loss_fcts}( loss_fct( target, pred, weigths) ))) + + This method orchestrates the calculation of the overall loss by iterating through + different data streams, forecast steps, channels, and configured loss functions. + It applies weighting, handles NaN values through masking, and accumulates + detailed loss metrics for logging. + + Args: + preds: A nested list of prediction tensors. The outer list represents forecast steps, + the inner list represents streams. Each tensor contains predictions for that + step and stream. + streams_data: A nested list representing the input batch data. The outer list is for + batch items, the inner list for streams. Each element provides an object + (e.g., dataclass instance) containing target data and metadata. + + Returns: + A ModelLoss dataclass instance containing: + - loss: The loss for back-propagation. + - losses_all: A dictionary mapping stream names to a tensor of per-channel and + per-loss-function losses, normalized by non-empty targets/forecast steps. + - stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations + of predictions for channels with statistical loss functions, normalized. + """ + + # gradient loss + loss = torch.tensor(0.0, device=self.device, requires_grad=True) + # counter for non-empty targets + ctr_streams = 0 + + # initialize dictionaries for detailed loss tracking and standard deviation statistics + # create tensor for each stream + losses_all: dict[str, Tensor] = { + st.name: torch.zeros( + (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), + device=self.device, + ) + for st in self.cf.streams + } + stddev_all: dict[str, Tensor] = { + st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams + } + + # TODO: iterate over batch dimension + i_batch = 0 + for i_stream_info, stream_info in enumerate(self.cf.streams): + # extract target tokens for current stream from the specified forecast offset onwards + targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] + + stream_data = streams_data[i_batch][i_stream_info] + + fstep_loss_weights = self._get_fstep_weights(len(targets)) + + loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps = 0 + + stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() + if stream_is_spoof: + spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) + else: + spoof_weight = torch.tensor(1.0, device=self.device, requires_grad=False) + + for fstep, (target, fstep_weight) in enumerate( + zip(targets, fstep_loss_weights, strict=False) + ): + # skip if either target or prediction has no data points + pred = preds[fstep][i_stream_info] + if not (target.shape[0] > 0 and pred.shape[0] > 0): + continue + + # reshape prediction tensor to match target's dimensions: extract data/coords and + # remove token dimension if it exists. + # expected final shape of pred is [ensemble_size, num_samples, num_channels]. + pred = pred.reshape([pred.shape[0], *target.shape]) + assert pred.shape[1] > 0 + + # get weigths for current streams + stream_loss_weight, weights_channels = self._get_weights(stream_info) + + # get weights for locations + weights_locations = self._get_location_weights( + stream_info, stream_data, self.cf.forecast_offset, fstep + ) + + # get masks for sub-time steps + substep_masks = self._get_substep_masks(stream_info, fstep, stream_data) + + # accumulate loss from different loss functions + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): + # loss for current loss function + loss_lfct, loss_lfct_chs = LossCalculator._loss_per_loss_function( + loss_fct, + stream_info, + target, + pred, + substep_masks, + weights_channels, + weights_locations, + ) + losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + ( + loss_fct_weight * loss_lfct * stream_loss_weight * fstep_weight + ) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) + ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss + ((spoof_weight * loss_fsteps) / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) + ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 + + # normalize by forecast step + losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + + # replace channels without information by nan to exclude from further computations + losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan + stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan + + # normalize by all targets and forecast steps that were non-empty + # (with each having an expected loss of 1 for an uninitalized neural net) + loss = loss / ctr_streams + + # Return all computed loss components encapsulated in a ModelLoss dataclass + return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) diff --git a/src/weathergen/train/loss_calculator_base.py b/src/weathergen/train/loss_calculator_base.py new file mode 100644 index 000000000..43091978c --- /dev/null +++ b/src/weathergen/train/loss_calculator_base.py @@ -0,0 +1,98 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import dataclasses + +import torch +from torch import Tensor + +from weathergen.common.config import Config +from weathergen.utils.train_logger import Stage + + +@dataclasses.dataclass +class LossValues: + """ + A dataclass to encapsulate the various loss components computed by the LossCalculator. + + This provides a structured way to return the primary loss used for optimization, + along with detailed per-stream/per-channel/per-loss-function losses for logging, + and standard deviations for ensemble scenarios. + """ + + # The primary scalar loss value for optimization. + loss: Tensor + # Dictionaries containing detailed loss values for each stream, channel, and loss function, as + # well as standard deviations when operating with ensembles (e.g., when training with CRPS). + losses_all: dict[str, Tensor] + stddev_all: dict[str, Tensor] + + +class LossCalculatorBase: + def __init__(self): + """ + Initializes the LossCalculator. + + This sets up the configuration, the operational stage (training or validation), + the device for tensor operations, and initializes the list of loss functions + based on the provided configuration. + + Args: + cf: The OmegaConf DictConfig object containing model and training configurations. + It should specify 'loss_fcts' for training and 'loss_fcts_val' for validation. + stage: The current operational stage, either TRAIN or VAL. + This dictates which set of loss functions (training or validation) will be used. + device: The computation device, such as 'cpu' or 'cuda:0', where tensors will reside. + """ + self.cf: Config | None = None + self.stage: Stage + self.loss_fcts = [] + + @staticmethod + def _loss_per_loss_function( + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + substep_masks: list[torch.Tensor], + weights_channels: torch.Tensor, + weights_locations: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_lfct = torch.tensor(0.0, device=target.device, requires_grad=True) + losses_chs = torch.zeros(target.shape[-1], device=target.device, dtype=torch.float32) + + ctr_substeps = 0 + for mask_t in substep_masks: + assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True + + loss, loss_chs = loss_fct( + target[mask_t], pred[:, mask_t], weights_channels, weights_locations + ) + + # accumulate loss + loss_lfct = loss_lfct + loss + losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs + ctr_substeps += 1 if loss > 0.0 else 0 + + # normalize over forecast steps in window + losses_chs /= ctr_substeps if ctr_substeps > 0 else 1.0 + + # TODO: substep weight + loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) + + return loss_lfct, losses_chs + + # def _get_weights(self, stream_info): + + # def _update_weights(self, stream_info): From f1e71321ffedc5c4a272a36b8d76231affd41204 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Wed, 5 Nov 2025 18:12:01 +0100 Subject: [PATCH 02/10] abstract loss calc structure --- config/default_config.yml | 5 +- src/weathergen/train/loss_calculator.py | 524 ++---------------- src/weathergen/train/loss_calculator_base.py | 5 + .../train/loss_calculator_classes.py | 285 ++++++++++ src/weathergen/train/trainer.py | 6 +- 5 files changed, 329 insertions(+), 496 deletions(-) create mode 100644 src/weathergen/train/loss_calculator_classes.py diff --git a/config/default_config.yml b/config/default_config.yml index 620f5c4ae..26f9382c2 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -70,9 +70,12 @@ latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True loss_fcts: - - + - - "mse" - 1.0 + # - + # - "latent:mse" + # - 0.3 loss_fcts_val: - - "mse" diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 1d54d49c9..ffd018f22 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -1,3 +1,5 @@ +# ruff: noqa: T201 + # (C) Copyright 2025 WeatherGenerator contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 @@ -7,49 +9,21 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import dataclasses import logging -import numpy as np -import torch from omegaconf import DictConfig -from torch import Tensor -import weathergen.train.loss as losses -from weathergen.train.loss import stat_loss_fcts -from weathergen.train.loss_calculator_base import LossCalculatorBase, LossValues -from weathergen.utils.train_logger import TRAIN, VAL, Stage +from weathergen.train.loss_calculator_base import LossValues +from weathergen.train.loss_calculator_classes import LossCalculatorLatent, LossCalculatorPhysical +from weathergen.utils.train_logger import TRAIN, Stage _logger = logging.getLogger(__name__) -@dataclasses.dataclass -class LossValues: - """ - A dataclass to encapsulate the various loss components computed by the LossCalculator. - - This provides a structured way to return the primary loss used for optimization, - along with detailed per-stream/per-channel/per-loss-function losses for logging, - and standard deviations for ensemble scenarios. - """ - - # The primary scalar loss value for optimization. - loss: Tensor - # Dictionaries containing detailed loss values for each stream, channel, and loss function, as - # well as standard deviations when operating with ensembles (e.g., when training with CRPS). - losses_all: dict[str, Tensor] - stddev_all: dict[str, Tensor] - - class LossCalculator: """ Manages and computes the overall loss for a WeatherGenerator model during training and validation stages. - - This class handles the initialization and application of various loss functions, - applies channel-specific weights, constructs masks for missing data, and - aggregates losses across different data streams, channels, and forecast steps. - It provides both the main loss for backpropagation and detailed loss metrics for logging. """ def __init__( @@ -76,475 +50,39 @@ def __init__( self.stage = stage self.device = device - # Dynamically load loss functions based on configuration and stage loss_fcts = cf.loss_fcts if stage == TRAIN else cf.loss_fcts_val - self.loss_fcts = [ - [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] - for name, w in loss_fcts - ] - - def _get_weights(self, stream_info): - """ - Get weights for current stream - """ - - device = self.device - - # Determine stream and channel loss weights based on the current stage - if self.stage == TRAIN: - # set loss_weights to 1. when not specified - stream_info_loss_weight = stream_info.get("loss_weight", 1.0) - weights_channels = ( - torch.tensor(stream_info["target_channel_weights"]).to( - device=device, non_blocking=True - ) - if "target_channel_weights" in stream_info - else None - ) - elif self.stage == VAL: - # in validation mode, always unweighted loss - stream_info_loss_weight = 1.0 - weights_channels = None - - return stream_info_loss_weight, weights_channels - - def _get_fstep_weights(self, forecast_steps): - timestep_weight_config = self.cf.get("timestep_weight") - if timestep_weight_config is None: - return [1.0 for _ in range(forecast_steps)] - weights_timestep_fct = getattr(losses, timestep_weight_config[0]) - return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) - - def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep): - location_weight_type = stream_info.get("location_weight", None) - if location_weight_type is None: - return None - weights_locations_fct = getattr(losses, location_weight_type) - weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep) - weights_locations = weights_locations.to(device=self.device, non_blocking=True) - - return weights_locations - - def _get_substep_masks(self, stream_info, fstep, stream_data): - """ - Find substeps and create corresponding masks (reused across loss functions) - """ - - tok_spacetime = stream_info.get("tokenize_spacetime", None) - target_times = stream_data.target_times_raw[self.cf.forecast_offset + fstep] - target_times_unique = np.unique(target_times) if tok_spacetime else [target_times] - substep_masks = [] - for t in target_times_unique: - # find substep - mask_t = torch.tensor(t == target_times).to(self.device, non_blocking=True) - substep_masks.append(mask_t) - - return substep_masks - - @staticmethod - def _loss_per_loss_function( - loss_fct, - stream_info, - target: torch.Tensor, - pred: torch.Tensor, - substep_masks: list[torch.Tensor], - weights_channels: torch.Tensor, - weights_locations: torch.Tensor, - ): - """ - Compute loss for given loss function - """ - - loss_lfct = torch.tensor(0.0, device=target.device, requires_grad=True) - losses_chs = torch.zeros(target.shape[-1], device=target.device, dtype=torch.float32) - - ctr_substeps = 0 - for mask_t in substep_masks: - assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True - - loss, loss_chs = loss_fct( - target[mask_t], pred[:, mask_t], weights_channels, weights_locations - ) - - # accumulate loss - loss_lfct = loss_lfct + loss - losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs - ctr_substeps += 1 if loss > 0.0 else 0 - - # normalize over forecast steps in window - losses_chs /= ctr_substeps if ctr_substeps > 0 else 1.0 - - # TODO: substep weight - loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) - - return loss_lfct, losses_chs - - def compute_loss( - self, - preds: list[list[Tensor]], - streams_data: list[list[any]], - ) -> LossValues: - """ - Computes the total loss for a given batch of predictions and corresponding - stream data. - - The computed loss is: - - Mean_{stream}( Mean_{fsteps}( Mean_{loss_fcts}( loss_fct( target, pred, weigths) ))) - - This method orchestrates the calculation of the overall loss by iterating through - different data streams, forecast steps, channels, and configured loss functions. - It applies weighting, handles NaN values through masking, and accumulates - detailed loss metrics for logging. - - Args: - preds: A nested list of prediction tensors. The outer list represents forecast steps, - the inner list represents streams. Each tensor contains predictions for that - step and stream. - streams_data: A nested list representing the input batch data. The outer list is for - batch items, the inner list for streams. Each element provides an object - (e.g., dataclass instance) containing target data and metadata. - - Returns: - A ModelLoss dataclass instance containing: - - loss: The loss for back-propagation. - - losses_all: A dictionary mapping stream names to a tensor of per-channel and - per-loss-function losses, normalized by non-empty targets/forecast steps. - - stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations - of predictions for channels with statistical loss functions, normalized. - """ - - # gradient loss - loss = torch.tensor(0.0, device=self.device, requires_grad=True) - # counter for non-empty targets - ctr_streams = 0 - - # initialize dictionaries for detailed loss tracking and standard deviation statistics - # create tensor for each stream - losses_all: dict[str, Tensor] = { - st.name: torch.zeros( - (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), - device=self.device, - ) - for st in self.cf.streams - } - stddev_all: dict[str, Tensor] = { - st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams - } - - # TODO: iterate over batch dimension - i_batch = 0 - for i_stream_info, stream_info in enumerate(self.cf.streams): - # extract target tokens for current stream from the specified forecast offset onwards - targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] - - stream_data = streams_data[i_batch][i_stream_info] - - fstep_loss_weights = self._get_fstep_weights(len(targets)) - - loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_fsteps = 0 - - stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() - if stream_is_spoof: - spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) - else: - spoof_weight = torch.tensor(1.0, device=self.device, requires_grad=False) - - for fstep, (target, fstep_weight) in enumerate( - zip(targets, fstep_loss_weights, strict=False) - ): - # skip if either target or prediction has no data points - pred = preds[fstep][i_stream_info] - if not (target.shape[0] > 0 and pred.shape[0] > 0): - continue - - # reshape prediction tensor to match target's dimensions: extract data/coords and - # remove token dimension if it exists. - # expected final shape of pred is [ensemble_size, num_samples, num_channels]. - pred = pred.reshape([pred.shape[0], *target.shape]) - assert pred.shape[1] > 0 - - # get weigths for current streams - stream_loss_weight, weights_channels = self._get_weights(stream_info) - - # get weights for locations - weights_locations = self._get_location_weights( - stream_info, stream_data, self.cf.forecast_offset, fstep - ) - - # get masks for sub-time steps - substep_masks = self._get_substep_masks(stream_info, fstep, stream_data) - - # accumulate loss from different loss functions - loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_loss_fcts = 0 - for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): - # loss for current loss function - loss_lfct, loss_lfct_chs = LossCalculator._loss_per_loss_function( - loss_fct, - stream_info, - target, - pred, - substep_masks, - weights_channels, - weights_locations, - ) - losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs - - # Add the weighted and normalized loss from this loss function to the total - # batch loss - loss_fstep = loss_fstep + ( - loss_fct_weight * loss_lfct * stream_loss_weight * fstep_weight - ) - ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 - - loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) - ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 - - loss = loss + ((spoof_weight * loss_fsteps) / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) - ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 - - # normalize by forecast step - losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 - stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 - - # replace channels without information by nan to exclude from further computations - losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan - stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan - # normalize by all targets and forecast steps that were non-empty - # (with each having an expected loss of 1 for an uninitalized neural net) - loss = loss / ctr_streams - - # Return all computed loss components encapsulated in a ModelLoss dataclass - return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) - - -class LossCalculatorPhysical(LossCalculatorBase): - """ - Manages and computes the overall loss for a WeatherGenerator model during - training and validation stages. - - This class handles the initialization and application of various loss functions, - applies channel-specific weights, constructs masks for missing data, and - aggregates losses across different data streams, channels, and forecast steps. - It provides both the main loss for backpropagation and detailed loss metrics for logging. - """ - - def __init__( - self, - cf: DictConfig, - stage: Stage, - device: str, - ): - LossCalculatorBase.__init__(self) - self.cf = cf - self.stage = stage - self.device = device - - # Dynamically load loss functions based on configuration and stage - loss_fcts = cf.loss_fcts if stage == TRAIN else cf.loss_fcts_val - self.loss_fcts = [ - [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] - for name, w in loss_fcts + loss_fcts_physical = [[name, w] for name, w in loss_fcts if name.split(":")[0] != "latent"] + loss_fcts_latent = [ + [name.split(":")[1], w] for name, w in loss_fcts if name.split(":")[0] == "latent" ] - def _get_weights(self, stream_info): - """ - Get weights for current stream - """ - - device = self.device - - # Determine stream and channel loss weights based on the current stage - if self.stage == TRAIN: - # set loss_weights to 1. when not specified - stream_info_loss_weight = stream_info.get("loss_weight", 1.0) - weights_channels = ( - torch.tensor(stream_info["target_channel_weights"]).to( - device=device, non_blocking=True - ) - if "target_channel_weights" in stream_info - else None - ) - elif self.stage == VAL: - # in validation mode, always unweighted loss - stream_info_loss_weight = 1.0 - weights_channels = None - - return stream_info_loss_weight, weights_channels - - def _get_fstep_weights(self, forecast_steps): - timestep_weight_config = self.cf.get("timestep_weight") - if timestep_weight_config is None: - return [1.0 for _ in range(forecast_steps)] - weights_timestep_fct = getattr(losses, timestep_weight_config[0]) - return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) - - def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep): - location_weight_type = stream_info.get("location_weight", None) - if location_weight_type is None: - return None - weights_locations_fct = getattr(losses, location_weight_type) - weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep) - weights_locations = weights_locations.to(device=self.device, non_blocking=True) - - return weights_locations - - def _get_substep_masks(self, stream_info, fstep, stream_data): - """ - Find substeps and create corresponding masks (reused across loss functions) - """ + calculator_configs = [] - tok_spacetime = stream_info.get("tokenize_spacetime", None) - target_times = stream_data.target_times_raw[self.cf.forecast_offset + fstep] - target_times_unique = np.unique(target_times) if tok_spacetime else [target_times] - substep_masks = [] - for t in target_times_unique: - # find substep - mask_t = torch.tensor(t == target_times).to(self.device, non_blocking=True) - substep_masks.append(mask_t) + if loss_fcts_physical: + calculator_configs.append((LossCalculatorPhysical, loss_fcts_physical, "physical")) + if loss_fcts_latent: + calculator_configs.append((LossCalculatorLatent, loss_fcts_latent, "latent")) - return substep_masks + self.loss_calculators = [ + (Cls(cf=cf, loss_fcts=losses, stage=stage, device=self.device), type) + for (Cls, losses, type) in calculator_configs + ] def compute_loss( self, - preds: list[list[Tensor]], - streams_data: list[list[any]], - ) -> LossValues: - """ - Computes the total loss for a given batch of predictions and corresponding - stream data. - - The computed loss is: - - Mean_{stream}( Mean_{fsteps}( Mean_{loss_fcts}( loss_fct( target, pred, weigths) ))) - - This method orchestrates the calculation of the overall loss by iterating through - different data streams, forecast steps, channels, and configured loss functions. - It applies weighting, handles NaN values through masking, and accumulates - detailed loss metrics for logging. - - Args: - preds: A nested list of prediction tensors. The outer list represents forecast steps, - the inner list represents streams. Each tensor contains predictions for that - step and stream. - streams_data: A nested list representing the input batch data. The outer list is for - batch items, the inner list for streams. Each element provides an object - (e.g., dataclass instance) containing target data and metadata. - - Returns: - A ModelLoss dataclass instance containing: - - loss: The loss for back-propagation. - - losses_all: A dictionary mapping stream names to a tensor of per-channel and - per-loss-function losses, normalized by non-empty targets/forecast steps. - - stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations - of predictions for channels with statistical loss functions, normalized. - """ - - # gradient loss - loss = torch.tensor(0.0, device=self.device, requires_grad=True) - # counter for non-empty targets - ctr_streams = 0 - - # initialize dictionaries for detailed loss tracking and standard deviation statistics - # create tensor for each stream - losses_all: dict[str, Tensor] = { - st.name: torch.zeros( - (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), - device=self.device, - ) - for st in self.cf.streams - } - stddev_all: dict[str, Tensor] = { - st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams - } - - # TODO: iterate over batch dimension - i_batch = 0 - for i_stream_info, stream_info in enumerate(self.cf.streams): - # extract target tokens for current stream from the specified forecast offset onwards - targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] - - stream_data = streams_data[i_batch][i_stream_info] - - fstep_loss_weights = self._get_fstep_weights(len(targets)) - - loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_fsteps = 0 - - stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() - if stream_is_spoof: - spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) - else: - spoof_weight = torch.tensor(1.0, device=self.device, requires_grad=False) - - for fstep, (target, fstep_weight) in enumerate( - zip(targets, fstep_loss_weights, strict=False) - ): - # skip if either target or prediction has no data points - pred = preds[fstep][i_stream_info] - if not (target.shape[0] > 0 and pred.shape[0] > 0): - continue - - # reshape prediction tensor to match target's dimensions: extract data/coords and - # remove token dimension if it exists. - # expected final shape of pred is [ensemble_size, num_samples, num_channels]. - pred = pred.reshape([pred.shape[0], *target.shape]) - assert pred.shape[1] > 0 - - # get weigths for current streams - stream_loss_weight, weights_channels = self._get_weights(stream_info) - - # get weights for locations - weights_locations = self._get_location_weights( - stream_info, stream_data, self.cf.forecast_offset, fstep - ) - - # get masks for sub-time steps - substep_masks = self._get_substep_masks(stream_info, fstep, stream_data) - - # accumulate loss from different loss functions - loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_loss_fcts = 0 - for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): - # loss for current loss function - loss_lfct, loss_lfct_chs = LossCalculator._loss_per_loss_function( - loss_fct, - stream_info, - target, - pred, - substep_masks, - weights_channels, - weights_locations, - ) - losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs - - # Add the weighted and normalized loss from this loss function to the total - # batch loss - loss_fstep = loss_fstep + ( - loss_fct_weight * loss_lfct * stream_loss_weight * fstep_weight - ) - ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 - - loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) - ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 - - loss = loss + ((spoof_weight * loss_fsteps) / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) - ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 - - # normalize by forecast step - losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 - stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 - - # replace channels without information by nan to exclude from further computations - losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan - stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan - - # normalize by all targets and forecast steps that were non-empty - # (with each having an expected loss of 1 for an uninitalized neural net) - loss = loss / ctr_streams - - # Return all computed loss components encapsulated in a ModelLoss dataclass + preds: dict, + targets: dict, + ): + loss_values = {} + loss = 0 + for calculator, type in self.loss_calculators: + loss_values[type] = calculator.compute_loss(preds=preds[type], targets=targets[type]) + loss += loss_values[type].loss + + losses_all = {} + stddev_all = {} + for _, v in loss_values.items(): + losses_all.update(v.losses_all) + stddev_all.update(v.stddev_all) return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) diff --git a/src/weathergen/train/loss_calculator_base.py b/src/weathergen/train/loss_calculator_base.py index 43091978c..720116baa 100644 --- a/src/weathergen/train/loss_calculator_base.py +++ b/src/weathergen/train/loss_calculator_base.py @@ -17,6 +17,11 @@ from weathergen.common.config import Config from weathergen.utils.train_logger import Stage +# @dataclasses.dataclass +# class InputOutputStructure: + +# targets.latent + @dataclasses.dataclass class LossValues: diff --git a/src/weathergen/train/loss_calculator_classes.py b/src/weathergen/train/loss_calculator_classes.py new file mode 100644 index 000000000..62ed45a7d --- /dev/null +++ b/src/weathergen/train/loss_calculator_classes.py @@ -0,0 +1,285 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import numpy as np +import torch +from omegaconf import DictConfig +from torch import Tensor + +import weathergen.train.loss as losses +from weathergen.train.loss import stat_loss_fcts +from weathergen.train.loss_calculator_base import LossCalculatorBase, LossValues +from weathergen.utils.train_logger import TRAIN, VAL, Stage + +_logger = logging.getLogger(__name__) + + +class LossCalculatorPhysical(LossCalculatorBase): + """ + Manages and computes the overall loss for a WeatherGenerator model during + training and validation stages. + + This class handles the initialization and application of various loss functions, + applies channel-specific weights, constructs masks for missing data, and + aggregates losses across different data streams, channels, and forecast steps. + It provides both the main loss for backpropagation and detailed loss metrics for logging. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + LossCalculatorBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + + # Dynamically load loss functions based on configuration and stage + + self.loss_fcts = [ + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + for name, w in loss_fcts + ] + + def _get_weights(self, stream_info): + """ + Get weights for current stream + """ + + device = self.device + + # Determine stream and channel loss weights based on the current stage + if self.stage == TRAIN: + # set loss_weights to 1. when not specified + stream_info_loss_weight = stream_info.get("loss_weight", 1.0) + weights_channels = ( + torch.tensor(stream_info["target_channel_weights"]).to( + device=device, non_blocking=True + ) + if "target_channel_weights" in stream_info + else None + ) + elif self.stage == VAL: + # in validation mode, always unweighted loss + stream_info_loss_weight = 1.0 + weights_channels = None + + return stream_info_loss_weight, weights_channels + + def _get_fstep_weights(self, forecast_steps): + timestep_weight_config = self.cf.get("timestep_weight") + if timestep_weight_config is None: + return [1.0 for _ in range(forecast_steps)] + weights_timestep_fct = getattr(losses, timestep_weight_config[0]) + return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) + + def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep): + location_weight_type = stream_info.get("location_weight", None) + if location_weight_type is None: + return None + weights_locations_fct = getattr(losses, location_weight_type) + weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep) + weights_locations = weights_locations.to(device=self.device, non_blocking=True) + + return weights_locations + + def _get_substep_masks(self, stream_info, fstep, stream_data): + """ + Find substeps and create corresponding masks (reused across loss functions) + """ + + tok_spacetime = stream_info.get("tokenize_spacetime", None) + target_times = stream_data.target_times_raw[self.cf.forecast_offset + fstep] + target_times_unique = np.unique(target_times) if tok_spacetime else [target_times] + substep_masks = [] + for t in target_times_unique: + # find substep + mask_t = torch.tensor(t == target_times).to(self.device, non_blocking=True) + substep_masks.append(mask_t) + + return substep_masks + + def compute_loss( + self, + preds: list[list[Tensor]], + targets: list[list[any]], + ) -> LossValues: + """ + Computes the total loss for a given batch of predictions and corresponding + stream data. + + The computed loss is: + + Mean_{stream}( Mean_{fsteps}( Mean_{loss_fcts}( loss_fct( target, pred, weigths) ))) + + This method orchestrates the calculation of the overall loss by iterating through + different data streams, forecast steps, channels, and configured loss functions. + It applies weighting, handles NaN values through masking, and accumulates + detailed loss metrics for logging. + + Args: + preds: A nested list of prediction tensors. The outer list represents forecast steps, + the inner list represents streams. Each tensor contains predictions for that + step and stream. + streams_data: A nested list representing the input batch data. The outer list is for + batch items, the inner list for streams. Each element provides an object + (e.g., dataclass instance) containing target data and metadata. + + Returns: + A ModelLoss dataclass instance containing: + - loss: The loss for back-propagation. + - losses_all: A dictionary mapping stream names to a tensor of per-channel and + per-loss-function losses, normalized by non-empty targets/forecast steps. + - stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations + of predictions for channels with statistical loss functions, normalized. + """ + + streams_data = targets + + # gradient loss + loss = torch.tensor(0.0, device=self.device, requires_grad=True) + # counter for non-empty targets + ctr_streams = 0 + + # initialize dictionaries for detailed loss tracking and standard deviation statistics + # create tensor for each stream + losses_all: dict[str, Tensor] = { + st.name: torch.zeros( + (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), + device=self.device, + ) + for st in self.cf.streams + } + stddev_all: dict[str, Tensor] = { + st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams + } + + # TODO: iterate over batch dimension + i_batch = 0 + for i_stream_info, stream_info in enumerate(self.cf.streams): + # extract target tokens for current stream from the specified forecast offset onwards + targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] + + stream_data = streams_data[i_batch][i_stream_info] + + fstep_loss_weights = self._get_fstep_weights(len(targets)) + + loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps = 0 + + stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() + if stream_is_spoof: + spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) + else: + spoof_weight = torch.tensor(1.0, device=self.device, requires_grad=False) + + for fstep, (target, fstep_weight) in enumerate( + zip(targets, fstep_loss_weights, strict=False) + ): + # skip if either target or prediction has no data points + pred = preds[fstep][i_stream_info] + if not (target.shape[0] > 0 and pred.shape[0] > 0): + continue + + # reshape prediction tensor to match target's dimensions: extract data/coords and + # remove token dimension if it exists. + # expected final shape of pred is [ensemble_size, num_samples, num_channels]. + pred = pred.reshape([pred.shape[0], *target.shape]) + assert pred.shape[1] > 0 + + # get weigths for current streams + stream_loss_weight, weights_channels = self._get_weights(stream_info) + + # get weights for locations + weights_locations = self._get_location_weights( + stream_info, stream_data, self.cf.forecast_offset, fstep + ) + + # get masks for sub-time steps + substep_masks = self._get_substep_masks(stream_info, fstep, stream_data) + + # accumulate loss from different loss functions + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): + # loss for current loss function + loss_lfct, loss_lfct_chs = self._loss_per_loss_function( + loss_fct, + target, + pred, + substep_masks, + weights_channels, + weights_locations, + ) + losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + ( + loss_fct_weight * loss_lfct * stream_loss_weight * fstep_weight + ) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) + ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss + ((spoof_weight * loss_fsteps) / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) + ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 + + # normalize by forecast step + losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + + # replace channels without information by nan to exclude from further computations + losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan + stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan + + # normalize by all targets and forecast steps that were non-empty + # (with each having an expected loss of 1 for an uninitalized neural net) + loss = loss / ctr_streams + + # Return all computed loss components encapsulated in a ModelLoss dataclass + return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) + + +class LossCalculatorLatent(LossCalculatorBase): + """ + Manages and computes the overall loss for a WeatherGenerator model during + training and validation stages. + + This class handles the initialization and application of various loss functions, + applies channel-specific weights, constructs masks for missing data, and + aggregates losses across different data streams, channels, and forecast steps. + It provides both the main loss for backpropagation and detailed loss metrics for logging. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + LossCalculatorBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + + # Dynamically load loss functions based on configuration and stage + self.loss_fcts = [ + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + for name, w in loss_fcts + ] diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 3d847a671..3c31daed6 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -588,12 +588,14 @@ def train(self, epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - preds, posteriors = self.model( + predictions, posteriors = self.model( self.model_params, batch, cf.forecast_offset, forecast_steps ) + targets = {"physical": batch[0]} + preds = {"physical": predictions, "latent": posteriors} loss_values = self.loss_calculator.compute_loss( preds=preds, - streams_data=batch[0], + targets=targets, ) if cf.latent_noise_kl_weight > 0.0: kl = torch.cat([posterior.kl() for posterior in posteriors]) From e822e12928605c11482ede836b409decca1d658b Mon Sep 17 00:00:00 2001 From: Jubeku Date: Thu, 6 Nov 2025 16:45:38 +0100 Subject: [PATCH 03/10] add abstract method to loss calculator base class --- src/weathergen/train/loss_calculator_base.py | 53 ++++--------------- .../train/loss_calculator_classes.py | 37 +++++++++++++ 2 files changed, 47 insertions(+), 43 deletions(-) diff --git a/src/weathergen/train/loss_calculator_base.py b/src/weathergen/train/loss_calculator_base.py index 720116baa..13ad0394d 100644 --- a/src/weathergen/train/loss_calculator_base.py +++ b/src/weathergen/train/loss_calculator_base.py @@ -10,8 +10,8 @@ # nor does it submit to any jurisdiction. import dataclasses +from abc import abstractmethod -import torch from torch import Tensor from weathergen.common.config import Config @@ -44,11 +44,7 @@ class LossValues: class LossCalculatorBase: def __init__(self): """ - Initializes the LossCalculator. - - This sets up the configuration, the operational stage (training or validation), - the device for tensor operations, and initializes the list of loss functions - based on the provided configuration. + Base class for loss calculators. Args: cf: The OmegaConf DictConfig object containing model and training configurations. @@ -61,43 +57,14 @@ def __init__(self): self.stage: Stage self.loss_fcts = [] - @staticmethod - def _loss_per_loss_function( - loss_fct, - target: torch.Tensor, - pred: torch.Tensor, - substep_masks: list[torch.Tensor], - weights_channels: torch.Tensor, - weights_locations: torch.Tensor, - ): + @abstractmethod + def compute_loss( + self, + preds: dict, + targets: dict, + ) -> LossValues: """ - Compute loss for given loss function + Computes loss given predictions and targets and returns values of LossValues dataclass. """ - loss_lfct = torch.tensor(0.0, device=target.device, requires_grad=True) - losses_chs = torch.zeros(target.shape[-1], device=target.device, dtype=torch.float32) - - ctr_substeps = 0 - for mask_t in substep_masks: - assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True - - loss, loss_chs = loss_fct( - target[mask_t], pred[:, mask_t], weights_channels, weights_locations - ) - - # accumulate loss - loss_lfct = loss_lfct + loss - losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs - ctr_substeps += 1 if loss > 0.0 else 0 - - # normalize over forecast steps in window - losses_chs /= ctr_substeps if ctr_substeps > 0 else 1.0 - - # TODO: substep weight - loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) - - return loss_lfct, losses_chs - - # def _get_weights(self, stream_info): - - # def _update_weights(self, stream_info): + raise NotImplementedError() diff --git a/src/weathergen/train/loss_calculator_classes.py b/src/weathergen/train/loss_calculator_classes.py index 62ed45a7d..6edab4bf4 100644 --- a/src/weathergen/train/loss_calculator_classes.py +++ b/src/weathergen/train/loss_calculator_classes.py @@ -112,6 +112,43 @@ def _get_substep_masks(self, stream_info, fstep, stream_data): return substep_masks + @staticmethod + def _loss_per_loss_function( + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + substep_masks: list[torch.Tensor], + weights_channels: torch.Tensor, + weights_locations: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_lfct = torch.tensor(0.0, device=target.device, requires_grad=True) + losses_chs = torch.zeros(target.shape[-1], device=target.device, dtype=torch.float32) + + ctr_substeps = 0 + for mask_t in substep_masks: + assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True + + loss, loss_chs = loss_fct( + target[mask_t], pred[:, mask_t], weights_channels, weights_locations + ) + + # accumulate loss + loss_lfct = loss_lfct + loss + losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs + ctr_substeps += 1 if loss > 0.0 else 0 + + # normalize over forecast steps in window + losses_chs /= ctr_substeps if ctr_substeps > 0 else 1.0 + + # TODO: substep weight + loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) + + return loss_lfct, losses_chs + def compute_loss( self, preds: list[list[Tensor]], From d24ef486279fa784a494d3e94b407c0ba2604a09 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Thu, 6 Nov 2025 17:21:11 +0100 Subject: [PATCH 04/10] add latent loss class --- src/weathergen/train/loss_calculator.py | 2 + .../train/loss_calculator_classes.py | 69 +++++++++++++++++-- 2 files changed, 64 insertions(+), 7 deletions(-) diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index ffd018f22..b2bdcc88e 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -80,6 +80,8 @@ def compute_loss( loss_values[type] = calculator.compute_loss(preds=preds[type], targets=targets[type]) loss += loss_values[type].loss + # Bring all loss values together + # TODO: keys should tell what type of loss was used, e.g loss_mse.latent.loss_2t losses_all = {} stddev_all = {} for _, v in loss_values.items(): diff --git a/src/weathergen/train/loss_calculator_classes.py b/src/weathergen/train/loss_calculator_classes.py index 6edab4bf4..13b9ae76a 100644 --- a/src/weathergen/train/loss_calculator_classes.py +++ b/src/weathergen/train/loss_calculator_classes.py @@ -294,13 +294,7 @@ def compute_loss( class LossCalculatorLatent(LossCalculatorBase): """ - Manages and computes the overall loss for a WeatherGenerator model during - training and validation stages. - - This class handles the initialization and application of various loss functions, - applies channel-specific weights, constructs masks for missing data, and - aggregates losses across different data streams, channels, and forecast steps. - It provides both the main loss for backpropagation and detailed loss metrics for logging. + Calculates loss in latent space. """ def __init__( @@ -320,3 +314,64 @@ def __init__( [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] for name, w in loss_fcts ] + + def _loss_per_loss_function( + self, + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_val = loss_fct(target=target, ens=None, mu=pred) + + return loss_val + + def compute_loss( + self, + preds: list[list[Tensor]], + targets: list[list[any]], + ) -> LossValues: + losses_all: Tensor = torch.zeros( + len(self.loss_fcts), + device=self.device, + ) + + loss_fsteps_lat = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps_lat = 0 + # TODO: KCT, do we need the below per fstep? + for fstep in range( + 1, len(preds) + ): # the first entry in tokens_all is the source itself, so skip it + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + # if forecast_offset==0, then the timepoints correspond. Otherwise targets don't encode the source timestep, so we don't need to skip + fstep_targs = fstep if self.cf.forecast_offset == 0 else fstep - 1 + for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts_lat): + loss_lfct = self._loss_per_loss_function( + loss_fct, + stream_info=None, + target=targets[fstep_targs], + pred=preds[fstep], + ) + + losses_all[i_lfct] += loss_lfct # TODO: break into fsteps + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + (loss_fct_weight * loss_lfct) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps_lat = loss_fsteps_lat + ( + loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0 + ) + ctr_fsteps_lat += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss_fsteps_lat / (ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0) + + losses_all /= ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0 + losses_all[losses_all == 0.0] = torch.nan + + return LossValues(loss=loss, losses_all=losses_all) From c259c20421ce559d0bc1a4530ada37a196198b2d Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 7 Nov 2025 16:15:33 +0100 Subject: [PATCH 05/10] update loss calc config and rename files --- config/default_config.yml | 11 ++++-- .../weathergen/evaluate/export_inference.py | 6 ++- src/weathergen/train/loss_calculator.py | 32 +++++++-------- ...s_calculator_classes.py => loss_module.py} | 39 ++++++++++++++----- ...calculator_base.py => loss_module_base.py} | 7 +--- 5 files changed, 56 insertions(+), 39 deletions(-) rename src/weathergen/train/{loss_calculator_classes.py => loss_module.py} (94%) rename src/weathergen/train/{loss_calculator_base.py => loss_module_base.py} (95%) diff --git a/config/default_config.yml b/config/default_config.yml index 26f9382c2..e2de3ff21 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -73,9 +73,6 @@ loss_fcts: - - "mse" - 1.0 - # - - # - "latent:mse" - # - 0.3 loss_fcts_val: - - "mse" @@ -97,6 +94,14 @@ ema_halflife_in_thousands: 1e-3 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 training_mode: "masking" +training_mode_config: {"losses": {LossPhysical: [['mse', 1.0]],} + } +# training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]], +# LossLatent: [['mse', 0.3]], +# LossStudentTeacher: [{'iBOT': {}, 'JEPA': {options}}],} +# } +validation_mode_config: {"losses": {LossPhysical: [['mse', 1.0]],} + } # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # sample the masking rate (with normal distribution centered at masking_rate) diff --git a/packages/evaluate/src/weathergen/evaluate/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export_inference.py index 2c0cb4243..4e0bd7d6e 100755 --- a/packages/evaluate/src/weathergen/evaluate/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export_inference.py @@ -61,6 +61,7 @@ def detect_grid_type(input_data_array: xr.DataArray) -> str: # Otherwise it's Gaussian (irregular spacing or reduced grid) return "gaussian" + def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]: """ Find all the pressure levels for each variable using regex and returns a dictionary @@ -90,6 +91,7 @@ def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]: pl = list(set(pl)) return var_dict, pl + def reshape_dataset_adaptive(input_data_array: xr.DataArray) -> xr.Dataset: """ Reshape dataset while preserving grid structure (regular or Gaussian). @@ -176,8 +178,6 @@ def add_gaussian_grid_metadata(ds: xr.Dataset, grid_info: dict | None = None) -> return ds - - def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset: """ Add CF conventions to the dataset attributes. @@ -201,6 +201,7 @@ def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset: ds.attrs["Conventions"] = "CF-1.12" return ds + def cf_parser_gaussian_aware(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset: """ Modified CF parser that handles both regular and Gaussian grids. @@ -323,6 +324,7 @@ def cf_parser_gaussian_aware(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset: return dataset + def output_filename( prefix: str, run_id: str, diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index b2bdcc88e..dfd582ec8 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -13,8 +13,8 @@ from omegaconf import DictConfig -from weathergen.train.loss_calculator_base import LossValues -from weathergen.train.loss_calculator_classes import LossCalculatorLatent, LossCalculatorPhysical +import weathergen.train.loss_module as LossModule +from weathergen.train.loss_module_base import LossValues from weathergen.utils.train_logger import TRAIN, Stage _logger = logging.getLogger(__name__) @@ -50,23 +50,17 @@ def __init__( self.stage = stage self.device = device - loss_fcts = cf.loss_fcts if stage == TRAIN else cf.loss_fcts_val + calculator_configs = ( + cf.training_mode_config.losses if stage == TRAIN else cf.validation_mode_config.losses + ) - loss_fcts_physical = [[name, w] for name, w in loss_fcts if name.split(":")[0] != "latent"] - loss_fcts_latent = [ - [name.split(":")[1], w] for name, w in loss_fcts if name.split(":")[0] == "latent" + calculator_configs = [ + (getattr(LossModule, Cls), losses) for (Cls, losses) in calculator_configs.items() ] - calculator_configs = [] - - if loss_fcts_physical: - calculator_configs.append((LossCalculatorPhysical, loss_fcts_physical, "physical")) - if loss_fcts_latent: - calculator_configs.append((LossCalculatorLatent, loss_fcts_latent, "latent")) - self.loss_calculators = [ - (Cls(cf=cf, loss_fcts=losses, stage=stage, device=self.device), type) - for (Cls, losses, type) in calculator_configs + Cls(cf=cf, loss_fcts=losses, stage=stage, device=self.device) + for (Cls, losses) in calculator_configs ] def compute_loss( @@ -76,12 +70,12 @@ def compute_loss( ): loss_values = {} loss = 0 - for calculator, type in self.loss_calculators: - loss_values[type] = calculator.compute_loss(preds=preds[type], targets=targets[type]) - loss += loss_values[type].loss + for calculator in self.loss_calculators: + loss_values[calculator.name] = calculator.compute_loss(preds=preds, targets=targets) + loss += loss_values[calculator.name].loss # Bring all loss values together - # TODO: keys should tell what type of loss was used, e.g loss_mse.latent.loss_2t + # TODO: make sure keys are explicit, e.g loss_mse.latent.loss_2t losses_all = {} stddev_all = {} for _, v in loss_values.items(): diff --git a/src/weathergen/train/loss_calculator_classes.py b/src/weathergen/train/loss_module.py similarity index 94% rename from src/weathergen/train/loss_calculator_classes.py rename to src/weathergen/train/loss_module.py index 13b9ae76a..2b345c3fe 100644 --- a/src/weathergen/train/loss_calculator_classes.py +++ b/src/weathergen/train/loss_module.py @@ -18,13 +18,13 @@ import weathergen.train.loss as losses from weathergen.train.loss import stat_loss_fcts -from weathergen.train.loss_calculator_base import LossCalculatorBase, LossValues +from weathergen.train.loss_module_base import LossModuleBase, LossValues from weathergen.utils.train_logger import TRAIN, VAL, Stage _logger = logging.getLogger(__name__) -class LossCalculatorPhysical(LossCalculatorBase): +class LossPhysical(LossModuleBase): """ Manages and computes the overall loss for a WeatherGenerator model during training and validation stages. @@ -42,13 +42,13 @@ def __init__( stage: Stage, device: str, ): - LossCalculatorBase.__init__(self) + LossModuleBase.__init__(self) self.cf = cf self.stage = stage self.device = device + self.name = "LossPhysical" # Dynamically load loss functions based on configuration and stage - self.loss_fcts = [ [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] for name, w in loss_fcts @@ -151,8 +151,8 @@ def _loss_per_loss_function( def compute_loss( self, - preds: list[list[Tensor]], - targets: list[list[any]], + preds: dict, + targets: dict, ) -> LossValues: """ Computes the total loss for a given batch of predictions and corresponding @@ -184,7 +184,8 @@ def compute_loss( of predictions for channels with statistical loss functions, normalized. """ - streams_data = targets + preds = preds["physical"] + streams_data = targets["physical"] # gradient loss loss = torch.tensor(0.0, device=self.device, requires_grad=True) @@ -292,7 +293,7 @@ def compute_loss( return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) -class LossCalculatorLatent(LossCalculatorBase): +class LossLatent(LossModuleBase): """ Calculates loss in latent space. """ @@ -304,10 +305,11 @@ def __init__( stage: Stage, device: str, ): - LossCalculatorBase.__init__(self) + LossModuleBase.__init__(self) self.cf = cf self.stage = stage self.device = device + self.name = "LossLatent" # Dynamically load loss functions based on configuration and stage self.loss_fcts = [ @@ -375,3 +377,22 @@ def compute_loss( losses_all[losses_all == 0.0] = torch.nan return LossValues(loss=loss, losses_all=losses_all) + + +class LossStudentTeacher(LossModuleBase): + """ + Calculates loss in latent space. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + self.name = "LossStudentTeacher" + raise NotImplementedError() + + def compute_loss(self, preds, targets): + return super().compute_loss(preds, targets) diff --git a/src/weathergen/train/loss_calculator_base.py b/src/weathergen/train/loss_module_base.py similarity index 95% rename from src/weathergen/train/loss_calculator_base.py rename to src/weathergen/train/loss_module_base.py index 13ad0394d..de66bda28 100644 --- a/src/weathergen/train/loss_calculator_base.py +++ b/src/weathergen/train/loss_module_base.py @@ -17,11 +17,6 @@ from weathergen.common.config import Config from weathergen.utils.train_logger import Stage -# @dataclasses.dataclass -# class InputOutputStructure: - -# targets.latent - @dataclasses.dataclass class LossValues: @@ -41,7 +36,7 @@ class LossValues: stddev_all: dict[str, Tensor] -class LossCalculatorBase: +class LossModuleBase: def __init__(self): """ Base class for loss calculators. From a19ee1658f65d1e0074ddc137ac58e70f66e6622 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 11 Nov 2025 15:41:29 +0100 Subject: [PATCH 06/10] restructure loss modules --- src/weathergen/train/loss_calculator.py | 6 +- src/weathergen/train/loss_modules/__init__.py | 5 + .../train/{ => loss_modules}/loss.py | 0 .../{ => loss_modules}/loss_module_base.py | 0 .../train/loss_modules/loss_module_latent.py | 112 ++++++++++++++++++ .../loss_module_physical.py} | 111 +---------------- .../train/loss_modules/loss_module_ssl.py | 38 ++++++ 7 files changed, 161 insertions(+), 111 deletions(-) create mode 100644 src/weathergen/train/loss_modules/__init__.py rename src/weathergen/train/{ => loss_modules}/loss.py (100%) rename src/weathergen/train/{ => loss_modules}/loss_module_base.py (100%) create mode 100644 src/weathergen/train/loss_modules/loss_module_latent.py rename src/weathergen/train/{loss_module.py => loss_modules/loss_module_physical.py} (77%) create mode 100644 src/weathergen/train/loss_modules/loss_module_ssl.py diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index dfd582ec8..2eda80fce 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -13,8 +13,8 @@ from omegaconf import DictConfig -import weathergen.train.loss_module as LossModule -from weathergen.train.loss_module_base import LossValues +import weathergen.train.loss_modules as LossModules +from weathergen.train.loss_modules.loss_module_base import LossValues from weathergen.utils.train_logger import TRAIN, Stage _logger = logging.getLogger(__name__) @@ -55,7 +55,7 @@ def __init__( ) calculator_configs = [ - (getattr(LossModule, Cls), losses) for (Cls, losses) in calculator_configs.items() + (getattr(LossModules, Cls), losses) for (Cls, losses) in calculator_configs.items() ] self.loss_calculators = [ diff --git a/src/weathergen/train/loss_modules/__init__.py b/src/weathergen/train/loss_modules/__init__.py new file mode 100644 index 000000000..7f5fc906d --- /dev/null +++ b/src/weathergen/train/loss_modules/__init__.py @@ -0,0 +1,5 @@ +from .loss_module_latent import LossLatent +from .loss_module_physical import LossPhysical +from .loss_module_ssl import LossStudentTeacher + +__all__ = [LossLatent, LossPhysical, LossStudentTeacher] diff --git a/src/weathergen/train/loss.py b/src/weathergen/train/loss_modules/loss.py similarity index 100% rename from src/weathergen/train/loss.py rename to src/weathergen/train/loss_modules/loss.py diff --git a/src/weathergen/train/loss_module_base.py b/src/weathergen/train/loss_modules/loss_module_base.py similarity index 100% rename from src/weathergen/train/loss_module_base.py rename to src/weathergen/train/loss_modules/loss_module_base.py diff --git a/src/weathergen/train/loss_modules/loss_module_latent.py b/src/weathergen/train/loss_modules/loss_module_latent.py new file mode 100644 index 000000000..6daf472bb --- /dev/null +++ b/src/weathergen/train/loss_modules/loss_module_latent.py @@ -0,0 +1,112 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import torch +from omegaconf import DictConfig +from torch import Tensor + +import weathergen.train.loss_modules.loss as losses +from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues +from weathergen.utils.train_logger import Stage + +_logger = logging.getLogger(__name__) + + +class LossLatent(LossModuleBase): + """ + Calculates loss in latent space. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + LossModuleBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + self.name = "LossLatent" + + # Dynamically load loss functions based on configuration and stage + self.loss_fcts = [ + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + for name, w in loss_fcts + ] + + def _loss_per_loss_function( + self, + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_val = loss_fct(target=target, ens=None, mu=pred) + + return loss_val + + def compute_loss( + self, + preds: list[list[Tensor]], + targets: list[list[any]], + ) -> LossValues: + return super().compute_loss(preds, targets) + + ### FROM KEREM's PR + # losses_all: Tensor = torch.zeros( + # len(self.loss_fcts), + # device=self.device, + # ) + + # loss_fsteps_lat = torch.tensor(0.0, device=self.device, requires_grad=True) + # ctr_fsteps_lat = 0 + # # TODO: KCT, do we need the below per fstep? + # for fstep in range( + # 1, len(preds) + # ): # the first entry in tokens_all is the source itself, so skip it + # loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + # ctr_loss_fcts = 0 + # # if forecast_offset==0, then the timepoints correspond. + # # Otherwise targets don't encode the source timestep, so we don't need to skip + # fstep_targs = fstep if self.cf.forecast_offset == 0 else fstep - 1 + # for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts_lat): + # loss_lfct = self._loss_per_loss_function( + # loss_fct, + # stream_info=None, + # target=targets[fstep_targs], + # pred=preds[fstep], + # ) + + # losses_all[i_lfct] += loss_lfct # TODO: break into fsteps + + # # Add the weighted and normalized loss from this loss function to the total + # # batch loss + # loss_fstep = loss_fstep + (loss_fct_weight * loss_lfct) + # ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + # loss_fsteps_lat = loss_fsteps_lat + ( + # loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0 + # ) + # ctr_fsteps_lat += 1 if ctr_loss_fcts > 0 else 0 + + # loss = loss_fsteps_lat / (ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0) + + # losses_all /= ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0 + # losses_all[losses_all == 0.0] = torch.nan + + # return LossValues(loss=loss, losses_all=losses_all) diff --git a/src/weathergen/train/loss_module.py b/src/weathergen/train/loss_modules/loss_module_physical.py similarity index 77% rename from src/weathergen/train/loss_module.py rename to src/weathergen/train/loss_modules/loss_module_physical.py index 2b345c3fe..db4917550 100644 --- a/src/weathergen/train/loss_module.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -16,9 +16,9 @@ from omegaconf import DictConfig from torch import Tensor -import weathergen.train.loss as losses -from weathergen.train.loss import stat_loss_fcts -from weathergen.train.loss_module_base import LossModuleBase, LossValues +import weathergen.train.loss_modules.loss as losses +from weathergen.train.loss_modules.loss import stat_loss_fcts +from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues from weathergen.utils.train_logger import TRAIN, VAL, Stage _logger = logging.getLogger(__name__) @@ -291,108 +291,3 @@ def compute_loss( # Return all computed loss components encapsulated in a ModelLoss dataclass return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) - - -class LossLatent(LossModuleBase): - """ - Calculates loss in latent space. - """ - - def __init__( - self, - cf: DictConfig, - loss_fcts: list, - stage: Stage, - device: str, - ): - LossModuleBase.__init__(self) - self.cf = cf - self.stage = stage - self.device = device - self.name = "LossLatent" - - # Dynamically load loss functions based on configuration and stage - self.loss_fcts = [ - [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] - for name, w in loss_fcts - ] - - def _loss_per_loss_function( - self, - loss_fct, - target: torch.Tensor, - pred: torch.Tensor, - ): - """ - Compute loss for given loss function - """ - - loss_val = loss_fct(target=target, ens=None, mu=pred) - - return loss_val - - def compute_loss( - self, - preds: list[list[Tensor]], - targets: list[list[any]], - ) -> LossValues: - losses_all: Tensor = torch.zeros( - len(self.loss_fcts), - device=self.device, - ) - - loss_fsteps_lat = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_fsteps_lat = 0 - # TODO: KCT, do we need the below per fstep? - for fstep in range( - 1, len(preds) - ): # the first entry in tokens_all is the source itself, so skip it - loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) - ctr_loss_fcts = 0 - # if forecast_offset==0, then the timepoints correspond. Otherwise targets don't encode the source timestep, so we don't need to skip - fstep_targs = fstep if self.cf.forecast_offset == 0 else fstep - 1 - for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts_lat): - loss_lfct = self._loss_per_loss_function( - loss_fct, - stream_info=None, - target=targets[fstep_targs], - pred=preds[fstep], - ) - - losses_all[i_lfct] += loss_lfct # TODO: break into fsteps - - # Add the weighted and normalized loss from this loss function to the total - # batch loss - loss_fstep = loss_fstep + (loss_fct_weight * loss_lfct) - ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 - - loss_fsteps_lat = loss_fsteps_lat + ( - loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0 - ) - ctr_fsteps_lat += 1 if ctr_loss_fcts > 0 else 0 - - loss = loss_fsteps_lat / (ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0) - - losses_all /= ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0 - losses_all[losses_all == 0.0] = torch.nan - - return LossValues(loss=loss, losses_all=losses_all) - - -class LossStudentTeacher(LossModuleBase): - """ - Calculates loss in latent space. - """ - - def __init__( - self, - cf: DictConfig, - loss_fcts: list, - stage: Stage, - device: str, - ): - self.name = "LossStudentTeacher" - raise NotImplementedError() - - def compute_loss(self, preds, targets): - return super().compute_loss(preds, targets) diff --git a/src/weathergen/train/loss_modules/loss_module_ssl.py b/src/weathergen/train/loss_modules/loss_module_ssl.py new file mode 100644 index 000000000..240a2e27d --- /dev/null +++ b/src/weathergen/train/loss_modules/loss_module_ssl.py @@ -0,0 +1,38 @@ +# ruff: noqa: T201 + +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +from omegaconf import DictConfig + +from weathergen.train.loss_modules.loss_module_base import LossModuleBase +from weathergen.utils.train_logger import Stage + +_logger = logging.getLogger(__name__) + + +class LossStudentTeacher(LossModuleBase): + """ + Calculates loss in latent space. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + self.name = "LossStudentTeacher" + raise NotImplementedError() + + def compute_loss(self, preds, targets): + return super().compute_loss(preds, targets) From bf3e128b28c6042feb53a45e4f1a481cd72fa1a1 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Tue, 11 Nov 2025 16:09:20 +0100 Subject: [PATCH 07/10] add ModelOutput dataclass --- src/weathergen/model/model.py | 16 ++++++++- .../loss_modules/loss_module_physical.py | 2 +- src/weathergen/train/trainer.py | 33 +++++++------------ 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 000f36735..ba5e2bb89 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -9,6 +9,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import dataclasses import logging import math import warnings @@ -42,6 +43,16 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass +class ModelOutput: + """ + A dataclass to encapsulate the model output and give a clear API. + """ + + physical: dict[str, torch.Tensor] + latent: dict[str, torch.Tensor] + + class ModelParams(torch.nn.Module): """Creation of query and embedding parameters of the model.""" @@ -653,7 +664,10 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca ) ] - return preds_all, posteriors + latents = {} + latents["posteriors"] = posteriors + + return ModelOutput(physical=preds_all, latent=latents) ######################################### def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index db4917550..54d30acc1 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -184,7 +184,7 @@ def compute_loss( of predictions for channels with statistical loss functions, normalized. """ - preds = preds["physical"] + preds = preds.physical streams_data = targets["physical"] # gradient loss diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 3c31daed6..6070e9263 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -588,17 +588,14 @@ def train(self, epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - predictions, posteriors = self.model( - self.model_params, batch, cf.forecast_offset, forecast_steps - ) + output = self.model(self.model_params, batch, cf.forecast_offset, forecast_steps) targets = {"physical": batch[0]} - preds = {"physical": predictions, "latent": posteriors} loss_values = self.loss_calculator.compute_loss( - preds=preds, + preds=output, targets=targets, ) if cf.latent_noise_kl_weight > 0.0: - kl = torch.cat([posterior.kl() for posterior in posteriors]) + kl = torch.cat([posterior.kl() for posterior in output.latent]) loss_values.loss += cf.latent_noise_kl_weight * kl.mean() # backward pass @@ -681,17 +678,17 @@ def validate(self, epoch): if self.ema_model is None else self.ema_model.forward_eval ) - preds, _ = model_forward( + output = model_forward( self.model_params, batch, cf.forecast_offset, forecast_steps ) - - # compute loss and log output + targets = {"physical": batch[0]} + # compute loss + loss_values = self.loss_calculator_val.compute_loss( + preds=output, + targets=targets, + ) + # log output if bidx < cf.log_validation: - loss_values = self.loss_calculator_val.compute_loss( - preds=preds, - streams_data=batch[0], - ) - # TODO: Move _prepare_logging into write_validation by passing streams_data ( preds_all, @@ -700,7 +697,7 @@ def validate(self, epoch): targets_times_all, targets_lens, ) = self._prepare_logging( - preds=preds, + preds=output, forecast_offset=cf.forecast_offset, forecast_steps=cf.forecast_steps, streams_data=batch[0], @@ -718,12 +715,6 @@ def validate(self, epoch): targets_lens, ) - else: - loss_values = self.loss_calculator_val.compute_loss( - preds=preds, - streams_data=batch[0], - ) - self.loss_unweighted_hist += [loss_values.losses_all] self.loss_model_hist += [loss_values.loss.item()] self.stdev_unweighted_hist += [loss_values.stddev_all] From cab9fbe9a6745fa6b9cc4b9d7288e69bd0e2bf77 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 14 Nov 2025 10:41:18 +0100 Subject: [PATCH 08/10] mv streams_data declaration under if condition --- src/weathergen/train/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 695150e3b..8cf2c067a 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -692,7 +692,7 @@ def validate(self, epoch): ) targets = {"physical": batch[0]} - streams_data: list[list[StreamData]] = batch[0] + # compute loss loss_values = self.loss_calculator_val.compute_loss( preds=output, @@ -701,6 +701,7 @@ def validate(self, epoch): # log output if bidx < cf.log_validation: # TODO: Move _prepare_logging into write_validation by passing streams_data + streams_data: list[list[StreamData]] = batch[0] ( preds_all, targets_all, From 20da55574f91eef716abcb55882131e29880e3d3 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 14 Nov 2025 12:07:10 +0100 Subject: [PATCH 09/10] add weight to loss config, add toy loss class LossPhysicalTwo --- config/default_config.yml | 6 +- src/weathergen/train/loss_calculator.py | 40 +-- src/weathergen/train/loss_modules/__init__.py | 4 +- .../train/loss_modules/loss_module_base.py | 2 +- .../loss_modules/loss_module_physical.py | 269 ++++++++++++++++++ 5 files changed, 299 insertions(+), 22 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 3da835e5a..e99d9f423 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -94,13 +94,15 @@ ema_halflife_in_thousands: 1e-3 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 training_mode: "masking" -training_mode_config: {"losses": {LossPhysical: [['mse', 1.0]],} +training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 1.0]]}, + LossPhysicalTwo: {weight: 0.3, loss_fcts: [['mse', 1.0]]}, + } } # training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]], # LossLatent: [['mse', 0.3]], # LossStudentTeacher: [{'iBOT': {}, 'JEPA': {options}}],} # } -validation_mode_config: {"losses": {LossPhysical: [['mse', 1.0]],} +validation_mode_config: {"losses": {LossPhysical: {weight: 1.0, loss_fcts: [['mse', 1.0]]},} } # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 2eda80fce..fbfaebdb0 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -9,9 +9,11 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import dataclasses import logging from omegaconf import DictConfig +from torch import Tensor import weathergen.train.loss_modules as LossModules from weathergen.train.loss_modules.loss_module_base import LossValues @@ -20,6 +22,18 @@ _logger = logging.getLogger(__name__) +@dataclasses.dataclass +class LossTerms: + """ + A dataclass which combines the LossValues of all loss modules + """ + + # The primary scalar loss value for optimization. + loss: Tensor + # Dictionary containing the LossValues of each loss module. + loss_terms: dict[str, LossValues] + + class LossCalculator: """ Manages and computes the overall loss for a WeatherGenerator model during @@ -53,14 +67,13 @@ def __init__( calculator_configs = ( cf.training_mode_config.losses if stage == TRAIN else cf.validation_mode_config.losses ) - calculator_configs = [ - (getattr(LossModules, Cls), losses) for (Cls, losses) in calculator_configs.items() + (getattr(LossModules, Cls), config) for (Cls, config) in calculator_configs.items() ] self.loss_calculators = [ - Cls(cf=cf, loss_fcts=losses, stage=stage, device=self.device) - for (Cls, losses) in calculator_configs + (config.weight, Cls(cf=cf, loss_fcts=config.loss_fcts, stage=stage, device=self.device)) + for (Cls, config) in calculator_configs ] def compute_loss( @@ -68,17 +81,10 @@ def compute_loss( preds: dict, targets: dict, ): - loss_values = {} + loss_terms = {} loss = 0 - for calculator in self.loss_calculators: - loss_values[calculator.name] = calculator.compute_loss(preds=preds, targets=targets) - loss += loss_values[calculator.name].loss - - # Bring all loss values together - # TODO: make sure keys are explicit, e.g loss_mse.latent.loss_2t - losses_all = {} - stddev_all = {} - for _, v in loss_values.items(): - losses_all.update(v.losses_all) - stddev_all.update(v.stddev_all) - return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) + for weight, calculator in self.loss_calculators: + loss_terms[calculator.name] = calculator.compute_loss(preds=preds, targets=targets) + loss += weight * loss_terms[calculator.name].loss + + return LossTerms(loss=loss, loss_terms=loss_terms) diff --git a/src/weathergen/train/loss_modules/__init__.py b/src/weathergen/train/loss_modules/__init__.py index 7f5fc906d..43be4dfe1 100644 --- a/src/weathergen/train/loss_modules/__init__.py +++ b/src/weathergen/train/loss_modules/__init__.py @@ -1,5 +1,5 @@ from .loss_module_latent import LossLatent -from .loss_module_physical import LossPhysical +from .loss_module_physical import LossPhysical, LossPhysicalTwo from .loss_module_ssl import LossStudentTeacher -__all__ = [LossLatent, LossPhysical, LossStudentTeacher] +__all__ = [LossLatent, LossPhysical, LossPhysicalTwo, LossStudentTeacher] diff --git a/src/weathergen/train/loss_modules/loss_module_base.py b/src/weathergen/train/loss_modules/loss_module_base.py index de66bda28..8e6ad3b5d 100644 --- a/src/weathergen/train/loss_modules/loss_module_base.py +++ b/src/weathergen/train/loss_modules/loss_module_base.py @@ -21,7 +21,7 @@ @dataclasses.dataclass class LossValues: """ - A dataclass to encapsulate the various loss components computed by the LossCalculator. + A dataclass to encapsulate the loss components returned by each loss module. This provides a structured way to return the primary loss used for optimization, along with detailed per-stream/per-channel/per-loss-function losses for logging, diff --git a/src/weathergen/train/loss_modules/loss_module_physical.py b/src/weathergen/train/loss_modules/loss_module_physical.py index 54d30acc1..1e900f25f 100644 --- a/src/weathergen/train/loss_modules/loss_module_physical.py +++ b/src/weathergen/train/loss_modules/loss_module_physical.py @@ -291,3 +291,272 @@ def compute_loss( # Return all computed loss components encapsulated in a ModelLoss dataclass return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) + + +class LossPhysicalTwo(LossModuleBase): + """ + Manages and computes the overall loss for a WeatherGenerator model during + training and validation stages. + + This class handles the initialization and application of various loss functions, + applies channel-specific weights, constructs masks for missing data, and + aggregates losses across different data streams, channels, and forecast steps. + It provides both the main loss for backpropagation and detailed loss metrics for logging. + """ + + def __init__( + self, + cf: DictConfig, + loss_fcts: list, + stage: Stage, + device: str, + ): + LossModuleBase.__init__(self) + self.cf = cf + self.stage = stage + self.device = device + self.name = "LossPhysicalTwo" + + # Dynamically load loss functions based on configuration and stage + self.loss_fcts = [ + [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] + for name, w in loss_fcts + ] + + def _get_weights(self, stream_info): + """ + Get weights for current stream + """ + + device = self.device + + # Determine stream and channel loss weights based on the current stage + if self.stage == TRAIN: + # set loss_weights to 1. when not specified + stream_info_loss_weight = stream_info.get("loss_weight", 1.0) + weights_channels = ( + torch.tensor(stream_info["target_channel_weights"]).to( + device=device, non_blocking=True + ) + if "target_channel_weights" in stream_info + else None + ) + elif self.stage == VAL: + # in validation mode, always unweighted loss + stream_info_loss_weight = 1.0 + weights_channels = None + + return stream_info_loss_weight, weights_channels + + def _get_fstep_weights(self, forecast_steps): + timestep_weight_config = self.cf.get("timestep_weight") + if timestep_weight_config is None: + return [1.0 for _ in range(forecast_steps)] + weights_timestep_fct = getattr(losses, timestep_weight_config[0]) + return weights_timestep_fct(forecast_steps, timestep_weight_config[1]) + + def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep): + location_weight_type = stream_info.get("location_weight", None) + if location_weight_type is None: + return None + weights_locations_fct = getattr(losses, location_weight_type) + weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep) + weights_locations = weights_locations.to(device=self.device, non_blocking=True) + + return weights_locations + + def _get_substep_masks(self, stream_info, fstep, stream_data): + """ + Find substeps and create corresponding masks (reused across loss functions) + """ + + tok_spacetime = stream_info.get("tokenize_spacetime", None) + target_times = stream_data.target_times_raw[self.cf.forecast_offset + fstep] + target_times_unique = np.unique(target_times) if tok_spacetime else [target_times] + substep_masks = [] + for t in target_times_unique: + # find substep + mask_t = torch.tensor(t == target_times).to(self.device, non_blocking=True) + substep_masks.append(mask_t) + + return substep_masks + + @staticmethod + def _loss_per_loss_function( + loss_fct, + target: torch.Tensor, + pred: torch.Tensor, + substep_masks: list[torch.Tensor], + weights_channels: torch.Tensor, + weights_locations: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_lfct = torch.tensor(0.0, device=target.device, requires_grad=True) + losses_chs = torch.zeros(target.shape[-1], device=target.device, dtype=torch.float32) + + ctr_substeps = 0 + for mask_t in substep_masks: + assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True + + loss, loss_chs = loss_fct( + target[mask_t], pred[:, mask_t], weights_channels, weights_locations + ) + + # accumulate loss + loss_lfct = loss_lfct + loss + losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs + ctr_substeps += 1 if loss > 0.0 else 0 + + # normalize over forecast steps in window + losses_chs /= ctr_substeps if ctr_substeps > 0 else 1.0 + + # TODO: substep weight + loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) + + return loss_lfct, losses_chs + + def compute_loss( + self, + preds: dict, + targets: dict, + ) -> LossValues: + """ + Computes the total loss for a given batch of predictions and corresponding + stream data. + + The computed loss is: + + Mean_{stream}( Mean_{fsteps}( Mean_{loss_fcts}( loss_fct( target, pred, weigths) ))) + + This method orchestrates the calculation of the overall loss by iterating through + different data streams, forecast steps, channels, and configured loss functions. + It applies weighting, handles NaN values through masking, and accumulates + detailed loss metrics for logging. + + Args: + preds: A nested list of prediction tensors. The outer list represents forecast steps, + the inner list represents streams. Each tensor contains predictions for that + step and stream. + streams_data: A nested list representing the input batch data. The outer list is for + batch items, the inner list for streams. Each element provides an object + (e.g., dataclass instance) containing target data and metadata. + + Returns: + A ModelLoss dataclass instance containing: + - loss: The loss for back-propagation. + - losses_all: A dictionary mapping stream names to a tensor of per-channel and + per-loss-function losses, normalized by non-empty targets/forecast steps. + - stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations + of predictions for channels with statistical loss functions, normalized. + """ + + preds = preds.physical + streams_data = targets["physical"] + + # gradient loss + loss = torch.tensor(0.0, device=self.device, requires_grad=True) + # counter for non-empty targets + ctr_streams = 0 + + # initialize dictionaries for detailed loss tracking and standard deviation statistics + # create tensor for each stream + losses_all: dict[str, Tensor] = { + st.name: torch.zeros( + (len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)), + device=self.device, + ) + for st in self.cf.streams + } + stddev_all: dict[str, Tensor] = { + st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams + } + + # TODO: iterate over batch dimension + i_batch = 0 + for i_stream_info, stream_info in enumerate(self.cf.streams): + # extract target tokens for current stream from the specified forecast offset onwards + targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] + + stream_data = streams_data[i_batch][i_stream_info] + + fstep_loss_weights = self._get_fstep_weights(len(targets)) + + loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps = 0 + + stream_is_spoof = streams_data[i_batch][i_stream_info].is_spoof() + if stream_is_spoof: + spoof_weight = torch.tensor(0.0, device=self.device, requires_grad=False) + else: + spoof_weight = torch.tensor(1.0, device=self.device, requires_grad=False) + + for fstep, (target, fstep_weight) in enumerate( + zip(targets, fstep_loss_weights, strict=False) + ): + # skip if either target or prediction has no data points + pred = preds[fstep][i_stream_info] + if not (target.shape[0] > 0 and pred.shape[0] > 0): + continue + + # reshape prediction tensor to match target's dimensions: extract data/coords and + # remove token dimension if it exists. + # expected final shape of pred is [ensemble_size, num_samples, num_channels]. + pred = pred.reshape([pred.shape[0], *target.shape]) + assert pred.shape[1] > 0 + + # get weigths for current streams + stream_loss_weight, weights_channels = self._get_weights(stream_info) + + # get weights for locations + weights_locations = self._get_location_weights( + stream_info, stream_data, self.cf.forecast_offset, fstep + ) + + # get masks for sub-time steps + substep_masks = self._get_substep_masks(stream_info, fstep, stream_data) + + # accumulate loss from different loss functions + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts): + # loss for current loss function + loss_lfct, loss_lfct_chs = self._loss_per_loss_function( + loss_fct, + target, + pred, + substep_masks, + weights_channels, + weights_locations, + ) + losses_all[stream_info.name][:, i_lfct] += spoof_weight * loss_lfct_chs + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + ( + loss_fct_weight * loss_lfct * stream_loss_weight * fstep_weight + ) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) + ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss + ((spoof_weight * loss_fsteps) / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) + ctr_streams += 1 if ctr_fsteps > 0 and not stream_is_spoof else 0 + + # normalize by forecast step + losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + stddev_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 + + # replace channels without information by nan to exclude from further computations + losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan + stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan + + # normalize by all targets and forecast steps that were non-empty + # (with each having an expected loss of 1 for an uninitalized neural net) + loss = loss / ctr_streams + + # Return all computed loss components encapsulated in a ModelLoss dataclass + return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) From d7b326ba72a31d52939b5434861a1bf2b29cd8f9 Mon Sep 17 00:00:00 2001 From: Jubeku Date: Fri, 14 Nov 2025 23:52:13 +0100 Subject: [PATCH 10/10] fixed trainer for multiple terms in losses_all, still need to fix logging --- src/weathergen/train/trainer.py | 56 +++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 8cf2c067a..63b1d07d5 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -640,17 +640,23 @@ def train(self, epoch): self.world_size_original * self.cf.batch_size_per_gpu, ) - self.loss_unweighted_hist += [loss_values.losses_all] + if bidx == 0: + self.loss_unweighted_hist = {k: [] for k in loss_values.loss_terms.keys()} + self.stdev_unweighted_hist = {k: [] for k in loss_values.loss_terms.keys()} + self.loss_model_hist = [] + for name, loss_terms in loss_values.loss_terms.items(): + self.loss_unweighted_hist[name].append(loss_terms.losses_all) + self.stdev_unweighted_hist[name].append(loss_terms.stddev_all) self.loss_model_hist += [loss_values.loss.item()] - self.stdev_unweighted_hist += [loss_values.stddev_all] perf_gpu, perf_mem = self.get_perf() self.perf_gpu = ddp_average(torch.tensor([perf_gpu], device=self.device)).item() self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item() - self._log_terminal(bidx, epoch, TRAIN) - if bidx % self.train_log_freq.metrics == 0: - self._log(TRAIN) + # NEED TO FIX LOGGING + # self._log_terminal(bidx, epoch, TRAIN) + # if bidx % self.train_log_freq.metrics == 0: + # self._log(TRAIN) # save model checkpoint (with designation _latest) if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: @@ -665,7 +671,6 @@ def validate(self, epoch): self.model.eval() dataset_val_iter = iter(self.data_loader_validation) - self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] with torch.no_grad(): # print progress bar but only in interactive mode, i.e. when without ddp @@ -730,14 +735,22 @@ def validate(self, epoch): sample_idxs, ) - self.loss_unweighted_hist += [loss_values.losses_all] + self.loss_unweighted_hist += [loss_values.loss_terms] + self.loss_model_hist += [loss_values.loss.item()] + if bidx == 0: + self.loss_unweighted_hist = {k: [] for k in loss_values.loss_terms.keys()} + self.stdev_unweighted_hist = {k: [] for k in loss_values.loss_terms.keys()} + self.loss_model_hist = [] + for name, loss_terms in loss_values.loss_terms.items(): + self.loss_unweighted_hist[name].append(loss_terms.losses_all) + self.stdev_unweighted_hist[name].append(loss_terms.stddev_all) self.loss_model_hist += [loss_values.loss.item()] - self.stdev_unweighted_hist += [loss_values.stddev_all] pbar.update(self.cf.batch_size_validation_per_gpu) - self._log_terminal(bidx, epoch, VAL) - self._log(VAL) + # NEED TO FIX LOGGING + # self._log_terminal(bidx, epoch, VAL) + # self._log(VAL) # avoid that there is a systematic bias in the validation subset self.dataset_val.advance() @@ -961,21 +974,24 @@ def _prepare_losses_for_logging( stddev_all (dict[str, torch.Tensor]): Dictionary mapping each stream name to its per-channel standard deviation tensor. """ - losses_all: dict[str, Tensor] = {} - stddev_all: dict[str, Tensor] = {} + losses_all: dict[dict[str, Tensor]] = {} + stddev_all: dict[dict[str, Tensor]] = {} # Make list of losses into a tensor. This is individual tensor per rank real_loss = torch.tensor(self.loss_model_hist, device=self.device) # Gather all tensors from all ranks into a list and stack them into one tensor again real_loss = torch.cat(all_gather_vlen(real_loss)) - for stream in self.cf.streams: # Loop over all streams - stream_hist = [losses_all[stream.name] for losses_all in self.loss_unweighted_hist] - stream_all = torch.stack(stream_hist).to(torch.float64) - losses_all[stream.name] = torch.cat(all_gather_vlen(stream_all)) - stream_hist = [stddev_all[stream.name] for stddev_all in self.stdev_unweighted_hist] - stream_all = torch.stack(stream_hist).to(torch.float64) - stddev_all[stream.name] = torch.cat(all_gather_vlen(stream_all)) + for name in self.loss_unweighted_hist.keys(): + losses_all[name] = {} + stddev_all[name] = {} + for stream in self.cf.streams: # Loop over all streams + stream_hist = [losses[stream.name] for losses in self.loss_unweighted_hist[name]] + stream_all = torch.stack(stream_hist).to(torch.float64) + losses_all[name][stream.name] = torch.cat(all_gather_vlen(stream_all)) + stream_hist = [stddevs[stream.name] for stddevs in self.stdev_unweighted_hist[name]] + stream_all = torch.stack(stream_hist).to(torch.float64) + stddev_all[name][stream.name] = torch.cat(all_gather_vlen(stream_all)) return real_loss, losses_all, stddev_all @@ -1010,7 +1026,7 @@ def _log(self, stage: Stage): self.perf_mem, ) - self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] + self.loss_unweighted_hist, self.loss_model_hist = [], [] def _get_tensor_item(self, tensor): """