Skip to content

Commit

Permalink
cherrypick commits from updateSVN.
Browse files Browse the repository at this point in the history
refactor tests for projection & site
  • Loading branch information
LuukBlom committed Feb 7, 2024
1 parent bfad889 commit a2ac08d
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 101 deletions.
75 changes: 75 additions & 0 deletions flood_adapt/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os
from pathlib import Path

import tomli


def set_database_root(database_root: Path) -> None:
"""
Sets the root directory for the database.
Args:
database_root (str): The new root directory path.
"""
if not Path(database_root).is_dir():
raise ValueError(f"{database_root} is not a directory")
os.environ["DATABASE_ROOT"] = str(database_root)


def set_system_folder(system_folder: Path) -> None:
"""
Sets the system folder path.
Args:
system_folder (str): The new system folder path.
"""
if not Path(system_folder).is_dir():
raise ValueError(f"{system_folder} is not a directory")
os.environ["SYSTEM_FOLDER"] = str(system_folder)


def set_site_name(site_name: str) -> None:
"""
Sets the site_name.
Args:
site_name (str): The new system folder path.
"""
db_root = os.environ.get("DATABASE_ROOT")
full_site_path = Path(db_root, site_name)
if not full_site_path.is_dir():
raise ValueError(f"{full_site_path} is not a directory")
os.environ["SITE_NAME"] = str(site_name)


def parse_config(config_path: Path) -> dict:
with open(config_path, "rb") as f:
config = tomli.load(f)

# Parse the config file
if "database_root" not in config:
raise ValueError(f"database_root not found in {config_path}")
set_database_root(config["database_root"])

if "system_folder" not in config:
raise ValueError(f"system_folder not found in {config_path}")
set_system_folder(config["system_folder"])

if "site_name" not in config:
raise ValueError(f"site_name not found in {config_path}")
set_site_name(config["site_name"])

return config


def main() -> None:
# Get the directory that contains the config.toml file (e.g. the one above this one)
config_dir = Path(__file__).parent.parent

# Get the path to the config.toml file
config_path = config_dir / "config.toml"
parse_config(config_path)


if __name__ == "__main__":
main()
65 changes: 47 additions & 18 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import shutil
import subprocess
from pathlib import Path

import pytest

from flood_adapt.api.startup import read_database
from flood_adapt.config import parse_config, set_site_name


def get_file_structure(path: str) -> list:
Expand Down Expand Up @@ -32,25 +34,52 @@ def remove_files_and_folders(path, file_structure):
print(f"PermissionError: {root}")


@pytest.fixture
def test_db():
"""This fixture is used for testing in general to setup the test database,
perform the test, and clean the database after each test.
It is used by other fixtures to set up and clean the test_database"""
@pytest.fixture(
autouse=True, scope="session"
) # This fixture is only run once per session
def updatedSVN():
parse_config(
"config.toml"
) # Set the database root, system folder, based on the config file
set_site_name("charleston_test") # set the site name to the test database
updateSVN_file_path = Path(__file__).parent / "updateSVN.py"
subprocess.run(
[str(updateSVN_file_path), os.environ["DATABASE_ROOT"]],
shell=True,
capture_output=True,
)

# Get the database file structure before the test
rootPath = Path().absolute() / "tests" / "test_database" # the path to the database
site_name = "Charleston" # the name of the test site

database_path = str(rootPath.joinpath(site_name))
file_structure = get_file_structure(database_path)
dbs = read_database(rootPath, site_name)
def make_db_fixture(scope):
"""
This fixture is used for testing in general.
It functions as follows:
1) Setup database controller
2) Perform all tests in scope
3) Clean the database
Scope can be one of the following: "function", "class", "module", "package", "session"
"""
if scope not in ["function", "class", "module", "package", "session"]:
raise ValueError(f"Invalid fixture scope: {scope}")

