Skip to content

Commit

Permalink
schema 300 - 2d 1d2d (#380)
Browse files Browse the repository at this point in the history
  • Loading branch information
margrietpalm authored Sep 10, 2024
1 parent ee78533 commit dc79af3
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 39 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Changelog of threedigrid-builder
1.18.1 (unreleased)
-------------------

- Nothing changed yet.
- Adapt for changes in schema upgrade for 2d and 1d2d


1.18.0 (2024-09-09)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_version():

install_requires = [
"numpy>=1.15,<3.0",
"threedi-schema==0.225.*",
"threedi-schema==0.226.*",
"shapely>=2",
"pyproj>=3",
"condenser[geo]>=0.1.1",
Expand Down
100 changes: 65 additions & 35 deletions threedigrid_builder/interface/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from contextlib import contextmanager
from enum import Enum
from functools import lru_cache
from typing import Callable, ContextManager
from typing import Callable, ContextManager, Dict, Optional

import numpy as np
import shapely
Expand Down Expand Up @@ -43,7 +43,7 @@
# hardcoded source projection
SOURCE_EPSG = 4326

MIN_SQLITE_VERSION = 225
MIN_SQLITE_VERSION = 226

DAY_IN_SECONDS = 24.0 * 3600.0

Expand Down Expand Up @@ -94,6 +94,20 @@ def _set_initialization_type(
dct[type_field] = None


def arr_to_attr_dict(
arr: np.ndarray, rename_dict: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
"""
Convert structured array to dict with optional rename of the keys
"""
if rename_dict is None:
rename_dict = {}
return {
rename_dict[name] if name in rename_dict else name: arr[name]
for name in arr.dtype.names
}


class SQLite:
def __init__(self, path: pathlib.Path, upgrade=False, convert_to_geopackage=False):
if not path.exists():
Expand Down Expand Up @@ -548,79 +562,82 @@ def get_exchange_lines(self) -> ExchangeLines:
session.query(
models.ExchangeLine.id,
models.ExchangeLine.channel_id,
models.ExchangeLine.the_geom,
models.ExchangeLine.geom,
models.ExchangeLine.exchange_level,
)
.order_by(models.ExchangeLine.id)
.as_structarray()
)

arr["the_geom"] = self.reproject(arr["the_geom"])

arr["geom"] = self.reproject(arr["geom"])
attr_dict = arr_to_attr_dict(arr, {"geom": "the_geom"})
# transform to a Channels object
return ExchangeLines(**{name: arr[name] for name in arr.dtype.names})
return ExchangeLines(**attr_dict)

def get_grid_refinements(self) -> GridRefinements:
"""Return Gridrefinement and GridRefinementArea concatenated into one array."""
with self.get_session() as session:
arr1 = (
session.query(
models.GridRefinement.the_geom,
models.GridRefinement.id,
models.GridRefinement.code,
models.GridRefinement.display_name,
models.GridRefinement.refinement_level,
models.GridRefinementLine.geom,
models.GridRefinementLine.id,
models.GridRefinementLine.code,
models.GridRefinementLine.display_name,
models.GridRefinementLine.grid_level,
)
.filter(
models.GridRefinement.the_geom.isnot(None),
models.GridRefinement.refinement_level.isnot(None),
models.GridRefinementLine.geom.isnot(None),
models.GridRefinementLine.grid_level.isnot(None),
)
.order_by(models.GridRefinement.id)
.order_by(models.GridRefinementLine.id)
.as_structarray()
)
arr2 = (
session.query(
models.GridRefinementArea.the_geom,
models.GridRefinementArea.geom,
models.GridRefinementArea.id,
models.GridRefinementArea.code,
models.GridRefinementArea.display_name,
models.GridRefinementArea.refinement_level,
models.GridRefinementArea.grid_level,
)
.filter(
models.GridRefinementArea.the_geom.isnot(None),
models.GridRefinementArea.refinement_level.isnot(None),
models.GridRefinementArea.geom.isnot(None),
models.GridRefinementArea.grid_level.isnot(None),
)
.order_by(models.GridRefinementArea.id)
.as_structarray()
)
arr = np.concatenate((arr1, arr2))

# reproject
arr["the_geom"] = self.reproject(arr["the_geom"])
arr["id"] = np.arange(len(arr["refinement_level"]))
arr["geom"] = self.reproject(arr["geom"])
arr["id"] = np.arange(len(arr["grid_level"]))

return GridRefinements(**{name: arr[name] for name in arr.dtype.names})
attr_dict = arr_to_attr_dict(
arr, {"geom": "the_geom", "grid_level": "refinement_level"}
)
return GridRefinements(**attr_dict)

def get_dem_average_areas(self) -> DemAverageAreas:
"""Return DemAverageAreas"""
with self.get_session() as session:
arr = (
session.query(
models.DemAverageArea.id,
models.DemAverageArea.the_geom,
models.DemAverageArea.geom,
)
.order_by(models.DemAverageArea.id)
.as_structarray()
)
arr["the_geom"] = self.reproject(arr["the_geom"])

return DemAverageAreas(**{name: arr[name] for name in arr.dtype.names})
arr["geom"] = self.reproject(arr["geom"])
attr_dict = arr_to_attr_dict(arr, {"geom": "the_geom"})
return DemAverageAreas(**attr_dict)

def get_obstacles(self) -> Obstacles:
with self.get_session() as session:
arr = (
session.query(
models.Obstacle.the_geom,
models.Obstacle.geom,
models.Obstacle.id,
models.Obstacle.crest_level,
)
Expand All @@ -629,9 +646,10 @@ def get_obstacles(self) -> Obstacles:
)

# reproject
arr["the_geom"] = self.reproject(arr["the_geom"])

return Obstacles(**{name: arr[name] for name in arr.dtype.names})
arr["geom"] = self.reproject(arr["geom"])
attr_dict = arr_to_attr_dict(arr, {"geom": "the_geom"})
# transform to a Channels object
return Obstacles(**attr_dict)

def get_orifices(self) -> Orifices:
"""Return Orifices"""
Expand Down Expand Up @@ -781,14 +799,14 @@ def get_potential_breaches(self) -> PotentialBreaches:
models.PotentialBreach.id,
models.PotentialBreach.code,
models.PotentialBreach.display_name,
models.PotentialBreach.the_geom,
models.PotentialBreach.geom,
models.PotentialBreach.channel_id,
]

if self.get_version() >= 212:
cols += [
models.PotentialBreach.exchange_level,
models.PotentialBreach.maximum_breach_depth,
models.PotentialBreach.initial_exchange_level,
models.PotentialBreach.final_exchange_level,
models.PotentialBreach.levee_material,
]

Expand All @@ -800,9 +818,21 @@ def get_potential_breaches(self) -> PotentialBreaches:
)

# reproject
arr["the_geom"] = self.reproject(arr["the_geom"])

return PotentialBreaches(**{name: arr[name] for name in arr.dtype.names})
arr["geom"] = self.reproject(arr["geom"])
# derive maximum_breach_depth from initial and final exchange level
# and overwrite final_exchange_level because adding a field is more work
arr["final_exchange_level"] = (
arr["initial_exchange_level"] - arr["final_exchange_level"]
)
attr_dict = arr_to_attr_dict(
arr,
{
"geom": "the_geom",
"initial_exchange_level": "exchange_level",
"final_exchange_level": "maximum_breach_depth",
},
)
return PotentialBreaches(**attr_dict)


# Constructing a Transformer takes quite long, so we use caching here. The
Expand Down
4 changes: 2 additions & 2 deletions threedigrid_builder/tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_init(tmp_path):
with mock.patch(
"threedigrid_builder.interface.db.ThreediDatabase"
) as db, mock.patch.object(SQLite, "get_version") as get_version:
get_version.return_value = 225
get_version.return_value = 226
sqlite = SQLite(path)

db.assert_called_with(path)
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_init_bad_version(tmp_path):


def test_get_version(db):
assert db.get_version() == 225
assert db.get_version() == 226


def test_get_boundary_conditions_1d(db):
Expand Down

0 comments on commit dc79af3

Please sign in to comment.