Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

[Feature] New Edge Attribute: AttributeFromNode #95

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Keep it human-readable, your future self will thank you!
- feat: Support for multiple edge builders between two sets of nodes (#70)
- feat: Support for providing lon/lat coordinates from a text file (loaded with numpy loadtxt method) to build the graph `TextNodes` (#93)
- feat: Build 2D graphs with `Voronoi` in case `SphericalVoronoi` does not work well/is an overkill (LAM). Set `flat=true` in the nodes attributes to compute area weight using Voronoi with a qhull options preventing the empty region creation (#93)
- feat: Add `AttributeFromNode` edge attribute to copy attribute from source or destination node. Set `node_attr_name` and `node_type : src | dst` in the config to specify which attribute to copy from the source | destination node (#94)


# Changed

Expand Down
71 changes: 71 additions & 0 deletions src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,74 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:
values = 1 - values

return values


class BooleanBaseEdgeAttribute:
"""Base class for boolean edge attributes."""

def __init__(self) -> None:
pass

@abstractmethod
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ...

def post_process(self, values: np.ndarray) -> torch.Tensor:
"""Post-process the values."""
return torch.tensor(values, dtype=torch.bool)

def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor:
"""Compute the edge attributes."""
source_name, _, target_name = edges_name
assert (
source_name in graph.node_types
), f"Node \"{source_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}."
assert (
target_name in graph.node_types
), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}."

values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs)
return self.post_process(values)


class AttributeFromNode(BooleanBaseEdgeAttribute):
"""
Copy an attribute of either the source or destination node to the edge.
Accesses origin/target node attribute and propagates it to the edge.
Used for example to identify if an encoder edge originates from a LAM or global node.

Attributes
----------
node_attr_name : str
Name of the node attribute to propagate.

node_type : str
Pick the node to copy from. Options: "src, dst"

Methods
-------
get_raw_values(graph, source_name, target_name)
Computes the edge attribute from the source or destination node attribute.
"""

def __init__(self, node_attr_name: str, node_type: str) -> None:
self.node_attr_name = node_attr_name
assert node_type in ["src", "dst"]
self.node_type = node_type

def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray:

edge_index = graph[(source_name, "to", target_name)].edge_index

if self.node_type == "src":
name_to_copy = source_name
idx = 0

else:
name_to_copy = target_name
idx = 1

assert hasattr(graph[name_to_copy], self.node_attr_name)

val = getattr(graph[name_to_copy], self.node_attr_name).numpy()[edge_index[idx]]

return val
Loading