# NOTE: to access the contents of this function in the test,
# the first line of your test needs to initialize the yielded variables:
# 'dbs, folders = test_db'
@pytest.fixture(scope=scope)
def _db_fixture():
database_path = os.environ["DATABASE_ROOT"]
site_name = os.environ["SITE_NAME"]
file_structure = get_file_structure(database_path)
dbs = read_database(database_path, site_name)
yield dbs
remove_files_and_folders(database_path, file_structure)

# Run the test
yield dbs
# Remove all files and folders that were not present before the test
remove_files_and_folders(database_path, file_structure)
return _db_fixture


# NOTE: to access the contents the fixtures in the test functions,
# the fixture name needs to be passed as an argument to the test function.
# the first line of your test needs to initialize the yielded variables:
# 'dbs = _db_fixture'
test_db = make_db_fixture("function")
test_db_class = make_db_fixture("class")
test_db_module = make_db_fixture("module")
test_db_package = make_db_fixture("package")
test_db_session = make_db_fixture("session")
1 change: 0 additions & 1 deletion tests/test_api/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
test_database = api_startup.read_database(database_path, test_site_name)


# TODO How to delete the scenario after tests have been run?
@pytest.fixture(scope="session")
def scenario_event():
name = "current_extreme12ft_no_measures"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_object_model/test_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_projection_save_createsFile(test_db, test_dict):
test_db.input_path / "projections" / "new_projection" / "new_projection.toml"
)
test_projection.save(file_path)
assert file_path.exists()
assert file_path.is_file()


def test_projection_loadFile_checkAllAttrs(test_db, test_dict):
Expand Down
153 changes: 89 additions & 64 deletions tests/test_object_model/test_scenarios.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

import numpy as np
import pandas as pd
import pytest
Expand All @@ -19,73 +17,67 @@
from flood_adapt.object_model.scenario import Scenario
from flood_adapt.object_model.site import Site

test_database = Path().absolute() / "tests" / "test_database"


def test_scenario_class(test_db):
scenario_toml = (
test_database
/ "charleston"
/ "input"
@pytest.fixture(autouse=True)
def test_tomls(test_db):
toml_files = [
test_db.input_path
/ "scenarios"
/ "all_projections_extreme12ft_strategy_comb"
/ "all_projections_extreme12ft_strategy_comb.toml"
)
assert scenario_toml.is_file()
/ "all_projections_extreme12ft_strategy_comb.toml",
test_db.input_path
/ "scenarios"
/ "current_extreme12ft_no_measures"
/ "current_extreme12ft_no_measures.toml",
]
return toml_files


@pytest.fixture(autouse=True)
def test_scenarios(test_db, test_tomls):
test_scenarios = {
toml_file.name: Scenario.load_file(toml_file) for toml_file in test_tomls
}
return test_scenarios


def test_initObjectModel_validInput(test_db, test_scenarios):
test_scenario = test_scenarios["all_projections_extreme12ft_strategy_comb.toml"]

scenario = Scenario.load_file(scenario_toml)
scenario.init_object_model()
test_scenario.init_object_model()

assert isinstance(scenario.site_info, Site)
assert isinstance(scenario.direct_impacts, DirectImpacts)
assert isinstance(test_scenario.site_info, Site)
assert isinstance(test_scenario.direct_impacts, DirectImpacts)
assert isinstance(
scenario.direct_impacts.socio_economic_change, SocioEconomicChange
test_scenario.direct_impacts.socio_economic_change, SocioEconomicChange
)
assert isinstance(scenario.direct_impacts.impact_strategy, ImpactStrategy)
assert isinstance(scenario.direct_impacts.hazard, Hazard)
assert isinstance(scenario.direct_impacts.hazard.hazard_strategy, HazardStrategy)
assert isinstance(test_scenario.direct_impacts.impact_strategy, ImpactStrategy)
assert isinstance(test_scenario.direct_impacts.hazard, Hazard)
assert isinstance(
scenario.direct_impacts.hazard.physical_projection, PhysicalProjection
test_scenario.direct_impacts.hazard.hazard_strategy, HazardStrategy
)
assert isinstance(scenario.direct_impacts.hazard.event_list[0], Synthetic)


