diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index aba62a2..7c36682 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -8,10 +8,12 @@ # import uuid +from typing import Optional import torch from anemoi.utils.config import DotDict from hydra.utils import instantiate +from torch.distributed.distributed_c10d import ProcessGroup from torch_geometric.data import HeteroData from anemoi.models.preprocessing import Processors @@ -39,38 +41,55 @@ class AnemoiModelInterface(torch.nn.Module): Metadata for the model. data_indices : dict Indices for the data. - pre_processors : Processors - Pre-processing steps to apply to the data before passing it to the model. - post_processors : Processors - Post-processing steps to apply to the model's output. model : AnemoiModelEncProcDec The underlying Anemoi model. """ def __init__( - self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict + self, + *, + config: DotDict, + graph_data: HeteroData, + statistics: dict, + data_indices: dict, + metadata: dict, + statistics_tendencies: Optional[dict] = None, ) -> None: super().__init__() self.config = config self.id = str(uuid.uuid4()) self.multi_step = self.config.training.multistep_input + self.prediction_strategy = self.config.training.prediction_strategy self.graph_data = graph_data self.statistics = statistics + self.statistics_tendencies = statistics_tendencies self.metadata = metadata self.data_indices = data_indices self._build_model() def _build_model(self) -> None: """Builds the model and pre- and post-processors.""" - # Instantiate processors - processors = [ - [name, instantiate(processor, data_indices=self.data_indices, statistics=self.statistics)] - for name, processor in self.config.data.processors.items() + # Instantiate processors for state + processors_state = [ + [name, instantiate(processor, statistics=self.statistics, data_indices=self.data_indices)] + for name, processor in self.config.data.processors.state.items() ] # Assign the processor list pre- and post-processors - self.pre_processors = Processors(processors) - self.post_processors = Processors(processors, inverse=True) + self.pre_processors_state = Processors(processors_state) + self.post_processors_state = Processors(processors_state, inverse=True) + + # Instantiate processors for tendency + self.pre_processors_tendency = None + self.post_processors_tendency = None + if self.prediction_strategy == "tendency": + processors_tendency = [ + [name, instantiate(processor, statistics=self.statistics_tendencies, data_indices=self.data_indices)] + for name, processor in self.config.data.processors.tendency.items() + ] + + self.pre_processors_tendency = Processors(processors_tendency) + self.post_processors_tendency = Processors(processors_tendency, inverse=True) # Instantiate the model self.model = instantiate( @@ -81,8 +100,19 @@ def _build_model(self) -> None: _recursive_=False, # Disables recursive instantiation by Hydra ) - # Use the forward method of the model directly - self.forward = self.model.forward + def forward(self, x: torch.Tensor, model_comm_group: Optional[ProcessGroup] = None) -> torch.Tensor: + if self.prediction_strategy == "residual": + # Predict state by adding residual connection (just for the prognostic variables) + x_pred = self.model.forward(x, model_comm_group) + x_pred[..., self.model._internal_output_idx] += x[:, -1, :, :, self.model._internal_input_idx] + else: + x_pred = self.model.forward(x, model_comm_group) + + for bounding in self.model.boundings: + # bounding performed in the order specified in the config file + x_pred = bounding(x_pred) + + return x_pred def predict_step(self, batch: torch.Tensor) -> torch.Tensor: """Prediction step for the model. @@ -97,17 +127,54 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor: torch.Tensor Predicted data. """ - batch = self.pre_processors(batch, in_place=False) with torch.no_grad(): assert ( len(batch.shape) == 4 ), f"The input tensor has an incorrect shape: expected a 4-dimensional tensor, got {batch.shape}!" + x = self.pre_processors_state(batch[:, 0 : self.multi_step, ...], in_place=False) + # Dimensions are - # batch, timesteps, horizonal space, variables - x = batch[:, 0 : self.multi_step, None, ...] # add dummy ensemble dimension as 3rd index + # batch, timesteps, horizontal space, variables + x = x[..., None, :, :] # add dummy ensemble dimension as 3rd index + if self.prediction_strategy == "tendency": + tendency_hat = self(x) + y_hat = self.add_tendency_to_state(x[:, -1, ...], tendency_hat) + else: + y_hat = self(x) + y_hat = self.post_processors_state(y_hat, in_place=False) + + return y_hat - y_hat = self(x) + def add_tendency_to_state(self, state_inp: torch.Tensor, tendency: torch.Tensor) -> torch.Tensor: + """Add the tendency to the state. + + Parameters + ---------- + state_inp : torch.Tensor + The input state tensor with full input variables and unprocessed. + tendency : torch.Tensor + The tendency tensor output from model. + + Returns + ------- + torch.Tensor + Predicted data. + """ + + state_outp = self.post_processors_tendency( + tendency, in_place=False, data_index=self.data_indices.data.output.full + ) + + state_outp[..., self.data_indices.model.output.diagnostic] = self.post_processors_state( + tendency[..., self.data_indices.model.output.diagnostic], + in_place=False, + data_index=self.data_indices.data.output.diagnostic, + ) + + state_outp[..., self.data_indices.model.output.prognostic] += state_inp[ + ..., self.data_indices.model.input.prognostic + ] - return self.post_processors(y_hat, in_place=False) + return state_outp diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index c77db6e..fb042e5 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -259,11 +259,4 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> .clone() ) - # residual connection (just for the prognostic variables) - x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] - - for bounding in self.boundings: - # bounding performed in the order specified in the config file - x_out = bounding(x_out) - return x_out diff --git a/src/anemoi/models/preprocessing/__init__.py b/src/anemoi/models/preprocessing/__init__.py index cc2cb4f..40c0396 100644 --- a/src/anemoi/models/preprocessing/__init__.py +++ b/src/anemoi/models/preprocessing/__init__.py @@ -95,7 +95,9 @@ def _invert_key_value_list(self, method_config: dict[str, list[str]]) -> dict[st for variable in variables } - def forward(self, x, in_place: bool = True, inverse: bool = False) -> Tensor: + def forward( + self, x, in_place: bool = True, inverse: bool = False, data_index: Optional[torch.Tensor] = None + ) -> Tensor: """Process the input tensor. Parameters @@ -106,6 +108,8 @@ def forward(self, x, in_place: bool = True, inverse: bool = False) -> Tensor: Whether to process the tensor in place inverse : bool Whether to inverse transform the input + data_index : torch.Tensor, optional + Normalize only the specified indices, by default. Returns ------- @@ -113,8 +117,8 @@ def forward(self, x, in_place: bool = True, inverse: bool = False) -> Tensor: Processed tensor """ if inverse: - return self.inverse_transform(x, in_place=in_place) - return self.transform(x, in_place=in_place) + return self.inverse_transform(x, in_place=in_place, data_index=data_index) + return self.transform(x, in_place=in_place, data_index=data_index) def transform(self, x, in_place: bool = True) -> Tensor: """Process the input tensor.""" @@ -155,7 +159,7 @@ def __init__(self, processors: list, inverse: bool = False) -> None: def __repr__(self) -> str: return f"{self.__class__.__name__} [{'inverse' if self.inverse else 'forward'}]({self.processors})" - def forward(self, x, in_place: bool = True) -> Tensor: + def forward(self, x, in_place: bool = True, data_index: Optional[torch.Tensor] = None) -> Tensor: """Process the input tensor. Parameters @@ -164,6 +168,8 @@ def forward(self, x, in_place: bool = True) -> Tensor: Input tensor in_place : bool Whether to process the tensor in place + data_index : Optional[torch.Tensor], optional + Normalize only the specified indices, by default None Returns ------- @@ -171,7 +177,7 @@ def forward(self, x, in_place: bool = True) -> Tensor: Processed tensor """ for processor in self.processors.values(): - x = processor(x, in_place=in_place, inverse=self.inverse) + x = processor(x, in_place=in_place, inverse=self.inverse, data_index=data_index) if self.first_run: self.first_run = False diff --git a/src/anemoi/models/preprocessing/imputer.py b/src/anemoi/models/preprocessing/imputer.py index 6ef5adb..4e06397 100644 --- a/src/anemoi/models/preprocessing/imputer.py +++ b/src/anemoi/models/preprocessing/imputer.py @@ -103,7 +103,9 @@ def _expand_subset_mask(self, x: torch.Tensor, idx_src: int) -> torch.Tensor: """Expand the subset of the mask to the correct shape.""" return self.nan_locations[:, idx_src].expand(*x.shape[:-2], -1) - def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + def transform( + self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None + ) -> torch.Tensor: """Impute missing values in the input tensor.""" if not in_place: x = x.clone() @@ -115,7 +117,9 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: self.nan_locations = torch.isnan(x[idx].squeeze()) # Choose correct index based on number of variables - if x.shape[-1] == self.num_training_input_vars: + if data_index is not None: + index = data_index + elif x.shape[-1] == self.num_training_input_vars: index = self.index_training_input elif x.shape[-1] == self.num_inference_input_vars: index = self.index_inference_input @@ -131,13 +135,18 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = value return x - def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + def inverse_transform( + self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None + ) -> torch.Tensor: """Impute missing values in the input tensor.""" if not in_place: x = x.clone() # Replace original nans with nan again - if x.shape[-1] == self.num_training_output_vars: + # Choose correct index based on number of variables + if data_index is not None: + index = data_index + elif x.shape[-1] == self.num_training_output_vars: index = self.index_training_output elif x.shape[-1] == self.num_inference_output_vars: index = self.index_inference_output diff --git a/tests/preprocessing/test_preprocessor_normalizer.py b/tests/preprocessing/test_preprocessor_normalizer.py index 3a2327e..d552c66 100644 --- a/tests/preprocessing/test_preprocessor_normalizer.py +++ b/tests/preprocessing/test_preprocessor_normalizer.py @@ -22,7 +22,9 @@ def input_normalizer(): { "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, "data": { - "normalizer": {"default": "mean-std", "min-max": ["x"], "max": ["y"], "none": ["z"], "mean-std": ["q"]}, + "normalizers": { + "state": {"default": "mean-std", "min-max": ["x"], "max": ["y"], "none": ["z"], "mean-std": ["q"]} + }, "forcing": ["z", "q"], "diagnostic": ["other"], "remapped": {}, @@ -68,7 +70,7 @@ def remap_normalizer(): } name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} data_indices = IndexCollection(config=config, name_to_index=name_to_index) - return InputNormalizer(config=config.data.normalizer, data_indices=data_indices, statistics=statistics) + return InputNormalizer(config=config.data.normalizers.state, statistics=statistics, data_indices=data_indices) def test_normalizer_not_inplace(input_normalizer) -> None: