Skip to content

Commit 9edaa49

Browse files
Merge pull request #67 from oceanmodeling/feature/enforce_input_version
Check and enforce input version
2 parents 1efecce + edb27e5 commit 9edaa49

File tree

7 files changed

+261
-4
lines changed

7 files changed

+261
-4
lines changed

.github/workflows/tests.yml

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: tests
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
paths:
8+
- '**.py'
9+
- '.github/workflows/tests.yml'
10+
- 'pyproject.toml'
11+
pull_request:
12+
branches:
13+
- main
14+
15+
jobs:
16+
test:
17+
name: test
18+
runs-on: ${{ matrix.os }}
19+
strategy:
20+
matrix:
21+
os: [ ubuntu-latest ]
22+
python-version: [ '3.9', '3.10', '3.11' ]
23+
steps:
24+
- name: clone repository
25+
uses: actions/checkout@v4
26+
- name: conda virtual environment
27+
uses: mamba-org/setup-micromamba@v1
28+
with:
29+
init-shell: bash
30+
environment-file: environment.yml
31+
- name: install the package
32+
run: pip install ".[dev]"
33+
shell: micromamba-shell {0}
34+
- name: run tests
35+
run: pytest
36+
shell: micromamba-shell {0}

pyproject.toml

+7-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ description = "A set of scripts to generate probabilistic storm surge results!"
2121

2222
license = {file = "LICENSE"}
2323

24-
requires-python = ">= 3.8, < 3.12"
24+
requires-python = ">= 3.9, < 3.12"
2525

