Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/mask NaNs in training loss function #72

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Keep it human-readable, your future self will thank you!
- Feature: Add configurable models [#50](https://github.com/ecmwf/anemoi-training/pulls/50)
- Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48)
- Long Rollout Plots
- Mask NaN values in training loss function [#271](https://github.com/ecmwf-lab/aifs-mono/issues/271)
sahahner marked this conversation as resolved.
Show resolved Hide resolved

### Fixed

Expand Down
11 changes: 11 additions & 0 deletions src/anemoi/training/losses/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def __init__(
self.avg_function = torch.nanmean if ignore_nans else torch.mean
self.sum_function = torch.nansum if ignore_nans else torch.sum

# register_buffer:
# 1. save the tensor to the model
# 2. make sure that the tensor is moved to the same device as the model
self.register_buffer("variable_node_mask", torch.ones(1), persistent=False) # not saved in state_dict
self.register_buffer("weights", node_weights, persistent=True)
if data_variances is not None:
self.register_buffer("ivar", data_variances, persistent=True)
Expand Down Expand Up @@ -76,6 +80,9 @@ def forward(
if hasattr(self, "ivar"):
out *= self.ivar

# apply variable node mask (masking input-NaN-positions with 0)
out = self.variable_node_mask * out

# Squash by last dimension
if squash:
out = self.avg_function(out, dim=-1)
Expand All @@ -89,3 +96,7 @@ def forward(
# keep last dimension (variables) when summing weights
out /= self.sum_function(self.weights[..., None].expand_as(out), axis=(0, 1, 2))
return self.sum_function(out, axis=(0, 1, 2))

def update_variable_node_mask(self, variable_node_mask: torch.tensor) -> None:
"""Update the variable node weights."""
self.variable_node_mask = variable_node_mask
23 changes: 22 additions & 1 deletion src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
from collections import defaultdict
from collections.abc import Mapping
from functools import cached_property

import numpy as np
import pytorch_lightning as pl
Expand Down Expand Up @@ -88,7 +89,10 @@ def __init__(
config,
data_indices,
)
self.loss = WeightedMSELoss(node_weights=self.loss_weights, data_variances=loss_scaling)
self.loss = WeightedMSELoss(
node_weights=self.loss_weights,
data_variances=loss_scaling,
)
self.metrics = WeightedMSELoss(node_weights=self.loss_weights, ignore_nans=True)

if config.training.loss_gradient_scaling:
Expand Down Expand Up @@ -127,6 +131,20 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x, self.model_comm_group)

@cached_property
def training_weights_for_imputed_variables(self) -> None:
LOGGER.info("EXECUTE cached property training_weights_for_imputed_variables, Should appear only once")
loss_weights_mask = torch.ones_like(self.loss.variable_node_mask)
# iterate over all pre-processors and check if they have a loss_mask_training attribute
for pre_processor in self.model.pre_processors.processors.values():
if hasattr(pre_processor, "loss_mask_training"):
loss_weights_mask = loss_weights_mask * pre_processor.loss_mask_training
# if transform_loss_mask function exists for preprocessor apply it
if hasattr(pre_processor, "transform_loss_mask"):
loss_weights_mask = pre_processor.transform_loss_mask(loss_weights_mask)
self.loss.update_variable_node_mask(loss_weights_mask)
return None

@staticmethod
def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> tuple[dict, torch.Tensor]:
metric_ranges = defaultdict(list)
Expand Down Expand Up @@ -222,6 +240,9 @@ def _step(
loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False)
# for validation not normalized in-place because remappers cannot be applied in-place
batch = self.model.pre_processors(batch, in_place=not validation_mode)

self.training_weights_for_imputed_variables

metrics = {}

# start rollout of preprocessed batch
Expand Down
Loading