Skip to content

Commit

Permalink
#47 adding tendency variances
Browse files Browse the repository at this point in the history
Co-authored-by: Jakob Schloer <jakob.schloer@gmail.com>
  • Loading branch information
Rilwan-Adewoyin committed Sep 4, 2024
1 parent 3d822ca commit c325238
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/anemoi/training/losses/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@

import torch
from torch import nn

from typing import Optional
LOGGER = logging.getLogger(__name__)


#TODO (rilwan-ade): remove the data_variances/tendency_variances and replace the name with feature_weights - then change loss_
class WeightedMSELoss(nn.Module):
"""Latitude-weighted MSE loss."""

def __init__(
self,
node_weights: torch.Tensor,
data_variances: torch.Tensor | None = None,
tendency_variances: Optional[torch.Tensor] = None,
ignore_nans: bool | None = False,
) -> None:
"""Latitude- and (inverse-)variance-weighted MSE Loss.
Expand All @@ -34,6 +35,8 @@ def __init__(
Weight of each node in the loss function
data_variances : Optional[torch.Tensor], optional
precomputed, per-variable stepwise variance estimate, by default None
tendency_variances : Optional[torch.Tensor], optional
precomputed, per-variable-level variance of time differences, by default None
ignore_nans : bool, optional
Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False
Expand All @@ -46,6 +49,8 @@ def __init__(
self.register_buffer("weights", node_weights, persistent=True)
if data_variances is not None:
self.register_buffer("ivar", data_variances, persistent=True)
if tendency_variances is not None:
self.register_buffer("tvar", tendency_variances, persistent=True)

def forward(
self,
Expand Down

0 comments on commit c325238

Please sign in to comment.