def test_hazard_load():
test_toml = (
test_database
/ "charleston"
/ "input"
/ "scenarios"
/ "current_extreme12ft_no_measures"
/ "current_extreme12ft_no_measures.toml"
assert isinstance(
test_scenario.direct_impacts.hazard.physical_projection, PhysicalProjection
)
assert isinstance(test_scenario.direct_impacts.hazard.event_list[0], Synthetic)


assert test_toml.is_file()
scenario = Scenario.load_file(test_toml)
scenario.init_object_model()
def test_hazard_load(test_db, test_scenarios):
test_scenario = test_scenarios["current_extreme12ft_no_measures.toml"]

hazard = scenario.direct_impacts.hazard
test_scenario.init_object_model()
hazard = test_scenario.direct_impacts.hazard

assert hazard.event_list[0].attrs.timing == "idealized"
assert isinstance(hazard.event_list[0].attrs.tide, TideModel)


def test_scs_rainfall(test_db):
test_toml = (
test_database
/ "charleston"
/ "input"
/ "scenarios"
/ "current_extreme12ft_no_measures"
/ "current_extreme12ft_no_measures.toml"
)

assert test_toml.is_file()
test_scenario = test_scenarios["current_extreme12ft_no_measures.toml"]

scenario = Scenario.load_file(test_toml)
scenario.init_object_model()
test_scenario.init_object_model()

hazard = scenario.direct_impacts.hazard
hazard = test_scenario.direct_impacts.hazard

hazard.event.attrs.rainfall = RainfallModel(
source="shape",
Expand Down Expand Up @@ -117,20 +109,53 @@ def test_scs_rainfall(test_db):
assert np.abs(cum_rainfall_ts - cum_rainfall_toml) < 0.01


@pytest.mark.skip(reason="No metric file to read from")
def test_infographic(test_db):
test_toml = (
test_database
/ "charleston"
/ "input"
/ "scenarios"
/ "current_extreme12ft_no_measures"
/ "current_extreme12ft_no_measures.toml"
)
class Test_scenario_run:
@pytest.fixture(scope="class")
def test_scenario_before_after_run(self, test_db_class):
test_scenario_copy = (
test_db_class.input_path
/ "scenarios"
/ "current_extreme12ft_no_measures"
/ "current_extreme12ft_no_measures.toml"
)
test_scenario_to_run_path = (
test_db_class.input_path
/ "scenarios"
/ "test_run_scenario"
/ "test_run_scenario.toml"
)

assert test_toml.is_file()
test_scenario_copy = Scenario.load_file(test_scenario_copy)
assert test_scenario_copy.has_run is False

# use event template to get the associated Event child class
test_scenario = Scenario.load_file(test_toml)
test_scenario.init_object_model()
test_scenario.infographic()
test_scenario_copy.save(test_scenario_to_run_path)
test_scenario_to_run = Scenario.load_file(test_scenario_to_run_path)

test_scenario_to_run.run()

return test_scenario_copy, test_scenario_to_run

def test_run_notRunYet(self, test_scenario_before_after_run):
before_run, after_run = test_scenario_before_after_run

assert before_run.has_run is False
assert before_run.direct_impacts.hazard.event_list[0].results is None
assert before_run.direct_impacts.impact_strategy.results is None
assert before_run.direct_impacts.socio_economic_change.results is None
assert before_run.direct_impacts.results is None
assert before_run.direct_impacts.hazard.has_run is False

def test_infographic(test_db):
test_toml = (
test_db.input_path
/ "scenarios"
/ "current_extreme12ft_no_measures"
/ "current_extreme12ft_no_measures.toml"
)

assert test_toml.is_file()

# use event template to get the associated Event child class
test_scenario = Scenario.load_file(test_toml)
test_scenario.init_object_model()
test_scenario.infographic()
Loading

0 comments on commit a2ac08d

Please sign in to comment.