2626
dependencies = [
2727
"cartopy",
@@ -45,6 +45,7 @@ dependencies = [
4545
"numpy",
4646
"numba",
4747
"ocsmesh==1.5.3",
48+
"packaging",
4849
"pandas",
4950
"pyarrow",
5051
"pygeos",
@@ -65,6 +66,11 @@ dependencies = [
6566
"xarray",
6667
]
6768

69+
[project.optional-dependencies]
70+
dev = [
71+
"pytest"
72+
]
73+
6874
[tool.setuptools_scm]
6975
version_file = "stormworkflow/_version.py"
7076

stormworkflow/main.py

+61-2
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,74 @@
22
import logging
33
import os
44
import shlex
5+
import warnings
56
from importlib.resources import files
67
from argparse import ArgumentParser
78
from pathlib import Path
89

9-
import stormworkflow
1010
import yaml
11+
from packaging.version import Version
1112
try:
1213
from yaml import CLoader as Loader, CDumper as Dumper
1314
except ImportError:
1415
from yaml import Loader, Dumper
1516

17+
import stormworkflow
1618

1719
_logger = logging.getLogger(__file__)
1820

21+
CUR_INPUT_VER = Version('0.0.2')
22+
23+
24+
def _handle_input_v0_0_1_to_v0_0_2(inout_conf):
25+
26+
ver = Version(inout_conf['input_version'])
27+
28+
# Only update config if specified version matches the assumed one
29+
if ver != Version('0.0.1'):
30+
return ver
31+
32+
33+
_logger.info(
34+
"Adding perturbation variables for persistent RMW perturbation"
35+
)
36+
inout_conf['perturb_vars'] = [
37+
'cross_track',
38+
'along_track',
39+
'radius_of_maximum_winds_persistent',
40+
'max_sustained_wind_speed',
41+
]
42+
43+
return Version('0.0.2')
44+
45+
46+
def handle_input_version(inout_conf):
47+
48+
if 'input_version' not in inout_conf:
49+
ver = CUR_INPUT_VER
50+
warnings.warn(
51+
f"`input_version` is NOT specified in `input.yaml`; assuming {ver}"
52+
)
53+
inout_conf['input_version'] = str(ver)
54+
return
55+
56+
ver = Version(inout_conf['input_version'])
57+
58+
if ver > CUR_INPUT_VER:
59+
raise ValueError(
60+
f"Input version not supported! Max version supported is {CUR_INPUT_VER}"
61+
)
62+
63+
ver = _handle_input_v0_0_1_to_v0_0_2(inout_conf)
64+
65+
if ver != CUR_INPUT_VER:
66+
raise ValueError(
67+
f"Could NOT update input to the latest version! Updated to {ver}"
68+
)
69+
70+
inout_conf['input_version'] = str(ver)
71+
72+
1973
def main():
2074

2175
parser = ArgumentParser()
@@ -28,12 +82,17 @@ def main():
2882

2983
infile = args.configuration
3084
if infile is None:
31-
_logger.warn('No input configuration provided, using reference file!')
85+
warnings.warn(
86+
'No input configuration provided, using reference file!'
87+
)
3288
infile = refs.joinpath('input.yaml')
3389

3490
with open(infile, 'r') as yfile:
3591
conf = yaml.load(yfile, Loader=Loader)
3692

93+
handle_input_version(conf)
94+
# TODO: Write out the updated config as a yaml file
95+
3796
wf = scripts.joinpath('workflow.sh')
3897

3998
run_env = os.environ.copy()

stormworkflow/scripts/workflow.sh

+4-1
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,15 @@ function init {
4848
done
4949

5050
logfile=$run_dir/versions.info
51+
version $logfile stormworkflow
5152
version $logfile stormevents
5253
version $logfile ensembleperturbation
54+
version $logfile coupledmodeldriver
55+
version $logfile pyschism
5356
version $logfile ocsmesh
5457
echo "SCHISM: see solver.version each outputs dir" >> $logfile
5558

56-
cp $input_file $run_dir/input.yaml
59+
cp $input_file $run_dir/input_asis.yaml
5760

5861
echo $run_dir
5962
}

tests/data/refs/input_v0.0.1.yaml

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
---
2+
input_version: 0.0.1
3+
4+
storm: "florence"
5+
year: 2018
6+
suffix: ""
7+
subset_mesh: 1
8+
hr_prelandfall: -1
9+
past_forecast: 1
10+
hydrology: 0
11+
use_wwm: 0
12+
pahm_model: "gahm"
13+
num_perturb: 2
14+
sample_rule: "korobov"
15+
spinup_exec: "pschism_PAHM_TVD-VL"
16+
hotstart_exec: "pschism_PAHM_TVD-VL"
17+
18+
hpc_solver_nnodes: 3
19+
hpc_solver_ntasks: 108
20+
hpc_account: ""
21+
hpc_partition: ""
22+
23+
RUN_OUT: ""
24+
L_NWM_DATASET: ""
25+
L_TPXO_DATASET: ""
26+
L_LEADTIMES_DATASET: ""
27+
L_TRACK_DIR: ""
28+
L_DEM_HI: ""
29+
L_DEM_LO: ""
30+
L_MESH_HI: ""
31+
L_MESH_LO: ""
32+
L_SHP_DIR: ""
33+
34+
TMPDIR: "/tmp"
35+
PATH_APPEND: ""
36+
37+
L_SOLVE_MODULES:
38+
- "intel/2022.1.2"
39+
- "impi/2022.1.2"
40+
- "netcdf"

tests/data/refs/input_v0.0.2.yaml

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
---
2+
input_version: 0.0.2
3+
4+
storm: "florence"
5+
year: 2018
6+
suffix: ""
7+
subset_mesh: 1
8+
hr_prelandfall: -1
9+
past_forecast: 1
10+
hydrology: 0
11+
use_wwm: 0
12+
pahm_model: "gahm"
13+
num_perturb: 2
14+
sample_rule: "korobov"
15+
spinup_exec: "pschism_PAHM_TVD-VL"
16+
hotstart_exec: "pschism_PAHM_TVD-VL"
17+
perturb_vars:
18+
- 'cross_track'
19+
- 'along_track'
20+
# - 'radius_of_maximum_winds'
21+
- 'radius_of_maximum_winds_persistent'
22+
- 'max_sustained_wind_speed'
23+
24+
hpc_solver_nnodes: 3
25+
hpc_solver_ntasks: 108
26+
hpc_account: ""
27+
hpc_partition: ""
28+
29+
RUN_OUT: ""
30+
L_NWM_DATASET: ""
31+
L_TPXO_DATASET: ""
32+
L_LEADTIMES_DATASET: ""
33+
L_TRACK_DIR: ""
34+
L_DEM_HI: ""
35+
L_DEM_LO: ""
36+
L_MESH_HI: ""
37+
L_MESH_LO: ""
38+
L_SHP_DIR: ""
39+
40+
TMPDIR: "/tmp"
41+
PATH_APPEND: ""
42+
43+
L_SOLVE_MODULES:
44+
- "intel/2022.1.2"
45+
- "impi/2022.1.2"
46+
- "netcdf"

tests/test_input_version.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from copy import deepcopy
2+
from importlib.resources import files
3+
4+
import pytest
5+
import yaml
6+
from packaging.version import Version
7+
from yaml import Loader, Dumper
8+
9+
from stormworkflow.main import handle_input_version, CUR_INPUT_VER
10+
11+
12+
refs = files('tests.data.refs')
13+
input_v0_0_1 = refs.joinpath('input_v0.0.1.yaml')
14+
input_v0_0_2 = refs.joinpath('input_v0.0.2.yaml')
15+
16+
17+
def read_conf(infile):
18+
with open(infile, 'r') as yfile:
19+
conf = yaml.load(yfile, Loader=Loader)
20+
return conf
21+
22+
23+
@pytest.fixture
24+
def conf_v0_0_1():
25+
return read_conf(input_v0_0_1)
26+
27+
28+
@pytest.fixture
29+
def conf_v0_0_2():
30+
return read_conf(input_v0_0_2)
31+
32+
33+
@pytest.fixture
34+
def conf_latest(conf_v0_0_2):
35+
return conf_v0_0_2
36+
37+
38+
def test_no_version_specified(conf_latest):
39+
conf_latest.pop('input_version')
40+
with pytest.warns(UserWarning):
41+
handle_input_version(conf_latest)
42+
43+
assert conf_latest['input_version'] == str(CUR_INPUT_VER)
44+
45+
46+
def test_invalid_version_specified(conf_latest):
47+
48+
invalid_1 = deepcopy(conf_latest)
49+
invalid_1['input_version'] = (
50+
f'{CUR_INPUT_VER.major}.{CUR_INPUT_VER.minor}.{CUR_INPUT_VER.micro + 1}'
51+
)
52+
with pytest.raises(ValueError) as e:
53+
handle_input_version(invalid_1)
54+
55+
assert "max" in str(e.value).lower()
56+
57+
58+
invalid_2 = deepcopy(conf_latest)
59+
invalid_2['input_version'] = 'a.b.c'
60+
with pytest.raises(ValueError) as e:
61+
handle_input_version(invalid_2)
62+
assert "invalid version" in str(e.value).lower()
63+
64+
65+
def test_v0_0_1_to_v0_0_2(conf_v0_0_1, conf_v0_0_2):
66+
handle_input_version(conf_v0_0_1)
67+
assert conf_v0_0_2 == conf_v0_0_1

0 commit comments

Comments
 (0)