This repository has been archived by the owner on Dec 20, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
[Feature] New Edge Attribute: AttributeFromNode #95
Open
icedoom888
wants to merge
16
commits into
ecmwf:develop
Choose a base branch
from
MeteoSwiss:feature/edge_attr_from_node_attr
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
7353f83
Implemented new attribute
icedoom888 fa5ce52
changelog update
icedoom888 1bc0113
changelog update
icedoom888 919d0ad
Refactored following review
icedoom888 40394e4
Update src/anemoi/graphs/edges/attributes.py
icedoom888 77b2335
Update src/anemoi/graphs/edges/attributes.py
icedoom888 a58dae6
Update src/anemoi/graphs/edges/attributes.py
icedoom888 ab2f6a3
Update src/anemoi/graphs/edges/attributes.py
icedoom888 11387b7
Refactor after review
icedoom888 8602b15
Docstring done
icedoom888 18e9a75
Docstring done
icedoom888 89af39a
Docstring done
icedoom888 d3611d3
Fixed ABC issue
icedoom888 5cc3ca8
addressed changes in docs and exception error
icedoom888 7af61c2
Changed changelog
icedoom888 0a48b4f
Merge branch 'develop' into feature/edge_attr_from_node_attr
icedoom888 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -24,8 +24,9 @@ | |||||
class BaseEdgeAttribute(ABC, NormaliserMixin): | ||||||
"""Base class for edge attributes.""" | ||||||
|
||||||
def __init__(self, norm: str | None = None) -> None: | ||||||
def __init__(self, norm: str | None = None, dtype: str = "float32") -> None: | ||||||
self.norm = norm | ||||||
self.dtype = dtype | ||||||
|
||||||
@abstractmethod | ||||||
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ... | ||||||
|
@@ -35,9 +36,9 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: | |||||
if values.ndim == 1: | ||||||
values = values[:, np.newaxis] | ||||||
|
||||||
normed_values = self.normalise(values) | ||||||
norm_values = self.normalise(values) | ||||||
|
||||||
return torch.tensor(normed_values, dtype=torch.float32) | ||||||
return torch.tensor(norm_values.astype(self.dtype)) | ||||||
|
||||||
def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor: | ||||||
"""Compute the edge attributes.""" | ||||||
|
@@ -155,3 +156,77 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: | |||||
values = 1 - values | ||||||
|
||||||
return values | ||||||
|
||||||
|
||||||
class BooleanBaseEdgeAttribute(BaseEdgeAttribute): | ||||||
icedoom888 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"""Base class for boolean edge attributes.""" | ||||||
|
||||||
def __init__(self) -> None: | ||||||
super().__init__(norm=None, dtype="bool") | ||||||
|
||||||
|
||||||
class AttributeFromNode(BooleanBaseEdgeAttribute): | ||||||
icedoom888 marked this conversation as resolved.
Show resolved
Hide resolved
icedoom888 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" | ||||||
Base class for Attribute from Node. | ||||||
|
||||||
Copy an attribute of either the source or target node to the edge. | ||||||
Accesses origin/target node attribute and propagates it to the edge. | ||||||
icedoom888 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
Used for example to identify if an encoder edge originates from a LAM or global node. | ||||||
|
||||||
Attributes | ||||||
---------- | ||||||
node_attr_name : str | ||||||
Name of the node attribute to propagate. | ||||||
|
||||||
Methods | ||||||
------- | ||||||
get_node_name(source_name, target_name) | ||||||
Return the name of the node to copy. | ||||||
|
||||||
get_raw_values(graph, source_name, target_name) | ||||||
Computes the edge attribute from the source or target node attribute. | ||||||
|
||||||
""" | ||||||
|
||||||
def __init__(self, node_attr_name: str) -> None: | ||||||
super().__init__() | ||||||
self.node_attr_name = node_attr_name | ||||||
self.idx = None | ||||||
|
||||||
@abstractmethod | ||||||
def get_node_name(self, source_name: str, target_name: str): ... | ||||||
|
||||||
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray: | ||||||
|
||||||
node_name = self.get_node_name(source_name, target_name) | ||||||
|
||||||
edge_index = graph[(source_name, "to", target_name)].edge_index | ||||||
assert hasattr(graph[node_name], self.node_attr_name) | ||||||
val = getattr(graph[node_name], self.node_attr_name).numpy()[edge_index[self.idx]] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say that direct attribute selection is preferred when possible.
Suggested change
In this case, if we want to test whether the attribute is in the graph, it may be better to use try/catch, with an appropriate error message like. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done ✅ |
||||||
return val | ||||||
|
||||||
|
||||||
class AttributeFromSourceNode(AttributeFromNode): | ||||||
""" | ||||||
Copy an attribute of the source node to the edge. | ||||||
""" | ||||||
|
||||||
def __init__(self, node_attr_name: str) -> None: | ||||||
super().__init__(node_attr_name) | ||||||
self.idx = 0 | ||||||
|
||||||
def get_node_name(self, source_name: str, target_name: str): | ||||||
return source_name | ||||||
|
||||||
|
||||||
class AttributeFromTargetNode(AttributeFromNode): | ||||||
""" | ||||||
Copy an attribute of the target node to the edge. | ||||||
""" | ||||||
|
||||||
def __init__(self, node_attr_name: str) -> None: | ||||||
super().__init__(node_attr_name) | ||||||
self.idx = 1 | ||||||
|
||||||
def get_node_name(self, source_name: str, target_name: str): | ||||||
return target_name |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should go in the Unreleased section.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed ✅