Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

[Feature] New Edge Attribute: AttributeFromNode #95

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Keep it human-readable, your future self will thank you!
- feat: Support for multiple edge builders between two sets of nodes (#70)
- feat: Support for providing lon/lat coordinates from a text file (loaded with numpy loadtxt method) to build the graph `TextNodes` (#93)
- feat: Build 2D graphs with `Voronoi` in case `SphericalVoronoi` does not work well/is an overkill (LAM). Set `flat=true` in the nodes attributes to compute area weight using Voronoi with a qhull options preventing the empty region creation (#93)
- feat: Add `AttributeFromSourceNode` and `AttributeFromTargetNode` edge attribute to copy attribute from source or target node. Set `node_attr_name` in the config to specify which attribute to copy from the source | target node (#94)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go in the Unreleased section.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed ✅


# Changed

Expand Down
81 changes: 78 additions & 3 deletions src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
class BaseEdgeAttribute(ABC, NormaliserMixin):
"""Base class for edge attributes."""

def __init__(self, norm: str | None = None) -> None:
def __init__(self, norm: str | None = None, dtype: str = "float32") -> None:
self.norm = norm
self.dtype = dtype

@abstractmethod
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ...
Expand All @@ -35,9 +36,9 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:
if values.ndim == 1:
values = values[:, np.newaxis]

normed_values = self.normalise(values)
norm_values = self.normalise(values)

return torch.tensor(normed_values, dtype=torch.float32)
return torch.tensor(norm_values.astype(self.dtype))

def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor:
"""Compute the edge attributes."""
Expand Down Expand Up @@ -155,3 +156,77 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:
values = 1 - values

return values


class BooleanBaseEdgeAttribute(BaseEdgeAttribute):
"""Base class for boolean edge attributes."""

def __init__(self) -> None:
super().__init__(norm=None, dtype="bool")


class AttributeFromNode(BooleanBaseEdgeAttribute):
"""
Base class for Attribute from Node.

Copy an attribute of either the source or target node to the edge.
Accesses origin/target node attribute and propagates it to the edge.
Used for example to identify if an encoder edge originates from a LAM or global node.

Attributes
----------
node_attr_name : str
Name of the node attribute to propagate.

Methods
-------
get_node_name(source_name, target_name)
Return the name of the node to copy.

get_raw_values(graph, source_name, target_name)
Computes the edge attribute from the source or target node attribute.

"""

def __init__(self, node_attr_name: str) -> None:
super().__init__()
self.node_attr_name = node_attr_name
self.idx = None

@abstractmethod
def get_node_name(self, source_name: str, target_name: str): ...

def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray:

node_name = self.get_node_name(source_name, target_name)

edge_index = graph[(source_name, "to", target_name)].edge_index
assert hasattr(graph[node_name], self.node_attr_name)
val = getattr(graph[node_name], self.node_attr_name).numpy()[edge_index[self.idx]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that direct attribute selection is preferred when possible.

Suggested change
val = getattr(graph[node_name], self.node_attr_name).numpy()[edge_index[self.idx]]
return graph[node_name][self.node_attr_name].numpy()[edge_index[self.idx]]

In this case, if we want to test whether the attribute is in the graph, it may be better to use try/catch, with an appropriate error message like.
"{super().__class__.__name__} failed because the attribute {} is not defined for the {} nodes.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done ✅

return val


class AttributeFromSourceNode(AttributeFromNode):
"""
Copy an attribute of the source node to the edge.
"""

def __init__(self, node_attr_name: str) -> None:
super().__init__(node_attr_name)
self.idx = 0

def get_node_name(self, source_name: str, target_name: str):
return source_name


class AttributeFromTargetNode(AttributeFromNode):
"""
Copy an attribute of the target node to the edge.
"""

def __init__(self, node_attr_name: str) -> None:
super().__init__(node_attr_name)
self.idx = 1

def get_node_name(self, source_name: str, target_name: str):
return target_name
Loading