diff --git a/CHANGELOG.md b/CHANGELOG.md index e1bbd55c..fba130a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ Keep it human-readable, your future self will thank you! - Feature: Add configurable models [#50](https://github.com/ecmwf/anemoi-training/pulls/50) - Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48) - Long Rollout Plots +- Mask NaN values in training loss function [#72](https://github.com/ecmwf/anemoi-training/pull/72) and [#271](https://github.com/ecmwf-lab/aifs-mono/issues/271) ### Fixed diff --git a/src/anemoi/training/losses/mse.py b/src/anemoi/training/losses/mse.py index b72a7aea..8bd24ae1 100644 --- a/src/anemoi/training/losses/mse.py +++ b/src/anemoi/training/losses/mse.py @@ -43,6 +43,10 @@ def __init__( self.avg_function = torch.nanmean if ignore_nans else torch.mean self.sum_function = torch.nansum if ignore_nans else torch.sum + # register_buffer: + # 1. save the tensor to the model + # 2. make sure that the tensor is moved to the same device as the model + self.register_buffer("variable_node_mask", torch.ones(1), persistent=False) # not saved in state_dict self.register_buffer("weights", node_weights, persistent=True) if data_variances is not None: self.register_buffer("ivar", data_variances, persistent=True) @@ -76,6 +80,9 @@ def forward( if hasattr(self, "ivar"): out *= self.ivar + # apply variable node mask (masking input-NaN-positions with 0) + out = self.variable_node_mask * out + # Squash by last dimension if squash: out = self.avg_function(out, dim=-1) @@ -89,3 +96,7 @@ def forward( # keep last dimension (variables) when summing weights out /= self.sum_function(self.weights[..., None].expand_as(out), axis=(0, 1, 2)) return self.sum_function(out, axis=(0, 1, 2)) + + def update_variable_node_mask(self, variable_node_mask: torch.tensor) -> None: + """Update the variable node weights.""" + self.variable_node_mask = variable_node_mask diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index ff1acfd7..6d1e3a12 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -12,6 +12,7 @@ import os from collections import defaultdict from collections.abc import Mapping +from functools import cached_property import numpy as np import pytorch_lightning as pl @@ -88,7 +89,10 @@ def __init__( config, data_indices, ) - self.loss = WeightedMSELoss(node_weights=self.loss_weights, data_variances=loss_scaling) + self.loss = WeightedMSELoss( + node_weights=self.loss_weights, + data_variances=loss_scaling, + ) self.metrics = WeightedMSELoss(node_weights=self.loss_weights, ignore_nans=True) if config.training.loss_gradient_scaling: @@ -127,6 +131,20 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x, self.model_comm_group) + @cached_property + def training_weights_for_imputed_variables(self) -> None: + LOGGER.info("EXECUTE cached property training_weights_for_imputed_variables, Should appear only once") + loss_weights_mask = torch.ones_like(self.loss.variable_node_mask) + # iterate over all pre-processors and check if they have a loss_mask_training attribute + for pre_processor in self.model.pre_processors.processors.values(): + if hasattr(pre_processor, "loss_mask_training"): + loss_weights_mask = loss_weights_mask * pre_processor.loss_mask_training + # if transform_loss_mask function exists for preprocessor apply it + if hasattr(pre_processor, "transform_loss_mask"): + loss_weights_mask = pre_processor.transform_loss_mask(loss_weights_mask) + self.loss.update_variable_node_mask(loss_weights_mask) + return None + @staticmethod def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> tuple[dict, torch.Tensor]: metric_ranges = defaultdict(list) @@ -222,6 +240,9 @@ def _step( loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) # for validation not normalized in-place because remappers cannot be applied in-place batch = self.model.pre_processors(batch, in_place=not validation_mode) + + self.training_weights_for_imputed_variables + metrics = {} # start rollout of preprocessed batch