diff --git a/CHANGELOG.md b/CHANGELOG.md index 02a96e9..76ae6b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ Keep it human-readable, your future self will thank you! ### Added +- feat: Add `AttributeFromSourceNode` and `AttributeFromTargetNode` edge attribute to copy attribute from source or target node. Set `node_attr_name` in the config to specify which attribute to copy from the source | target node (#94) + +# Changed - feat: Support for providing lon/lat coordinates from a text file (loaded with numpy loadtxt method) to build the graph `TextNodes` (#93) - feat: Build 2D graphs with `Voronoi` in case `SphericalVoronoi` does not work well/is an overkill (LAM). Set `flat=true` in the nodes attributes to compute area weight using Voronoi with a qhull options preventing the empty region creation (#93)  - feat: Support for defining nodes from lat& lon NumPy arrays (#98) diff --git a/docs/graphs/edge_attributes.rst b/docs/graphs/edge_attributes.rst index 1c99149..d2f4b47 100644 --- a/docs/graphs/edge_attributes.rst +++ b/docs/graphs/edge_attributes.rst @@ -4,7 +4,7 @@ Edges - Attributes #################### -There are 2 main edge attributes implemented in the `anemoi-graphs` +There are few edge attributes implemented in the `anemoi-graphs` package: ************* @@ -44,3 +44,49 @@ latitude and longitude coordinates of the source and target nodes. attributes: edge_length: _target_: anemoi.graphs.edges.attributes.EdgeDirection + +********************* + Attribute from Node +********************* + +Attributes can also be copied from nodes to edges. This is done using +the `AttributeFromNode` base class, with specialized versions for source +and target nodes. + +From Source +=========== + +This attribute copies a specific property of the source node to the +edge. Example usage for copying the cutout mask from nodes to edges in +the encoder: + +.. code:: yaml + + edges: + # Encoder + - source_name: data + target_name: hidden + edge_builders: ... + attributes: + cutout: # Assigned name to the edge attribute, can be different than node_attr_name + _target_: anemoi.graphs.edges.attributes.AttributeFromSourceNode + node_attr_name: cutout + +From Target +=========== + +This attribute copies a specific property of the target node to the +edge. Example usage for copying the coutout mask from nodes to edges in +the decoder: + +.. code:: yaml + + edges: + # Decoder + - source_name: hidden + target_name: data + edge_builders: ... + attributes: + cutout: # Assigned name to the edge attribute, can be different than node_attr_name + _target_: anemoi.graphs.edges.attributes.AttributeFromTargetNode + node_attr_name: cutout diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 6be801b..9d43ac0 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -24,8 +24,9 @@ class BaseEdgeAttribute(ABC, NormaliserMixin): """Base class for edge attributes.""" - def __init__(self, norm: str | None = None) -> None: + def __init__(self, norm: str | None = None, dtype: str = "float32") -> None: self.norm = norm + self.dtype = dtype @abstractmethod def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ... @@ -35,9 +36,9 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: if values.ndim == 1: values = values[:, np.newaxis] - normed_values = self.normalise(values) + norm_values = self.normalise(values) - return torch.tensor(normed_values, dtype=torch.float32) + return torch.tensor(norm_values.astype(self.dtype)) def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor: """Compute the edge attributes.""" @@ -155,3 +156,81 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: values = 1 - values return values + + +class BooleanBaseEdgeAttribute(BaseEdgeAttribute, ABC): + """Base class for boolean edge attributes.""" + + def __init__(self) -> None: + super().__init__(norm=None, dtype="bool") + + +class AttributeFromNode(BooleanBaseEdgeAttribute, ABC): + """ + Base class for Attribute from Node. + + Copy an attribute of either the source or target node to the edge. + Accesses source/target node attribute and propagates it to the edge. + Used for example to identify if an encoder edge originates from a LAM or global node. + + Attributes + ---------- + node_attr_name : str + Name of the node attribute to propagate. + + Methods + ------- + get_node_name(source_name, target_name) + Return the name of the node to copy. + + get_raw_values(graph, source_name, target_name) + Computes the edge attribute from the source or target node attribute. + + """ + + def __init__(self, node_attr_name: str) -> None: + super().__init__() + self.node_attr_name = node_attr_name + self.idx = None + + @abstractmethod + def get_node_name(self, source_name: str, target_name: str): ... + + def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray: + + node_name = self.get_node_name(source_name, target_name) + + edge_index = graph[(source_name, "to", target_name)].edge_index + try: + return graph[node_name][self.node_attr_name].numpy()[edge_index[self.idx]] + + except AttributeError: + raise AttributeError( + f"{self.__class__.__name__} failed because the attribute '{self.node_attr_name}' is not defined for the nodes." + ) + + +class AttributeFromSourceNode(AttributeFromNode): + """ + Copy an attribute of the source node to the edge. + """ + + def __init__(self, node_attr_name: str) -> None: + super().__init__(node_attr_name) + self.idx = 0 + + def get_node_name(self, source_name: str, target_name: str): + return source_name + + +class AttributeFromTargetNode(AttributeFromNode): + """ + Copy an attribute of the target node to the edge. + """ + + def __init__(self, node_attr_name: str) -> None: + super().__init__(node_attr_name) + self.idx = 1 + + def get_node_name(self, source_name: str, target_name: str): + return target_name