From 6e64c91d0a8b0876860888f2c32bf094dc928ed2 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 18 Oct 2024 08:54:23 +0000 Subject: [PATCH 1/5] feat: add NamedNodeAttributes --- src/anemoi/models/layers/graph.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index 71703d9..332a4a3 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -11,6 +11,7 @@ import torch from torch import Tensor from torch import nn +from torch_geometric.data import HeteroData class TrainableTensor(nn.Module): @@ -35,8 +36,36 @@ def __init__(self, tensor_size: int, trainable_size: int) -> None: def forward(self, x: Tensor, batch_size: int) -> Tensor: latent = [einops.repeat(x, "e f -> (repeat e) f", repeat=batch_size)] if self.trainable is not None: - latent.append(einops.repeat(self.trainable, "e f -> (repeat e) f", repeat=batch_size)) + latent.append(einops.repeat(self.trainable.to(x.device), "e f -> (repeat e) f", repeat=batch_size)) return torch.cat( latent, dim=-1, # feature dimension ) + + +class NamedNodesAttributes(torch.nn.Module): + """Named Node Attributes Module.""" + + def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None: + """Initialize NamedNodesAttributes.""" + self.num_trainable_params = num_trainable_params + self.nodes_names = list(graph_data.node_types) + + self.trainable_tensors = nn.ModuleDict() + for nodes_name in self.nodes_names: + self.register_coordinates(nodes_name, graph_data[nodes_name].x) + self.register_tensor(nodes_name, graph_data[nodes_name].num_nodes) + + def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None: + """Register coordinates.""" + sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1) + self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) + + def register_tensor(self, name: str, tensor_size: int) -> None: + """Register a trainable tensor.""" + self.trainable_tensors[name] = TrainableTensor(tensor_size, self.num_trainable_params) + + def forward(self, name: str, batch_size: int) -> Tensor: + """Forward pass.""" + latlons = getattr(self, f"latlons_{name}") + return self.trainable_tensors[name](latlons, batch_size) From 5c83f5f6f283e94c19429bfeeb81bf7570cbcfaa Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 18 Oct 2024 09:42:56 +0000 Subject: [PATCH 2/5] feat: use NamedNodesAttributes in AnemoiModelEncProcDec --- src/anemoi/models/layers/graph.py | 17 +++-- .../models/encoder_processor_decoder.py | 65 ++++--------------- 2 files changed, 27 insertions(+), 55 deletions(-) diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index 332a4a3..7a4e0d4 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -48,22 +48,31 @@ class NamedNodesAttributes(torch.nn.Module): def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None: """Initialize NamedNodesAttributes.""" + super().__init__() + self.num_trainable_params = num_trainable_params - self.nodes_names = list(graph_data.node_types) + self.register_fixed_attributes(graph_data) self.trainable_tensors = nn.ModuleDict() for nodes_name in self.nodes_names: self.register_coordinates(nodes_name, graph_data[nodes_name].x) - self.register_tensor(nodes_name, graph_data[nodes_name].num_nodes) + self.register_tensor(nodes_name) + + def register_fixed_attributes(self, graph_data: HeteroData) -> None: + """Register fixed attributes.""" + self.nodes_names = list(graph_data.node_types) + self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in self.nodes_names} + self.coord_dims = {2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names} + self.attr_ndims = {self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names} def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None: """Register coordinates.""" sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1) self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) - def register_tensor(self, name: str, tensor_size: int) -> None: + def register_tensor(self, name: str) -> None: """Register a trainable tensor.""" - self.trainable_tensors[name] = TrainableTensor(tensor_size, self.num_trainable_params) + self.trainable_tensors[name] = TrainableTensor(self.num_nodes[name], self.num_trainable_params) def forward(self, name: str, batch_size: int) -> Tensor: """Forward pass.""" diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index c77db6e..592d6d4 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -21,7 +21,7 @@ from torch_geometric.data import HeteroData from anemoi.models.distributed.shapes import get_shape_shards -from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.layers.graph import NamedNodesAttributes LOGGER = logging.getLogger(__name__) @@ -55,33 +55,24 @@ def __init__( self._calculate_shapes_and_indices(data_indices) self._assert_matching_indices(data_indices) - - self.multi_step = model_config.training.multistep_input - - self._define_tensor_sizes(model_config) - - # Create trainable tensors - self._create_trainable_attributes() - - # Register lat/lon of nodes - self._register_latlon("data", self._graph_name_data) - self._register_latlon("hidden", self._graph_name_hidden) - self.data_indices = data_indices + self.multi_step = model_config.training.multistep_input self.num_channels = model_config.model.num_channels - input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size + self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data) + + input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] # Encoder data -> hidden self.encoder = instantiate( model_config.model.encoder, in_channels_src=input_dim, - in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size, + in_channels_dst=self.node_attributes.attr_ndims[self._graph_name_hidden], hidden_dim=self.num_channels, sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)], - src_grid_size=self._data_grid_size, - dst_grid_size=self._hidden_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], ) # Processor hidden -> hidden @@ -89,8 +80,8 @@ def __init__( model_config.model.processor, num_channels=self.num_channels, sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], - src_grid_size=self._hidden_grid_size, - dst_grid_size=self._hidden_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], ) # Decoder hidden -> data @@ -101,8 +92,8 @@ def __init__( hidden_dim=self.num_channels, out_channels_dst=self.num_output_channels, sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)], - src_grid_size=self._hidden_grid_size, - dst_grid_size=self._data_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data], ) # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) @@ -132,34 +123,6 @@ def _assert_matching_indices(self, data_indices: dict) -> None: self._internal_output_idx, ), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}" - def _define_tensor_sizes(self, config: DotDict) -> None: - self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes - self._hidden_grid_size = self._graph_data[self._graph_name_hidden].num_nodes - - self.trainable_data_size = config.model.trainable_parameters.data - self.trainable_hidden_size = config.model.trainable_parameters.hidden - - def _register_latlon(self, name: str, nodes: str) -> None: - """Register lat/lon buffers. - - Parameters - ---------- - name : str - Name to store the lat-lon coordinates of the nodes. - nodes : str - Name of nodes to map - """ - coords = self._graph_data[nodes].x - sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) - self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) - - def _create_trainable_attributes(self) -> None: - """Create all trainable attributes.""" - self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) - self.trainable_hidden = TrainableTensor( - trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size - ) - def _run_mapper( self, mapper: nn.Module, @@ -209,12 +172,12 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> x_data_latent = torch.cat( ( einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), - self.trainable_data(self.latlons_data, batch_size=batch_size), + self.node_attributes(self._graph_name_data, batch_size=batch_size), ), dim=-1, # feature dimension ) - x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size) + x_hidden_latent = self.node_attributes(self._graph_name_hidden, batch_size=batch_size) # get shard shapes shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group) From fdbf92f2e76189297d3953ec8eadf109444f50ce Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 18 Oct 2024 10:14:37 +0000 Subject: [PATCH 3/5] fix: update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57078fd..963c62d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Keep it human-readable, your future self will thank you! - configurabilty of the dropout probability in the the MultiHeadSelfAttention module - Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13) - GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46) +- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64) ### Changed - Bugfixes for CI From 659652f38f31c8bf49a061c2658aed26397fce0d Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 18 Oct 2024 10:42:31 +0000 Subject: [PATCH 4/5] fix: typo --- src/anemoi/models/layers/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index 7a4e0d4..c3608c5 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -62,8 +62,8 @@ def register_fixed_attributes(self, graph_data: HeteroData) -> None: """Register fixed attributes.""" self.nodes_names = list(graph_data.node_types) self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in self.nodes_names} - self.coord_dims = {2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names} - self.attr_ndims = {self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names} + self.coord_dims = {nodes_name: 2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names} + self.attr_ndims = {nodes_name: self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names} def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None: """Register coordinates.""" From f10810cd4fedfdac4dba069ee5316b8564def589 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Oct 2024 10:42:57 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/models/layers/graph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index c3608c5..5d96e73 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -63,7 +63,9 @@ def register_fixed_attributes(self, graph_data: HeteroData) -> None: self.nodes_names = list(graph_data.node_types) self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in self.nodes_names} self.coord_dims = {nodes_name: 2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names} - self.attr_ndims = {nodes_name: self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names} + self.attr_ndims = { + nodes_name: self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names + } def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None: """Register coordinates."""