Skip to content

Commit

Permalink
15 remapping of one input variable to multiple new ones (#21)
Browse files Browse the repository at this point in the history
* feat: remapper and change to data indices when mapping one variable to several

* tests: update imputer and normalizer tests

* feat: include remapper as a preprocessor. update init for all preprocessors. add tests for the remapper.

* tests: update tests for index collection

* documentation and changelog

* feat: enable remapper for forcing variables

* tests: include remapping forcing variable and do not test with remapping variables at the end

* comment and warning about using in_place=True in remapper as this is not possible

* comments: incorporate changes/documentation requested by jesper

* change order of function inputs preprocessors, documentation for data indices and remapper

* style: dict in config files for defining the variables to be remapped. structure and additional assert in index collection.

* args in preprocessors
  • Loading branch information
sahahner authored Sep 9, 2024
1 parent 4f1de81 commit 80787ce
Show file tree
Hide file tree
Showing 14 changed files with 620 additions and 52 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Keep it human-readable, your future self will thank you!
### Added
- CI workflow to update the changelog on release

- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables.

### Changed

- Update CI to inherit from common infrastructue reusable workflows
Expand Down
25 changes: 23 additions & 2 deletions docs/modules/data_indices.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,33 @@ config entry:
:alt: Schematic of IndexCollection with Data Indexing on Data and Model levels.
:align: center

The are two Index-levels:
Additionally, prognostic and forcing variables can be remapped and
converted to multiple variables. The conversion is then done by the
remapper-preprocessor.

.. code:: yaml
data:
remapped:
- d:
- "d_1"
- "d_2"
There are two main Index-levels:

- Data: The data at "Zarr"-level provided by Anemoi-Datasets
- Model: The "squeezed" tensors with irrelevant parts missing.

These are both split into two versions:
Additionally, there are two internal model levels (After preprocessor
and before postprocessor) that are necessary because of the possiblity
to remap variables to multiple variables.

- Internal Data: Variables from Data-level that are used internally in
the model, but not exposed to the user.
- Internal Model: Variables from Model-level that are used internally
in the model, but not exposed to the user.

All indices at the different levels are split into two versions:

- Input: The data going into training / model
- Output: The data produced by training / model
Expand Down
13 changes: 13 additions & 0 deletions docs/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,16 @@ following classes:
:members:
:no-undoc-members:
:show-inheritance:

**********
Remapper
**********

The remapper module is used to remap one variable to multiple other
variables that have been listed in data.remapped:. The module contains
the following classes:

.. automodule:: anemoi.models.preprocessing.remapper
:members:
:no-undoc-members:
:show-inheritance:
61 changes: 58 additions & 3 deletions src/anemoi/models/data_indices/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,74 @@ class IndexCollection:

def __init__(self, config, name_to_index) -> None:
self.config = OmegaConf.to_container(config, resolve=True)

self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1)))
self.forcing = [] if config.data.forcing is None else OmegaConf.to_container(config.data.forcing, resolve=True)
self.diagnostic = (
[] if config.data.diagnostic is None else OmegaConf.to_container(config.data.diagnostic, resolve=True)
)
# config.data.remapped is a list of diccionaries: every remapper is one entry of the list
self.remapped = (
dict() if config.data.remapped is None else OmegaConf.to_container(config.data.remapped, resolve=True)
)
self.forcing_remapped = self.forcing.copy()

assert set(self.diagnostic).isdisjoint(self.forcing), (
f"Diagnostic and forcing variables overlap: {set(self.diagnostic).intersection(self.forcing)}. ",
"Please drop them at a dataset-level to exclude them from the training data.",
)
self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1)))
assert set(self.remapped).isdisjoint(self.diagnostic), (
"Remapped variable overlap with diagnostic variables. Not implemented.",
)
assert set(self.remapped).issubset(self.name_to_index), (
"Remapping a variable that does not exist in the dataset. Check for typos: ",
f"{set(self.remapped).difference(self.name_to_index)}",
)
name_to_index_model_input = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.diagnostic)
}
name_to_index_model_output = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.forcing)
}
# remove remapped variables from internal data and model indices
name_to_index_internal_data_input = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.remapped)
}
name_to_index_internal_model_input = {
name: i for i, name in enumerate(key for key in name_to_index_model_input if key not in self.remapped)
}
name_to_index_internal_model_output = {
name: i for i, name in enumerate(key for key in name_to_index_model_output if key not in self.remapped)
}
# for all variables to be remapped we add the resulting remapped variables to the end of the tensors
# keep track of that in the index collections
for key in self.remapped:
for mapped in self.remapped[key]:
# add index of remapped variables to dictionary
name_to_index_internal_model_input[mapped] = len(name_to_index_internal_model_input)
name_to_index_internal_data_input[mapped] = len(name_to_index_internal_data_input)
if key not in self.forcing:
# do not include forcing variables in the remapped model output
name_to_index_internal_model_output[mapped] = len(name_to_index_internal_model_output)
else:
# add remapped forcing variables to forcing_remapped
self.forcing_remapped += [mapped]
if key in self.forcing:
# if key is in forcing we need to remove it from forcing_remapped after remapped variables have been added
self.forcing_remapped.remove(key)

