From e6997097356b0a07ca996e018ca2c0c7b3b3705b Mon Sep 17 00:00:00 2001 From: Luuk Blom Date: Fri, 3 Jan 2025 15:50:55 +0100 Subject: [PATCH 01/10] Added reusable netcdf validator --- .../object_model/hazard/forcing/netcdf.py | 21 +++ .../object_model/hazard/interface/forcing.py | 3 +- .../test_events/test_forcing/test_netcdf.py | 123 ++++++++++++++++++ 3 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 flood_adapt/object_model/hazard/forcing/netcdf.py create mode 100644 tests/test_object_model/test_events/test_forcing/test_netcdf.py diff --git a/flood_adapt/object_model/hazard/forcing/netcdf.py b/flood_adapt/object_model/hazard/forcing/netcdf.py new file mode 100644 index 000000000..f01a9b29f --- /dev/null +++ b/flood_adapt/object_model/hazard/forcing/netcdf.py @@ -0,0 +1,21 @@ +import xarray as xr + + +@staticmethod +def validate_netcdf_forcing( + ds: xr.Dataset, required_vars: set[str], required_coords: set[str] +) -> xr.Dataset: + """Validate a forcing dataset by checking for required variables and coordinates.""" + if not required_vars.issubset(ds.data_vars): + missing_vars = required_vars - set(ds.data_vars) + raise ValueError( + f"Missing required variables for netcdf forcing: {missing_vars}" + ) + + if not required_coords.issubset(ds.coords): + missing_coords = required_coords - set(ds.coords) + raise ValueError( + f"Missing required coordinates for netcdf forcing: {missing_coords}" + ) + + return ds diff --git a/flood_adapt/object_model/hazard/interface/forcing.py b/flood_adapt/object_model/hazard/interface/forcing.py index 2ea3bc898..e43c87a13 100644 --- a/flood_adapt/object_model/hazard/interface/forcing.py +++ b/flood_adapt/object_model/hazard/interface/forcing.py @@ -41,7 +41,8 @@ class ForcingSource(str, Enum): MODEL = "MODEL" # 'our' hindcast/ sfincs offshore model TRACK = "TRACK" # 'our' hindcast/ sfincs offshore model + (shifted) hurricane - CSV = "CSV" # user imported data + CSV = "CSV" # user provided csv file + NETCDF = "NETCDF" # user provided netcdf file SYNTHETIC = "SYNTHETIC" # synthetic data CONSTANT = "CONSTANT" # synthetic data diff --git a/tests/test_object_model/test_events/test_forcing/test_netcdf.py b/tests/test_object_model/test_events/test_forcing/test_netcdf.py new file mode 100644 index 000000000..6ac416d06 --- /dev/null +++ b/tests/test_object_model/test_events/test_forcing/test_netcdf.py @@ -0,0 +1,123 @@ +from datetime import timedelta +from typing import Optional + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from flood_adapt.object_model.hazard.forcing.netcdf import validate_netcdf_forcing +from flood_adapt.object_model.hazard.interface.models import TimeModel + + +def get_test_dataset( + excluded_coord: Optional[str] = None, + data_vars=["wind10_u", "wind10_v", "press_msl", "precip"], +) -> xr.Dataset: + gen = np.random.default_rng(42) + time = TimeModel() + + gen = np.random.default_rng(42) + lat = [90, 110] + lon = [0, 20] + _time = pd.date_range( + start=time.start_time, + end=time.end_time, + freq=timedelta(hours=1), + name="time", + ) + + coords = { + "time": _time, + "lat": lat, + "lon": lon, + } + if excluded_coord: + coords.pop(excluded_coord) + dims = list(coords.keys()) + + def _generate_data(dimensions): + shape = tuple(len(coords[dim]) for dim in dimensions if dim in coords) + return gen.random(shape) + + _data_vars = {name: (dims, _generate_data(dims)) for name in data_vars} + + ds = xr.Dataset( + data_vars=_data_vars, + coords=coords, + attrs={ + "crs": 4326, + }, + ) + ds.raster.set_crs(4326) + + return ds + + +def test_all_datavars_all_coords(): + # Arrange + vars = ["wind10_u", "wind10_v", "press_msl", "precip"] + required_vars = set(vars) + + coords = {"time", "lat", "lon"} + required_coords = coords + + ds = get_test_dataset( + excluded_coord=None, + data_vars=vars, + ) + + # Act + result = validate_netcdf_forcing( + ds, required_vars=required_vars, required_coords=required_coords + ) + + # Assert + assert result.equals(ds) + + +def test_missing_datavar_all_coords_raises_validation_error(): + # Arrange + vars = ["wind10_u", "wind10_v", "press_msl", "precip"] + required_vars = set(vars) + required_vars.add("missing_var") + + coords = {"time", "lat", "lon"} + required_coords = coords + + ds = get_test_dataset( + excluded_coord=None, + data_vars=vars, + ) + + # Act + with pytest.raises(ValueError) as e: + validate_netcdf_forcing( + ds, required_vars=required_vars, required_coords=required_coords + ) + + # Assert + assert "missing_var" in str(e.value) + assert "Missing required variables for netcdf forcing:" in str(e.value) + + +@pytest.mark.parametrize("excluded_coord", ["time", "lat", "lon"]) +def test_all_datavar_missing_coords_raises_validation_error(excluded_coord): + vars = ["wind10_u", "wind10_v", "press_msl", "precip"] + required_vars = set(vars) + + coords = {"time", "lat", "lon"} + required_coords = coords.copy() + coords.remove(excluded_coord) + + ds = get_test_dataset(excluded_coord=excluded_coord, data_vars=vars) + + # Act + with pytest.raises(ValueError) as e: + validate_netcdf_forcing( + ds, required_vars=required_vars, required_coords=required_coords + ) + + # Assert + assert "Missing required coordinates for netcdf forcing:" in str(e.value) + assert excluded_coord in str(e.value) From 86ecce4fc5fff4fa448443a4cf2b3eb62bb41a99 Mon Sep 17 00:00:00 2001 From: Luuk Blom Date: Fri, 3 Jan 2025 16:48:33 +0100 Subject: [PATCH 02/10] implemented NETCDF forcings + tests TODO fix tests for adding them to sfincsadapter --- flood_adapt/adapter/sfincs_adapter.py | 43 ++++++++++++++----- .../object_model/hazard/forcing/rainfall.py | 28 ++++++++++++ .../object_model/hazard/forcing/wind.py | 28 ++++++++++++ .../object_model/hazard/interface/models.py | 2 +- tests/test_adapter/test_sfincs_adapter.py | 31 ++++++++++++- .../test_events/test_forcing/test_netcdf.py | 16 +++---- 6 files changed, 127 insertions(+), 21 deletions(-) diff --git a/flood_adapt/adapter/sfincs_adapter.py b/flood_adapt/adapter/sfincs_adapter.py index 54a09356a..4fc564b66 100644 --- a/flood_adapt/adapter/sfincs_adapter.py +++ b/flood_adapt/adapter/sfincs_adapter.py @@ -36,6 +36,7 @@ RainfallConstant, RainfallCSV, RainfallMeteo, + RainfallNetCDF, RainfallSynthetic, RainfallTrack, ) @@ -52,6 +53,7 @@ from flood_adapt.object_model.hazard.forcing.wind import ( WindConstant, WindMeteo, + WindNetCDF, WindSynthetic, WindTrack, ) @@ -339,6 +341,10 @@ def add_projection(self, projection: IProjection): ) ### GETTERS ### + def get_model_time(self) -> TimeModel: + t0, t1 = self._model.get_model_time() + return TimeModel(start_time=t0, end_time=t1) + def get_model_root(self) -> Path: return Path(self._model.root) @@ -905,8 +911,7 @@ def _add_forcing_wind( const_dir : float, optional direction of time-invariant wind forcing [deg], by default None """ - t0, t1 = self._model.get_model_time() - + time_frame = self.get_model_time() if isinstance(wind, WindConstant): # HydroMT function: set wind forcing from constant magnitude and direction self._model.setup_wind_forcing( @@ -915,7 +920,7 @@ def _add_forcing_wind( direction=wind.direction.value, ) elif isinstance(wind, WindSynthetic): - df = wind.to_dataframe(time_frame=TimeModel(start_time=t0, end_time=t1)) + df = wind.to_dataframe(time_frame=time_frame) df["mag"] *= us.UnitfulVelocity( value=1.0, units=Settings().unit_system.velocity ).convert(us.UnitTypesVelocity.mps) @@ -928,7 +933,7 @@ def _add_forcing_wind( timeseries=tmp_path, magnitude=None, direction=None ) elif isinstance(wind, WindMeteo): - ds = MeteoHandler().read(TimeModel(start_time=t0, end_time=t1)) + ds = MeteoHandler().read(time_frame) # data already in metric units so no conversion needed # HydroMT function: set wind forcing from grid @@ -938,6 +943,14 @@ def _add_forcing_wind( raise ValueError("No path to rainfall track file provided.") # data already in metric units so no conversion needed self._add_forcing_spw(wind.path) + elif isinstance(wind, WindNetCDF): + ds = wind.read() + # TODO timeframe + conversion = us.UnitfulVelocity(value=1.0, units=wind.unit).convert( + us.UnitTypesVelocity.mps + ) + ds *= conversion + self._model.setup_precip_forcing_from_grid(precip=ds, aggregate=False) else: self.logger.warning( f"Unsupported wind forcing type: {wind.__class__.__name__}" @@ -954,8 +967,7 @@ def _add_forcing_rain(self, rainfall: IRainfall): const_intensity : float, optional time-invariant precipitation intensity [mm_hr], by default None """ - t0, t1 = self._model.get_model_time() - time_frame = TimeModel(start_time=t0, end_time=t1) + time_frame = self.get_model_time() if isinstance(rainfall, RainfallConstant): self._model.setup_precip_forcing( timeseries=None, @@ -980,13 +992,23 @@ def _add_forcing_rain(self, rainfall: IRainfall): self._model.setup_precip_forcing(timeseries=tmp_path) elif isinstance(rainfall, RainfallMeteo): ds = MeteoHandler().read(time_frame) - # data already in metric units so no conversion needed + # MeteoHandler always return metric so no conversion needed + ds["precip"] *= self._current_scenario.event.attrs.rainfall_multiplier self._model.setup_precip_forcing_from_grid(precip=ds, aggregate=False) elif isinstance(rainfall, RainfallTrack): if rainfall.path is None: raise ValueError("No path to rainfall track file provided.") # data already in metric units so no conversion needed + # TODO rainfall multiplier self._add_forcing_spw(rainfall.path) + elif isinstance(rainfall, RainfallNetCDF): + ds = rainfall.read() + # TODO timeframe + conversion = us.UnitfulIntensity(value=1.0, units=rainfall.unit).convert( + us.UnitTypesIntensity.mm_hr + ) + ds *= self._current_scenario.event.attrs.rainfall_multiplier * conversion + self._model.setup_precip_forcing_from_grid(precip=ds, aggregate=False) else: self.logger.warning( f"Unsupported rainfall forcing type: {rainfall.__class__.__name__}" @@ -1011,8 +1033,7 @@ def _add_forcing_discharge(self, forcing: IDischarge): ) def _add_forcing_waterlevels(self, forcing: IWaterlevel): - t0, t1 = self._model.get_model_time() - time_frame = TimeModel(start_time=t0, end_time=t1) + time_frame = self.get_model_time() if isinstance(forcing, WaterlevelSynthetic): df_ts = forcing.to_dataframe(time_frame=time_frame) conversion = us.UnitfulLength( @@ -1218,8 +1239,8 @@ def _set_single_river_forcing(self, discharge: IDischarge): return self.logger.info(f"Setting discharge forcing for river: {discharge.river.name}") - t0, t1 = self._model.get_model_time() - time_frame = TimeModel(start_time=t0, end_time=t1) + + time_frame = self.get_model_time() model_rivers = self._read_river_locations() # Check that the river is defined in the model and that the coordinates match diff --git a/flood_adapt/object_model/hazard/forcing/rainfall.py b/flood_adapt/object_model/hazard/forcing/rainfall.py index d3e41de8d..51b33b5e4 100644 --- a/flood_adapt/object_model/hazard/forcing/rainfall.py +++ b/flood_adapt/object_model/hazard/forcing/rainfall.py @@ -4,8 +4,10 @@ from typing import Optional import pandas as pd +import xarray as xr from pydantic import Field +from flood_adapt.object_model.hazard.forcing.netcdf import validate_netcdf_forcing from flood_adapt.object_model.hazard.forcing.timeseries import ( CSVTimeseries, SyntheticTimeseries, @@ -111,3 +113,29 @@ def save_additional(self, output_dir: Path | str | os.PathLike) -> None: @classmethod def default(cls) -> "RainfallCSV": return RainfallCSV(path="path/to/rainfall.csv") + + +class RainfallNetCDF(IRainfall): + source: ForcingSource = ForcingSource.NETCDF + unit: us.UnitTypesIntensity = us.UnitTypesIntensity.mm_hr + + path: Path + + def read(self) -> xr.Dataset: + ds = xr.open_dataset(self.path) + required_vars = {"precip"} + required_coords = {"time", "lat", "lon"} + return validate_netcdf_forcing(ds, required_vars, required_coords) + + def save_additional(self, output_dir: Path | str | os.PathLike) -> None: + if self.path: + output_dir = Path(output_dir) + if self.path == output_dir / self.path.name: + return + output_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(self.path, output_dir) + self.path = output_dir / self.path.name + + @classmethod + def default(cls) -> "RainfallNetCDF": + return RainfallNetCDF(Path("path/to/forcing.nc")) diff --git a/flood_adapt/object_model/hazard/forcing/wind.py b/flood_adapt/object_model/hazard/forcing/wind.py index 8d96de86c..e599b6976 100644 --- a/flood_adapt/object_model/hazard/forcing/wind.py +++ b/flood_adapt/object_model/hazard/forcing/wind.py @@ -4,8 +4,10 @@ from typing import Optional import pandas as pd +import xarray as xr from pydantic import Field +from flood_adapt.object_model.hazard.forcing.netcdf import validate_netcdf_forcing from flood_adapt.object_model.hazard.forcing.timeseries import SyntheticTimeseries from flood_adapt.object_model.hazard.interface.forcing import ( ForcingSource, @@ -138,3 +140,29 @@ class WindMeteo(IWind): @classmethod def default(cls) -> "WindMeteo": return WindMeteo() + + +class WindNetCDF(IWind): + source: ForcingSource = ForcingSource.NETCDF + unit: us.UnitTypesVelocity = us.UnitTypesVelocity.mps + + path: Path + + def read(self) -> xr.Dataset: + ds = xr.open_dataset(self.path) + required_vars = {"wind10_v", "wind10_u", "press_msl"} + required_coords = {"time", "lat", "lon"} + return validate_netcdf_forcing(ds, required_vars, required_coords) + + def save_additional(self, output_dir: Path | str | os.PathLike) -> None: + if self.path: + output_dir = Path(output_dir) + if self.path == output_dir / self.path.name: + return + output_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(self.path, output_dir) + self.path = output_dir / self.path.name + + @classmethod + def default(cls) -> "WindNetCDF": + return WindNetCDF(Path("path/to/forcing.nc")) diff --git a/flood_adapt/object_model/hazard/interface/models.py b/flood_adapt/object_model/hazard/interface/models.py index 66797ca79..e848faccf 100644 --- a/flood_adapt/object_model/hazard/interface/models.py +++ b/flood_adapt/object_model/hazard/interface/models.py @@ -12,7 +12,7 @@ class TimeModel(BaseModel): start_time: datetime = REFERENCE_TIME end_time: datetime = REFERENCE_TIME + timedelta(days=1) - time_step: timedelta = timedelta(seconds=10) + time_step: timedelta = timedelta(minutes=10) @field_validator("start_time", "end_time", mode="before") @classmethod diff --git a/tests/test_adapter/test_sfincs_adapter.py b/tests/test_adapter/test_sfincs_adapter.py index ed61098e7..7bfd1de8f 100644 --- a/tests/test_adapter/test_sfincs_adapter.py +++ b/tests/test_adapter/test_sfincs_adapter.py @@ -34,6 +34,7 @@ from flood_adapt.object_model.hazard.forcing.wind import ( WindConstant, WindMeteo, + WindNetCDF, WindSynthetic, WindTrack, ) @@ -63,6 +64,9 @@ from flood_adapt.object_model.io import unit_system as us from flood_adapt.object_model.projection import Projection from tests.fixtures import TEST_DATA_DIR +from tests.test_object_model.test_events.test_forcing.test_netcdf import ( + get_test_dataset, +) @pytest.fixture() @@ -73,7 +77,11 @@ def default_sfincs_adapter(test_db) -> SfincsAdapter: duration = timedelta(hours=3) adapter.set_timing( - TimeModel(start_time=start_time, end_time=start_time + duration) + TimeModel( + start_time=start_time, + end_time=start_time + duration, + time_step=timedelta(hours=1), + ) ) adapter.logger = mock.Mock() adapter.logger.handlers = [] @@ -91,6 +99,7 @@ def sfincs_adapter_with_dummy_scn(default_sfincs_adapter): dummy_event.attrs.rainfall_multiplier = 2 dummy_scn.event = dummy_event default_sfincs_adapter._current_scenario = dummy_scn + default_sfincs_adapter.ensure_no_existing_forcings() yield default_sfincs_adapter @@ -123,6 +132,7 @@ def sfincs_adapter_2_rivers(test_db: IDatabase) -> tuple[IDatabase, SfincsAdapte adapter._logger = mock.Mock() adapter.logger.handlers = [] adapter.logger.warning = mock.Mock() + adapter.ensure_no_existing_forcings() return adapter, test_db @@ -374,6 +384,25 @@ def test_add_forcing_wind_from_meteo( assert default_sfincs_adapter.wind is not None + def test_add_forcing_wind_from_netcdf( + self, test_db: IDatabase, default_sfincs_adapter: SfincsAdapter + ): + # Arrange + path = Path(tempfile.gettempdir()) / "wind_netcdf.nc" + ds = get_test_dataset( + time=default_sfincs_adapter.get_model_time(), + lat=int(test_db.site.attrs.lat), + lon=int(test_db.site.attrs.lon), + ) + ds.to_netcdf(path) + forcing = WindNetCDF(path=path) + + # Act + default_sfincs_adapter.add_forcing(forcing) + + # Assert + assert default_sfincs_adapter.wind is not None + def test_add_forcing_wind_from_track( self, test_db, tmp_path, default_sfincs_adapter: SfincsAdapter ): diff --git a/tests/test_object_model/test_events/test_forcing/test_netcdf.py b/tests/test_object_model/test_events/test_forcing/test_netcdf.py index 6ac416d06..9c2118492 100644 --- a/tests/test_object_model/test_events/test_forcing/test_netcdf.py +++ b/tests/test_object_model/test_events/test_forcing/test_netcdf.py @@ -1,4 +1,3 @@ -from datetime import timedelta from typing import Optional import numpy as np @@ -11,26 +10,27 @@ def get_test_dataset( + lat: int = -80, + lon: int = 32, + time: TimeModel = TimeModel(), excluded_coord: Optional[str] = None, data_vars=["wind10_u", "wind10_v", "press_msl", "precip"], ) -> xr.Dataset: gen = np.random.default_rng(42) - time = TimeModel() - gen = np.random.default_rng(42) - lat = [90, 110] - lon = [0, 20] + _lat = np.arange(lat - 10, lat + 10, 1) + _lon = np.arange(lon - 10, lon + 10, 1) _time = pd.date_range( start=time.start_time, end=time.end_time, - freq=timedelta(hours=1), + freq=time.time_step, name="time", ) coords = { "time": _time, - "lat": lat, - "lon": lon, + "lat": _lat, + "lon": _lon, } if excluded_coord: coords.pop(excluded_coord) From 999519f4a361f082e6d565bc83aa50e75de5b65e Mon Sep 17 00:00:00 2001 From: Luuk Blom Date: Mon, 6 Jan 2025 10:32:42 +0100 Subject: [PATCH 03/10] Implement add_forcing methods for NetCDF forcing classes --- flood_adapt/adapter/sfincs_adapter.py | 2 +- tests/test_adapter/test_sfincs_adapter.py | 37 +++++++++++++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/flood_adapt/adapter/sfincs_adapter.py b/flood_adapt/adapter/sfincs_adapter.py index 4fc564b66..a9c5b7120 100644 --- a/flood_adapt/adapter/sfincs_adapter.py +++ b/flood_adapt/adapter/sfincs_adapter.py @@ -950,7 +950,7 @@ def _add_forcing_wind( us.UnitTypesVelocity.mps ) ds *= conversion - self._model.setup_precip_forcing_from_grid(precip=ds, aggregate=False) + self._model.setup_wind_forcing_from_grid(wind=ds) else: self.logger.warning( f"Unsupported wind forcing type: {wind.__class__.__name__}" diff --git a/tests/test_adapter/test_sfincs_adapter.py b/tests/test_adapter/test_sfincs_adapter.py index 7bfd1de8f..b7f2a7c0a 100644 --- a/tests/test_adapter/test_sfincs_adapter.py +++ b/tests/test_adapter/test_sfincs_adapter.py @@ -21,6 +21,7 @@ from flood_adapt.object_model.hazard.forcing.rainfall import ( RainfallConstant, RainfallMeteo, + RainfallNetCDF, RainfallSynthetic, ) from flood_adapt.object_model.hazard.forcing.waterlevels import ( @@ -94,12 +95,12 @@ def default_sfincs_adapter(test_db) -> SfincsAdapter: @pytest.fixture() def sfincs_adapter_with_dummy_scn(default_sfincs_adapter): + # Mock scenario to get a rainfall multiplier dummy_scn = mock.Mock() dummy_event = mock.Mock() dummy_event.attrs.rainfall_multiplier = 2 dummy_scn.event = dummy_event default_sfincs_adapter._current_scenario = dummy_scn - default_sfincs_adapter.ensure_no_existing_forcings() yield default_sfincs_adapter @@ -389,8 +390,14 @@ def test_add_forcing_wind_from_netcdf( ): # Arrange path = Path(tempfile.gettempdir()) / "wind_netcdf.nc" + + # TODO remove 2 lines below + # investigate why hydromt-sfincs raises if the timestep is < 1 hour + time = TimeModel(time_step=timedelta(hours=1)) + default_sfincs_adapter.set_timing(time) + ds = get_test_dataset( - time=default_sfincs_adapter.get_model_time(), + time=time, lat=int(test_db.site.attrs.lat), lon=int(test_db.site.attrs.lon), ) @@ -497,6 +504,32 @@ def test_add_forcing_from_meteo( # Assert assert adapter.rainfall is not None + def test_add_forcing_rainfall_from_netcdf( + self, test_db: IDatabase, sfincs_adapter_with_dummy_scn: SfincsAdapter + ): + # Arrange + adapter = sfincs_adapter_with_dummy_scn + path = Path(tempfile.gettempdir()) / "wind_netcdf.nc" + + # TODO remove 2 lines below + # investigate why hydromt-sfincs raises if the timestep is < 1 hour + time = TimeModel(time_step=timedelta(hours=1)) + adapter.set_timing(time) + + ds = get_test_dataset( + time=time, + lat=int(test_db.site.attrs.lat), + lon=int(test_db.site.attrs.lon), + ) + ds.to_netcdf(path) + forcing = RainfallNetCDF(path=path) + + # Act + adapter.add_forcing(forcing) + + # Assert + assert adapter.rainfall is not None + def test_add_forcing_unsupported( self, sfincs_adapter_with_dummy_scn: SfincsAdapter ): From 8c4045c52985bdd9116295efe3813fb93f8e5f04 Mon Sep 17 00:00:00 2001 From: Luuk Blom Date: Mon, 6 Jan 2025 11:36:24 +0100 Subject: [PATCH 04/10] add check for timestep < 1H to netcdf validator --- .../object_model/hazard/forcing/netcdf.py | 9 ++++++- tests/test_adapter/test_sfincs_adapter.py | 4 --- .../test_events/test_forcing/test_netcdf.py | 27 ++++++++++++++++++- 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/flood_adapt/object_model/hazard/forcing/netcdf.py b/flood_adapt/object_model/hazard/forcing/netcdf.py index f01a9b29f..2bf7c3330 100644 --- a/flood_adapt/object_model/hazard/forcing/netcdf.py +++ b/flood_adapt/object_model/hazard/forcing/netcdf.py @@ -1,3 +1,5 @@ +import numpy as np +import pandas as pd import xarray as xr @@ -11,11 +13,16 @@ def validate_netcdf_forcing( raise ValueError( f"Missing required variables for netcdf forcing: {missing_vars}" ) - if not required_coords.issubset(ds.coords): missing_coords = required_coords - set(ds.coords) raise ValueError( f"Missing required coordinates for netcdf forcing: {missing_coords}" ) + ts = pd.to_timedelta(np.diff(ds.time).mean()) + if ts < pd.to_timedelta("1H"): + raise ValueError( + f"SFINCS NetCDF forcing time step cannot be less than 1 hour: {ts}" + ) + return ds diff --git a/tests/test_adapter/test_sfincs_adapter.py b/tests/test_adapter/test_sfincs_adapter.py index b7f2a7c0a..93dbbc38e 100644 --- a/tests/test_adapter/test_sfincs_adapter.py +++ b/tests/test_adapter/test_sfincs_adapter.py @@ -391,8 +391,6 @@ def test_add_forcing_wind_from_netcdf( # Arrange path = Path(tempfile.gettempdir()) / "wind_netcdf.nc" - # TODO remove 2 lines below - # investigate why hydromt-sfincs raises if the timestep is < 1 hour time = TimeModel(time_step=timedelta(hours=1)) default_sfincs_adapter.set_timing(time) @@ -511,8 +509,6 @@ def test_add_forcing_rainfall_from_netcdf( adapter = sfincs_adapter_with_dummy_scn path = Path(tempfile.gettempdir()) / "wind_netcdf.nc" - # TODO remove 2 lines below - # investigate why hydromt-sfincs raises if the timestep is < 1 hour time = TimeModel(time_step=timedelta(hours=1)) adapter.set_timing(time) diff --git a/tests/test_object_model/test_events/test_forcing/test_netcdf.py b/tests/test_object_model/test_events/test_forcing/test_netcdf.py index 9c2118492..4b1d24369 100644 --- a/tests/test_object_model/test_events/test_forcing/test_netcdf.py +++ b/tests/test_object_model/test_events/test_forcing/test_netcdf.py @@ -1,3 +1,4 @@ +from datetime import timedelta from typing import Optional import numpy as np @@ -12,7 +13,7 @@ def get_test_dataset( lat: int = -80, lon: int = 32, - time: TimeModel = TimeModel(), + time: TimeModel = TimeModel(time_step=timedelta(hours=1)), excluded_coord: Optional[str] = None, data_vars=["wind10_u", "wind10_v", "press_msl", "precip"], ) -> xr.Dataset: @@ -121,3 +122,27 @@ def test_all_datavar_missing_coords_raises_validation_error(excluded_coord): # Assert assert "Missing required coordinates for netcdf forcing:" in str(e.value) assert excluded_coord in str(e.value) + + +def test_netcdf_timestep_less_than_1_hour_raises(): + # Arrange + vars = ["wind10_u", "wind10_v", "press_msl", "precip"] + required_vars = set(vars) + + coords = {"time", "lat", "lon"} + required_coords = coords + + ds = get_test_dataset( + time=TimeModel(time_step=timedelta(minutes=30)), + excluded_coord=None, + data_vars=vars, + ) + + # Act + with pytest.raises(ValueError) as e: + validate_netcdf_forcing( + ds, required_vars=required_vars, required_coords=required_coords + ) + + # Assert + assert "SFINCS NetCDF forcing time step cannot be less than 1 hour" in str(e.value) From f2518d932673fbab5e2002ea9811774e328cdfcd Mon Sep 17 00:00:00 2001 From: Luuk Blom Date: Mon, 6 Jan 2025 15:18:42 +0100 Subject: [PATCH 05/10] implement review comments --- flood_adapt/adapter/sfincs_adapter.py | 4 +- .../object_model/hazard/forcing/discharge.py | 2 +- .../object_model/hazard/forcing/netcdf.py | 22 +++- .../object_model/hazard/forcing/rainfall.py | 6 +- .../hazard/forcing/waterlevels.py | 2 +- .../object_model/hazard/forcing/wind.py | 6 +- .../test_events/test_forcing/test_netcdf.py | 108 +++++++++--------- 7 files changed, 78 insertions(+), 72 deletions(-) diff --git a/flood_adapt/adapter/sfincs_adapter.py b/flood_adapt/adapter/sfincs_adapter.py index a9c5b7120..a2abe7c8e 100644 --- a/flood_adapt/adapter/sfincs_adapter.py +++ b/flood_adapt/adapter/sfincs_adapter.py @@ -945,7 +945,7 @@ def _add_forcing_wind( self._add_forcing_spw(wind.path) elif isinstance(wind, WindNetCDF): ds = wind.read() - # TODO timeframe + # time slicing to time_frame not needed, hydromt-sfincs handles it conversion = us.UnitfulVelocity(value=1.0, units=wind.unit).convert( us.UnitTypesVelocity.mps ) @@ -1003,7 +1003,7 @@ def _add_forcing_rain(self, rainfall: IRainfall): self._add_forcing_spw(rainfall.path) elif isinstance(rainfall, RainfallNetCDF): ds = rainfall.read() - # TODO timeframe + # time slicing to time_frame not needed, hydromt-sfincs handles it conversion = us.UnitfulIntensity(value=1.0, units=rainfall.unit).convert( us.UnitTypesIntensity.mm_hr ) diff --git a/flood_adapt/object_model/hazard/forcing/discharge.py b/flood_adapt/object_model/hazard/forcing/discharge.py index 6ece991d6..ec4bbb612 100644 --- a/flood_adapt/object_model/hazard/forcing/discharge.py +++ b/flood_adapt/object_model/hazard/forcing/discharge.py @@ -92,7 +92,7 @@ def to_dataframe(self, time_frame: TimeModel) -> pd.DataFrame: def save_additional(self, output_dir: Path | str | os.PathLike) -> None: if self.path: - output_dir = Path(output_dir) + output_dir = Path(output_dir).resolve() if self.path == output_dir / self.path.name: return output_dir.mkdir(parents=True, exist_ok=True) diff --git a/flood_adapt/object_model/hazard/forcing/netcdf.py b/flood_adapt/object_model/hazard/forcing/netcdf.py index 2bf7c3330..418aeaf2c 100644 --- a/flood_adapt/object_model/hazard/forcing/netcdf.py +++ b/flood_adapt/object_model/hazard/forcing/netcdf.py @@ -5,24 +5,36 @@ @staticmethod def validate_netcdf_forcing( - ds: xr.Dataset, required_vars: set[str], required_coords: set[str] + ds: xr.Dataset, required_vars: tuple[str, ...], required_coords: tuple[str, ...] ) -> xr.Dataset: """Validate a forcing dataset by checking for required variables and coordinates.""" - if not required_vars.issubset(ds.data_vars): - missing_vars = required_vars - set(ds.data_vars) + # Check variables + _required_vars = set(required_vars) + if not _required_vars.issubset(ds.data_vars): + missing_vars = _required_vars - set(ds.data_vars) raise ValueError( f"Missing required variables for netcdf forcing: {missing_vars}" ) - if not required_coords.issubset(ds.coords): - missing_coords = required_coords - set(ds.coords) + + # Check coordinates + _required_coords = set(required_coords) + if not _required_coords.issubset(ds.coords): + missing_coords = _required_coords - set(ds.coords) raise ValueError( f"Missing required coordinates for netcdf forcing: {missing_coords}" ) + # Check time step ts = pd.to_timedelta(np.diff(ds.time).mean()) if ts < pd.to_timedelta("1H"): raise ValueError( f"SFINCS NetCDF forcing time step cannot be less than 1 hour: {ts}" ) + for var in ds.data_vars: + # Check order of dimensions + if ds[var].dims != required_coords: + raise ValueError( + f"Order of dimensions for variable {var} must be {required_coords}" + ) return ds diff --git a/flood_adapt/object_model/hazard/forcing/rainfall.py b/flood_adapt/object_model/hazard/forcing/rainfall.py index 51b33b5e4..ac315d5a0 100644 --- a/flood_adapt/object_model/hazard/forcing/rainfall.py +++ b/flood_adapt/object_model/hazard/forcing/rainfall.py @@ -78,7 +78,7 @@ class RainfallTrack(IRainfall): def save_additional(self, output_dir: Path | str | os.PathLike) -> None: if self.path: - output_dir = Path(output_dir) + output_dir = Path(output_dir).resolve() if self.path == output_dir / self.path.name: return output_dir.mkdir(parents=True, exist_ok=True) @@ -103,7 +103,7 @@ def to_dataframe(self, time_frame: TimeModel) -> pd.DataFrame: def save_additional(self, output_dir: Path | str | os.PathLike) -> None: if self.path: - output_dir = Path(output_dir) + output_dir = Path(output_dir).resolve() if self.path == output_dir / self.path.name: return output_dir.mkdir(parents=True, exist_ok=True) @@ -129,7 +129,7 @@ def read(self) -> xr.Dataset: def save_additional(self, output_dir: Path | str | os.PathLike) -> None: if self.path: - output_dir = Path(output_dir) + output_dir = Path(output_dir).resolve() if self.path == output_dir / self.path.name: return output_dir.mkdir(parents=True, exist_ok=True) diff --git a/flood_adapt/object_model/hazard/forcing/waterlevels.py b/flood_adapt/object_model/hazard/forcing/waterlevels.py index ffde71ce0..c72347c6c 100644 --- a/flood_adapt/object_model/hazard/forcing/waterlevels.py +++ b/flood_adapt/object_model/hazard/forcing/waterlevels.py @@ -118,7 +118,7 @@ def to_dataframe(self, time_frame: TimeModel) -> pd.DataFrame: def save_additional(self, output_dir: Path | str | os.PathLike) -> None: if self.path: - output_dir = Path(output_dir) + output_dir = Path(output_dir).resolve() if self.path == output_dir / self.path.name: return output_dir.mkdir(parents=True, exist_ok=True) diff --git a/flood_adapt/object_model/hazard/forcing/wind.py b/flood_adapt/object_model/hazard/forcing/wind.py index e599b6976..56238d36e 100644 --- a/flood_adapt/object_model/hazard/forcing/wind.py +++ b/flood_adapt/object_model/hazard/forcing/wind.py @@ -99,7 +99,7 @@ class WindTrack(IWind): def save_additional(self, output_dir: Path | str | os.PathLike) -> None: if self.path: - output_dir = Path(output_dir) + output_dir = Path(output_dir).resolve() if self.path == output_dir / self.path.name: return output_dir.mkdir(parents=True, exist_ok=True) @@ -122,7 +122,7 @@ def to_dataframe(self, time_frame: TimeModel) -> pd.DataFrame: def save_additional(self, output_dir: Path | str | os.PathLike) -> None: if self.path: - output_dir = Path(output_dir) + output_dir = Path(output_dir).resolve() if self.path == output_dir / self.path.name: return output_dir.mkdir(parents=True, exist_ok=True) @@ -156,7 +156,7 @@ def read(self) -> xr.Dataset: def save_additional(self, output_dir: Path | str | os.PathLike) -> None: if self.path: - output_dir = Path(output_dir) + output_dir = Path(output_dir).resolve() if self.path == output_dir / self.path.name: return output_dir.mkdir(parents=True, exist_ok=True) diff --git a/tests/test_object_model/test_events/test_forcing/test_netcdf.py b/tests/test_object_model/test_events/test_forcing/test_netcdf.py index 4b1d24369..f8ed5ecd9 100644 --- a/tests/test_object_model/test_events/test_forcing/test_netcdf.py +++ b/tests/test_object_model/test_events/test_forcing/test_netcdf.py @@ -1,5 +1,5 @@ +from copy import copy from datetime import timedelta -from typing import Optional import numpy as np import pandas as pd @@ -10,12 +10,22 @@ from flood_adapt.object_model.hazard.interface.models import TimeModel +@pytest.fixture +def required_vars(): + return ("wind10_u", "wind10_v", "press_msl", "precip") + + +@pytest.fixture +def required_coords(): + return ("time", "lat", "lon") + + def get_test_dataset( lat: int = -80, lon: int = 32, time: TimeModel = TimeModel(time_step=timedelta(hours=1)), - excluded_coord: Optional[str] = None, - data_vars=["wind10_u", "wind10_v", "press_msl", "precip"], + coords: tuple[str, ...] = ("time", "lat", "lon"), + data_vars: tuple[str, ...] = ("wind10_u", "wind10_v", "press_msl", "precip"), ) -> xr.Dataset: gen = np.random.default_rng(42) @@ -28,24 +38,22 @@ def get_test_dataset( name="time", ) - coords = { + _coords = { "time": _time, "lat": _lat, "lon": _lon, } - if excluded_coord: - coords.pop(excluded_coord) - dims = list(coords.keys()) + coords_dict = {name: _coords.get(name, np.arange(10)) for name in coords} def _generate_data(dimensions): - shape = tuple(len(coords[dim]) for dim in dimensions if dim in coords) + shape = tuple(len(coords_dict[dim]) for dim in dimensions if dim in coords_dict) return gen.random(shape) - _data_vars = {name: (dims, _generate_data(dims)) for name in data_vars} + _data_vars = {name: (coords, _generate_data(coords)) for name in data_vars} ds = xr.Dataset( data_vars=_data_vars, - coords=coords, + coords=coords_dict, attrs={ "crs": 4326, }, @@ -55,18 +63,9 @@ def _generate_data(dimensions): return ds -def test_all_datavars_all_coords(): +def test_all_datavars_all_coords(required_vars, required_coords): # Arrange - vars = ["wind10_u", "wind10_v", "press_msl", "precip"] - required_vars = set(vars) - - coords = {"time", "lat", "lon"} - required_coords = coords - - ds = get_test_dataset( - excluded_coord=None, - data_vars=vars, - ) + ds = get_test_dataset() # Act result = validate_netcdf_forcing( @@ -77,41 +76,43 @@ def test_all_datavars_all_coords(): assert result.equals(ds) -def test_missing_datavar_all_coords_raises_validation_error(): +def test_missing_datavar_all_coords_raises_validation_error( + required_coords, required_vars +): # Arrange - vars = ["wind10_u", "wind10_v", "press_msl", "precip"] - required_vars = set(vars) - required_vars.add("missing_var") - - coords = {"time", "lat", "lon"} - required_coords = coords - - ds = get_test_dataset( - excluded_coord=None, - data_vars=vars, - ) + vars = tuple(copy(required_vars) + ("missing_var",)) + ds = get_test_dataset(data_vars=required_vars) # Act with pytest.raises(ValueError) as e: - validate_netcdf_forcing( - ds, required_vars=required_vars, required_coords=required_coords - ) + validate_netcdf_forcing(ds, required_vars=vars, required_coords=required_coords) # Assert assert "missing_var" in str(e.value) assert "Missing required variables for netcdf forcing:" in str(e.value) -@pytest.mark.parametrize("excluded_coord", ["time", "lat", "lon"]) -def test_all_datavar_missing_coords_raises_validation_error(excluded_coord): - vars = ["wind10_u", "wind10_v", "press_msl", "precip"] - required_vars = set(vars) +def test_all_datavar_missing_coords_raises_validation_error( + required_vars, required_coords +): + # Arrange + coords = tuple(copy(required_coords) + ("missing_coord",)) + ds = get_test_dataset(coords=required_coords, data_vars=required_vars) + + # Act + with pytest.raises(ValueError) as e: + validate_netcdf_forcing(ds, required_vars=required_vars, required_coords=coords) + + # Assert + assert "Missing required coordinates for netcdf forcing:" in str(e.value) + assert "missing_coord" in str(e.value) - coords = {"time", "lat", "lon"} - required_coords = coords.copy() - coords.remove(excluded_coord) - ds = get_test_dataset(excluded_coord=excluded_coord, data_vars=vars) +def test_netcdf_timestep_less_than_1_hour_raises(required_vars, required_coords): + # Arrange + ds = get_test_dataset( + time=TimeModel(time_step=timedelta(minutes=30)), + ) # Act with pytest.raises(ValueError) as e: @@ -120,22 +121,14 @@ def test_all_datavar_missing_coords_raises_validation_error(excluded_coord): ) # Assert - assert "Missing required coordinates for netcdf forcing:" in str(e.value) - assert excluded_coord in str(e.value) + assert "SFINCS NetCDF forcing time step cannot be less than 1 hour" in str(e.value) -def test_netcdf_timestep_less_than_1_hour_raises(): +def test_netcdf_incorrect_coord_order_raises(required_vars, required_coords): # Arrange - vars = ["wind10_u", "wind10_v", "press_msl", "precip"] - required_vars = set(vars) - - coords = {"time", "lat", "lon"} - required_coords = coords - ds = get_test_dataset( - time=TimeModel(time_step=timedelta(minutes=30)), - excluded_coord=None, - data_vars=vars, + coords=required_coords[::-1], # reverse order + data_vars=required_vars, ) # Act @@ -145,4 +138,5 @@ def test_netcdf_timestep_less_than_1_hour_raises(): ) # Assert - assert "SFINCS NetCDF forcing time step cannot be less than 1 hour" in str(e.value) + assert "Order of dimensions for variable" in str(e.value) + assert f"must be {tuple(required_coords)}" in str(e.value) From c9c4e2f72c2c2e54df4e052ffaaaf5ea22dd7e0a Mon Sep 17 00:00:00 2001 From: Luuk Blom Date: Mon, 6 Jan 2025 15:34:20 +0100 Subject: [PATCH 06/10] removed TODO and made issue --- flood_adapt/adapter/sfincs_adapter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flood_adapt/adapter/sfincs_adapter.py b/flood_adapt/adapter/sfincs_adapter.py index a2abe7c8e..ff27ea3f9 100644 --- a/flood_adapt/adapter/sfincs_adapter.py +++ b/flood_adapt/adapter/sfincs_adapter.py @@ -999,7 +999,6 @@ def _add_forcing_rain(self, rainfall: IRainfall): if rainfall.path is None: raise ValueError("No path to rainfall track file provided.") # data already in metric units so no conversion needed - # TODO rainfall multiplier self._add_forcing_spw(rainfall.path) elif isinstance(rainfall, RainfallNetCDF): ds = rainfall.read() From b633f23b6da745305155fcf62ab0e52003176653 Mon Sep 17 00:00:00 2001 From: Luuk Blom Date: Wed, 8 Jan 2025 17:00:47 +0100 Subject: [PATCH 07/10] ensure proper closing of .nc files when they are opened --- flood_adapt/dbs_classes/database.py | 11 +++++------ flood_adapt/object_model/hazard/forcing/rainfall.py | 9 +++++---- flood_adapt/object_model/hazard/forcing/wind.py | 9 +++++---- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/flood_adapt/dbs_classes/database.py b/flood_adapt/dbs_classes/database.py index 2c5722f51..0f7982fa3 100644 --- a/flood_adapt/dbs_classes/database.py +++ b/flood_adapt/dbs_classes/database.py @@ -7,11 +7,11 @@ import geopandas as gpd import numpy as np import pandas as pd +import xarray as xr from cht_cyclones.tropical_cyclone import TropicalCyclone from geopandas import GeoDataFrame from plotly.express import line from plotly.express.colors import sample_colorscale -from xarray import open_dataarray, open_dataset from flood_adapt.dbs_classes.dbs_benefit import DbsBenefit from flood_adapt.dbs_classes.dbs_event import DbsEvent @@ -484,17 +484,16 @@ def get_max_water_level( "Flooding", "max_water_level_map.nc", ) - map = open_dataarray(map_path) - - zsmax = map.to_numpy() - + with xr.open_dataarray(map_path) as map: + zsmax = map.to_numpy() else: file_path = self.scenarios.output_path.joinpath( scenario_name, "Flooding", f"RP_{return_period:04d}_maps.nc", ) - zsmax = open_dataset(file_path)["risk_map"][:, :].to_numpy().T + with xr.open_dataset(file_path) as ds: + zsmax = ds["risk_map"][:, :].to_numpy().T return zsmax def get_fiat_footprints(self, scenario_name: str) -> GeoDataFrame: diff --git a/flood_adapt/object_model/hazard/forcing/rainfall.py b/flood_adapt/object_model/hazard/forcing/rainfall.py index ac315d5a0..3dc589ae4 100644 --- a/flood_adapt/object_model/hazard/forcing/rainfall.py +++ b/flood_adapt/object_model/hazard/forcing/rainfall.py @@ -122,10 +122,11 @@ class RainfallNetCDF(IRainfall): path: Path def read(self) -> xr.Dataset: - ds = xr.open_dataset(self.path) - required_vars = {"precip"} - required_coords = {"time", "lat", "lon"} - return validate_netcdf_forcing(ds, required_vars, required_coords) + required_vars = ("precip",) + required_coords = ("time", "lat", "lon") + with xr.open_dataset(self.path) as ds: + validated_ds = validate_netcdf_forcing(ds, required_vars, required_coords) + return validated_ds def save_additional(self, output_dir: Path | str | os.PathLike) -> None: if self.path: diff --git a/flood_adapt/object_model/hazard/forcing/wind.py b/flood_adapt/object_model/hazard/forcing/wind.py index 56238d36e..4020d4305 100644 --- a/flood_adapt/object_model/hazard/forcing/wind.py +++ b/flood_adapt/object_model/hazard/forcing/wind.py @@ -149,10 +149,11 @@ class WindNetCDF(IWind): path: Path def read(self) -> xr.Dataset: - ds = xr.open_dataset(self.path) - required_vars = {"wind10_v", "wind10_u", "press_msl"} - required_coords = {"time", "lat", "lon"} - return validate_netcdf_forcing(ds, required_vars, required_coords) + required_vars = ("wind10_v", "wind10_u", "press_msl") + required_coords = ("time", "lat", "lon") + with xr.open_dataset(self.path) as ds: + validated_ds = validate_netcdf_forcing(ds, required_vars, required_coords) + return validated_ds def save_additional(self, output_dir: Path | str | os.PathLike) -> None: if self.path: From 812ead0e7b55fd8dec0def7d2a6ced260c94edc5 Mon Sep 17 00:00:00 2001 From: Luuk Blom Date: Mon, 20 Jan 2025 16:56:49 +0100 Subject: [PATCH 08/10] pre-commit --- flood_adapt/object_model/hazard/forcing/rainfall.py | 1 - flood_adapt/object_model/hazard/forcing/wind.py | 1 - 2 files changed, 2 deletions(-) diff --git a/flood_adapt/object_model/hazard/forcing/rainfall.py b/flood_adapt/object_model/hazard/forcing/rainfall.py index 0c413a28c..1defe4b22 100644 --- a/flood_adapt/object_model/hazard/forcing/rainfall.py +++ b/flood_adapt/object_model/hazard/forcing/rainfall.py @@ -112,4 +112,3 @@ def save_additional(self, output_dir: Path | str | os.PathLike) -> None: output_dir.mkdir(parents=True, exist_ok=True) shutil.copy2(self.path, output_dir) self.path = output_dir / self.path.name - diff --git a/flood_adapt/object_model/hazard/forcing/wind.py b/flood_adapt/object_model/hazard/forcing/wind.py index 222d8dd73..07f67ac23 100644 --- a/flood_adapt/object_model/hazard/forcing/wind.py +++ b/flood_adapt/object_model/hazard/forcing/wind.py @@ -129,4 +129,3 @@ def save_additional(self, output_dir: Path | str | os.PathLike) -> None: output_dir.mkdir(parents=True, exist_ok=True) shutil.copy2(self.path, output_dir) self.path = output_dir / self.path.name - From f6469b1d1c60922639603655b1b794b24c777765 Mon Sep 17 00:00:00 2001 From: Luuk Blom Date: Thu, 23 Jan 2025 16:47:57 +0100 Subject: [PATCH 09/10] pre-commit --- flood_adapt/object_model/hazard/interface/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flood_adapt/object_model/hazard/interface/models.py b/flood_adapt/object_model/hazard/interface/models.py index 7ce091444..7b8e24d17 100644 --- a/flood_adapt/object_model/hazard/interface/models.py +++ b/flood_adapt/object_model/hazard/interface/models.py @@ -16,7 +16,6 @@ class TimeModel(BaseModel): end_time: datetime = REFERENCE_TIME + timedelta(days=1) time_step: Optional[timedelta] = None - @field_validator("start_time", "end_time", mode="before") @classmethod def try_parse_datetime(cls, value: str | datetime) -> datetime: From 5f7f102ce96458ffe7f44034d7d1c6f4500f2a31 Mon Sep 17 00:00:00 2001 From: Luuk Blom Date: Fri, 24 Jan 2025 10:10:01 +0100 Subject: [PATCH 10/10] update tests to use class initializers instead of dicts --- flood_adapt/object_model/scenario.py | 4 +- .../test_events/test_offshore.py | 50 +++++++++---------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/flood_adapt/object_model/scenario.py b/flood_adapt/object_model/scenario.py index c8e6f40ee..0a81ff78d 100644 --- a/flood_adapt/object_model/scenario.py +++ b/flood_adapt/object_model/scenario.py @@ -7,7 +7,7 @@ from flood_adapt.object_model.hazard.interface.events import IEvent from flood_adapt.object_model.interface.database_user import DatabaseUser from flood_adapt.object_model.interface.projections import IProjection -from flood_adapt.object_model.interface.scenarios import IScenario +from flood_adapt.object_model.interface.scenarios import IScenario, ScenarioModel from flood_adapt.object_model.interface.strategies import IStrategy from flood_adapt.object_model.utils import finished_file_exists, write_finished_file @@ -15,7 +15,7 @@ class Scenario(IScenario, DatabaseUser): """class holding all information related to a scenario.""" - def __init__(self, data: dict[str, Any]) -> None: + def __init__(self, data: dict[str, Any] | ScenarioModel) -> None: """Create a Scenario object.""" super().__init__(data) self.site_info = self.database.site diff --git a/tests/test_object_model/test_events/test_offshore.py b/tests/test_object_model/test_events/test_offshore.py index 12a9fe115..5c282a313 100644 --- a/tests/test_object_model/test_events/test_offshore.py +++ b/tests/test_object_model/test_events/test_offshore.py @@ -3,7 +3,10 @@ from flood_adapt.adapter.sfincs_offshore import OffshoreSfincsHandler from flood_adapt.dbs_classes.interface.database import IDatabase -from flood_adapt.object_model.hazard.event.historical import HistoricalEvent +from flood_adapt.object_model.hazard.event.historical import ( + HistoricalEvent, + HistoricalEventModel, +) from flood_adapt.object_model.hazard.forcing.discharge import ( DischargeConstant, ) @@ -16,31 +19,27 @@ from flood_adapt.object_model.hazard.forcing.wind import ( WindMeteo, ) -from flood_adapt.object_model.hazard.interface.events import ( - Mode, - Template, -) +from flood_adapt.object_model.hazard.interface.forcing import ForcingType from flood_adapt.object_model.hazard.interface.models import ( TimeModel, ) from flood_adapt.object_model.interface.config.sfincs import RiverModel +from flood_adapt.object_model.interface.scenarios import ScenarioModel from flood_adapt.object_model.io import unit_system as us from flood_adapt.object_model.scenario import Scenario @pytest.fixture() def setup_offshore_scenario(test_db: IDatabase): - event_attrs = ( - { - "name": "test_historical_offshore_meteo", - "time": TimeModel(), - "template": Template.Historical, - "mode": Mode.single_event, - "forcings": { - "WATERLEVEL": [WaterlevelModel()], - "WIND": [WindMeteo()], - "RAINFALL": [RainfallMeteo()], - "DISCHARGE": [ + event = HistoricalEvent( + HistoricalEventModel( + name="test_historical_offshore_meteo", + time=TimeModel(), + forcings={ + ForcingType.WATERLEVEL: [WaterlevelModel()], + ForcingType.WIND: [WindMeteo()], + ForcingType.RAINFALL: [RainfallMeteo()], + ForcingType.DISCHARGE: [ DischargeConstant( river=RiverModel( name="cooper", @@ -57,19 +56,18 @@ def setup_offshore_scenario(test_db: IDatabase): ) ], }, - }, + ) ) - - event = HistoricalEvent.load_dict(event_attrs) test_db.events.save(event) - scenario_attrs = { - "name": "test_scenario", - "event": event.attrs.name, - "projection": "current", - "strategy": "no_measures", - } - scn = Scenario.load_dict(scenario_attrs) + scn = Scenario( + ScenarioModel( + name="test_scenario", + event=event.attrs.name, + projection="current", + strategy="no_measures", + ) + ) test_db.scenarios.save(scn) return test_db, scn, event