From f00fd72c673af676a8a67c2b2a08710dbf0101a6 Mon Sep 17 00:00:00 2001 From: theissenhelen Date: Mon, 1 Jul 2024 14:03:44 +0000 Subject: [PATCH] fix: output type --- src/anemoi/graphs/nodes/builder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 6ea6bf0..a477020 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -86,6 +86,7 @@ def __init__( resolution: Union[int, list[int]], np_dtype: np.dtype = np.float32, ) -> None: + # TODO: Discuss np_dtype self.np_dtype = np_dtype if isinstance(resolution, int): @@ -95,9 +96,9 @@ def __init__( super().__init__() - def get_coordinates(self) -> np.ndarray: + def get_coordinates(self) -> torch.Tensor: self.nx_graph, coords_rad, self.node_ordering = self.create_nodes() - return coords_rad[self.node_ordering] + return torch.tensor(coords_rad[self.node_ordering]) @abstractmethod def create_nodes(self) -> np.ndarray: ... @@ -122,4 +123,5 @@ class HexRefinedIcosahedralNodes(RefinedIcosahedralNodes): """It depends on the h3 Python library.""" def create_nodes(self) -> np.ndarray: + # TODO: AOI mask builder is not used in the current implementation. return create_hexagonal_nodes(self.resolutions)