Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
fix: output type
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Jul 1, 2024
1 parent 39ee3ad commit f00fd72
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/anemoi/graphs/nodes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
resolution: Union[int, list[int]],
np_dtype: np.dtype = np.float32,
) -> None:
# TODO: Discuss np_dtype
self.np_dtype = np_dtype

if isinstance(resolution, int):
Expand All @@ -95,9 +96,9 @@ def __init__(

super().__init__()

def get_coordinates(self) -> np.ndarray:
def get_coordinates(self) -> torch.Tensor:
self.nx_graph, coords_rad, self.node_ordering = self.create_nodes()
return coords_rad[self.node_ordering]
return torch.tensor(coords_rad[self.node_ordering])

@abstractmethod
def create_nodes(self) -> np.ndarray: ...
Expand All @@ -122,4 +123,5 @@ class HexRefinedIcosahedralNodes(RefinedIcosahedralNodes):
"""It depends on the h3 Python library."""

def create_nodes(self) -> np.ndarray:
# TODO: AOI mask builder is not used in the current implementation.
return create_hexagonal_nodes(self.resolutions)

0 comments on commit f00fd72

Please sign in to comment.