From ba1bc99ac6e2b0e1993ce743406eb4157bf7dbc8 Mon Sep 17 00:00:00 2001 From: Margriet Palm Date: Mon, 18 Nov 2024 15:19:28 +0100 Subject: [PATCH] Add manhole_id to gridadmin --- threedigrid_builder/interface/db.py | 2 +- threedigrid_builder/interface/gridadmin.py | 4 ++++ threedigrid_builder/tests/test_gridadmin.py | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/threedigrid_builder/interface/db.py b/threedigrid_builder/interface/db.py index 9d15807b..d68c3c3f 100644 --- a/threedigrid_builder/interface/db.py +++ b/threedigrid_builder/interface/db.py @@ -462,7 +462,7 @@ def get_channels(self) -> Channels: return Channels(**attr_dict) def get_connection_nodes(self) -> ConnectionNodes: - """Return ConnectionNodes (which are enriched using the manhole table)""" + """Return ConnectionNodes""" cols = [ models.ConnectionNode.geom, models.ConnectionNode.id, diff --git a/threedigrid_builder/interface/gridadmin.py b/threedigrid_builder/interface/gridadmin.py index 5bc8f208..139b955f 100644 --- a/threedigrid_builder/interface/gridadmin.py +++ b/threedigrid_builder/interface/gridadmin.py @@ -357,6 +357,10 @@ def write_nodes(self, nodes, group_name="nodes"): group, "display_name", to_bytes_array(nodes.display_name, 64) ) self.write_dataset(group, "zoom_category", nodes.zoom_category) + # Set manhole_id to match nodes.id when there is a manhole + self.write_dataset( + group, "manhole_id", np.where(nodes.is_manhole, nodes.id + 1, -9999) + ) self.write_dataset(group, "manhole_indicator", nodes.manhole_indicator) self.write_dataset(group, "shape", to_bytes_array(nodes.shape, 4)) self.write_dataset(group, "drain_level", nodes.drain_level) diff --git a/threedigrid_builder/tests/test_gridadmin.py b/threedigrid_builder/tests/test_gridadmin.py index ff8bbc45..bb1cf1cc 100644 --- a/threedigrid_builder/tests/test_gridadmin.py +++ b/threedigrid_builder/tests/test_gridadmin.py @@ -50,6 +50,7 @@ def h5_out_1d(tmpdir_factory, grid_all): ("id", (4,), "int32"), ("initial_waterlevel", (4,), "float64"), ("is_manhole", (4,), "int32"), + ("manhole_id", (4,), "int32"), ("manhole_indicator", (4,), "int32"), ("node_type", (4,), "int32"), ("pixel_coords", (4, 4), "int32"), @@ -89,6 +90,7 @@ def test_write_nodes(h5_out, dataset, shape, dtype): ("id", (3,), "int32"), ("initial_waterlevel", (3,), "float64"), ("is_manhole", (3,), "int32"), + ("manhole_id", (3,), "int32"), ("manhole_indicator", (3,), "int32"), ("node_type", (3,), "int32"), ("pixel_coords", (4, 3), "int32"),