self.data = DataIndex(self.diagnostic, self.forcing, self.name_to_index)
self.internal_data = DataIndex(
self.diagnostic,
self.forcing_remapped,
name_to_index_internal_data_input,
) # internal after the remapping applied to data (training)
self.model = ModelIndex(self.diagnostic, self.forcing, name_to_index_model_input, name_to_index_model_output)
self.internal_model = ModelIndex(
self.diagnostic,
self.forcing_remapped,
name_to_index_internal_model_input,
name_to_index_internal_model_output,
) # internal after the remapping applied to model (inference)

def __repr__(self) -> str:
return f"IndexCollection(config={self.config}, name_to_index={self.name_to_index})"
Expand All @@ -54,7 +102,12 @@ def __eq__(self, other):
# don't attempt to compare against unrelated types
return NotImplemented

return self.model == other.model and self.data == other.data
return (
self.model == other.model
and self.data == other.data
and self.internal_model == other.internal_model
and self.internal_data == other.internal_data
)

def __getitem__(self, key):
return getattr(self, key)
Expand All @@ -63,6 +116,8 @@ def todict(self):
return {
"data": self.data.todict(),
"model": self.model.todict(),
"internal_model": self.internal_model.todict(),
"internal_data": self.internal_data.todict(),
}

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _build_model(self) -> None:
"""Builds the model and pre- and post-processors."""
# Instantiate processors
processors = [
[name, instantiate(processor, statistics=self.statistics, data_indices=self.data_indices)]
[name, instantiate(processor, data_indices=self.data_indices, statistics=self.statistics)]
for name, processor in self.config.data.processors.items()
]

Expand Down
19 changes: 10 additions & 9 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,23 @@ def __init__(
)

def _calculate_shapes_and_indices(self, data_indices: dict) -> None:
self.num_input_channels = len(data_indices.model.input)
self.num_output_channels = len(data_indices.model.output)
self._internal_input_idx = data_indices.model.input.prognostic
self._internal_output_idx = data_indices.model.output.prognostic
self.num_input_channels = len(data_indices.internal_model.input)
self.num_output_channels = len(data_indices.internal_model.output)
self._internal_input_idx = data_indices.internal_model.input.prognostic
self._internal_output_idx = data_indices.internal_model.output.prognostic

def _assert_matching_indices(self, data_indices: dict) -> None:

assert len(self._internal_output_idx) == len(data_indices.model.output.full) - len(
data_indices.model.output.diagnostic
assert len(self._internal_output_idx) == len(data_indices.internal_model.output.full) - len(
data_indices.internal_model.output.diagnostic
), (
f"Mismatch between the internal data indices ({len(self._internal_output_idx)}) and the output indices excluding "
f"diagnostic variables ({len(data_indices.model.output.full) - len(data_indices.model.output.diagnostic)})",
f"Mismatch between the internal data indices ({len(self._internal_output_idx)}) and "
f"the internal output indices excluding diagnostic variables "
f"({len(data_indices.internal_model.output.full) - len(data_indices.internal_model.output.diagnostic)})",
)
assert len(self._internal_input_idx) == len(
self._internal_output_idx,
), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}"
), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}"

def _define_tensor_sizes(self, config: DotDict) -> None:
self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes
Expand Down
12 changes: 8 additions & 4 deletions src/anemoi/models/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
#

import logging
from typing import TYPE_CHECKING
from typing import Optional

import torch
from torch import Tensor
from torch import nn

if TYPE_CHECKING:
from anemoi.models.data_indices.collection import IndexCollection

LOGGER = logging.getLogger(__name__)


Expand All @@ -23,19 +27,19 @@ class BasePreprocessor(nn.Module):
def __init__(
self,
config=None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
data_indices: Optional[dict] = None,
) -> None:
"""Initialize the preprocessor.
Parameters
----------
config : DotDict
configuration object
configuration object of the processor
data_indices : IndexCollection
Data indices for input and output variables
statistics : dict
Data statistics dictionary
data_indices : dict
Data indices for input and output variables
"""
super().__init__()

Expand Down
16 changes: 9 additions & 7 deletions src/anemoi/models/preprocessing/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ def __init__(
Parameters
----------
config : DotDict
configuration object
configuration object of the processor
data_indices : IndexCollection
Data indices for input and output variables
statistics : dict
Data statistics dictionary
data_indices : dict
Data indices for input and output variables
"""
super().__init__(config, statistics, data_indices)
super().__init__(config, data_indices, statistics)

self.nan_locations = None
self.data_indices = data_indices

def _validate_indices(self):
assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.replacement), (
Expand Down Expand Up @@ -174,8 +173,8 @@ class InputImputer(BaseImputer):
def __init__(
self,
config=None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
data_indices: Optional[dict] = None,
) -> None:
super().__init__(config, data_indices, statistics)

Expand All @@ -201,7 +200,10 @@ class ConstantImputer(BaseImputer):
"""

def __init__(
self, config=None, statistics: Optional[dict] = None, data_indices: Optional[IndexCollection] = None
self,
config=None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
) -> None:
super().__init__(config, data_indices, statistics)

Expand Down
8 changes: 4 additions & 4 deletions src/anemoi/models/preprocessing/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def __init__(
Parameters
----------
config : DotDict
configuration object
configuration object of the processor
data_indices : IndexCollection
Data indices for input and output variables
statistics : dict
Data statistics dictionary
data_indices : dict
Data indices for input and output variables
"""
super().__init__(config, statistics, data_indices)
super().__init__(config, data_indices, statistics)

name_to_index_training_input = self.data_indices.data.input.name_to_index

Expand Down
Loading

0 comments on commit 80787ce

Please sign in to comment.