6
6
import xarray as xr
7
7
8
8
from polaris import Step
9
- from polaris .ocean .resolution import resolution_to_subdir
9
+ from polaris .ocean .convergence import (
10
+ get_resolution_for_task ,
11
+ get_timestep_for_task ,
12
+ )
10
13
from polaris .ocean .tasks .inertial_gravity_wave .exact_solution import (
11
14
ExactSolution ,
12
15
)
@@ -20,10 +23,31 @@ class Viz(Step):
20
23
21
24
Attributes
22
25
----------
23
- resolutions : list of float
24
- The resolutions of the meshes that have been run
26
+ dependencies_dict : dict of dict of polaris.Steps
27
+ The dependencies of this step must be given as separate keys in the
28
+ dict:
29
+
30
+ mesh : dict of polaris.Steps
31
+ Keys of the dict correspond to `refinement_factors`
32
+ Values of the dict are polaris.Steps, which must have the
33
+ attribute `path`, the path to `base_mesh.nc` of that
34
+ resolution
35
+ init : dict of polaris.Steps
36
+ Keys of the dict correspond to `refinement_factors`
37
+ Values of the dict are polaris.Steps, which must have the
38
+ attribute `path`, the path to `initial_state.nc` of that
39
+ resolution
40
+ forward : dict of polaris.Steps
41
+ Keys of the dict correspond to `refinement_factors`
42
+ Values of the dict are polaris.Steps, which must have the
43
+ attribute `path`, the path to `forward.nc` of that
44
+ resolution
45
+
46
+ refinement : str, optional
47
+ Refinement type. One of 'space', 'time' or 'both' indicating both
48
+ space and time
25
49
"""
26
- def __init__ (self , component , resolutions , taskdir ):
50
+ def __init__ (self , component , taskdir , dependencies , refinement = 'both' ):
27
51
"""
28
52
Create the step
29
53
@@ -32,37 +56,81 @@ def __init__(self, component, resolutions, taskdir):
32
56
component : polaris.Component
33
57
The component the step belongs to
34
58
35
- resolutions : list of float
36
- The resolutions of the meshes that have been run
37
-
38
59
taskdir : str
39
60
The subdirectory that the task belongs to
61
+
62
+ dependencies : dict of dict of polaris.Steps
63
+ The dependencies of this step must be given as separate keys in the
64
+ dict:
65
+
66
+ mesh : dict of polaris.Steps
67
+ Keys of the dict correspond to `refinement_factors`
68
+ Values of the dict are polaris.Steps, which must have the
69
+ attribute `path`, the path to `base_mesh.nc` of that
70
+ resolution
71
+ init : dict of polaris.Steps
72
+ Keys of the dict correspond to `refinement_factors`
73
+ Values of the dict are polaris.Steps, which must have the
74
+ attribute `path`, the path to `initial_state.nc` of that
75
+ resolution
76
+ forward : dict of polaris.Steps
77
+ Keys of the dict correspond to `refinement_factors`
78
+ Values of the dict are polaris.Steps, which must have the
79
+ attribute `path`, the path to `forward.nc` of that
80
+ resolution
81
+
82
+ refinement : str, optional
83
+ Refinement type. One of 'space', 'time' or 'both' indicating both
84
+ space and time
40
85
"""
41
86
super ().__init__ (component = component , name = 'viz' , indir = taskdir )
42
- self .resolutions = resolutions
87
+ self .dependencies_dict = dependencies
88
+ self .refinement = refinement
43
89
44
- for resolution in resolutions :
45
- mesh_name = resolution_to_subdir (resolution )
90
+ self .add_output_file ('comparison.png' )
91
+
92
+ def setup (self ):
93
+ """
94
+ Add input files based on resolutions, which may have been changed by
95
+ user config options
96
+ """
97
+ super ().setup ()
98
+ config = self .config
99
+ dependencies = self .dependencies_dict
100
+ if self .refinement == 'time' :
101
+ option = 'refinement_factors_time'
102
+ else :
103
+ option = 'refinement_factors_space'
104
+ refinement_factors = config .getlist ('convergence' , option ,
105
+ dtype = float )
106
+ for refinement_factor in refinement_factors :
107
+ base_mesh = dependencies ['mesh' ][refinement_factor ]
108
+ init = dependencies ['init' ][refinement_factor ]
109
+ forward = dependencies ['forward' ][refinement_factor ]
46
110
self .add_input_file (
47
- filename = f'mesh_ { mesh_name } .nc' ,
48
- target = f'../init/ { mesh_name } /culled_mesh .nc' )
111
+ filename = f'mesh_r { refinement_factor :02g } .nc' ,
112
+ work_dir_target = f'{ base_mesh . path } /base_mesh .nc' )
49
113
self .add_input_file (
50
- filename = f'init_ { mesh_name } .nc' ,
51
- target = f'../init/ { mesh_name } /initial_state.nc' )
114
+ filename = f'init_r { refinement_factor :02g } .nc' ,
115
+ work_dir_target = f'{ init . path } /initial_state.nc' )
52
116
self .add_input_file (
53
- filename = f'output_{ mesh_name } .nc' ,
54
- target = f'../forward/{ mesh_name } /output.nc' )
55
-
56
- self .add_output_file ('comparison.png' )
117
+ filename = f'output_r{ refinement_factor :02g} .nc' ,
118
+ work_dir_target = f'{ forward .path } /output.nc' )
57
119
58
120
def run (self ):
59
121
"""
60
122
Run this step of the test case
61
123
"""
62
124
plt .switch_backend ('Agg' )
63
125
config = self .config
64
- resolutions = self .resolutions
65
- nres = len (resolutions )
126
+
127
+ if self .refinement == 'time' :
128
+ option = 'refinement_factors_time'
129
+ else :
130
+ option = 'refinement_factors_space'
131
+ refinement_factors = config .getlist ('convergence' , option ,
132
+ dtype = float )
133
+ nres = len (refinement_factors )
66
134
67
135
section = config ['inertial_gravity_wave' ]
68
136
eta0 = section .getfloat ('ssh_amplitude' )
@@ -71,11 +139,12 @@ def run(self):
71
139
fig , axes = plt .subplots (nrows = nres , ncols = 3 , figsize = (12 , 2 * nres ))
72
140
rmse = []
73
141
error_range = None
74
- for i , res in enumerate (resolutions ):
75
- mesh_name = resolution_to_subdir (res )
76
- ds_mesh = xr .open_dataset (f'mesh_{ mesh_name } .nc' )
77
- ds_init = xr .open_dataset (f'init_{ mesh_name } .nc' )
78
- ds = xr .open_dataset (f'output_{ mesh_name } .nc' )
142
+ for i , refinement_factor in enumerate (refinement_factors ):
143
+ resolution = get_resolution_for_task (
144
+ config , refinement_factor , refinement = self .refinement )
145
+ ds_mesh = xr .open_dataset (f'mesh_r{ refinement_factor :02g} .nc' )
146
+ ds_init = xr .open_dataset (f'init_r{ refinement_factor :02g} .nc' )
147
+ ds = xr .open_dataset (f'output_r{ refinement_factor :02g} .nc' )
79
148
exact = ExactSolution (ds_init , config )
80
149
81
150
t0 = datetime .datetime .strptime (ds .xtime .values [0 ].decode (),
@@ -112,8 +181,13 @@ def run(self):
112
181
axes [0 , 2 ].set_title ('Error (Numerical - Analytical)' )
113
182
114
183
pad = 5
115
- for ax , res in zip (axes [:, 0 ], resolutions ):
116
- ax .annotate (f'{ res } km' , xy = (0 , 0.5 ),
184
+ for ax , refinement_factor in zip (axes [:, 0 ], refinement_factors ):
185
+ timestep , _ = get_timestep_for_task (
186
+ config , refinement_factor , refinement = self .refinement )
187
+ resolution = get_resolution_for_task (
188
+ config , refinement_factor , refinement = self .refinement )
189
+
190
+ ax .annotate (f'{ resolution } km\n { timestep } s' , xy = (0 , 0.5 ),
117
191
xytext = (- ax .yaxis .labelpad - pad , 0 ),
118
192
xycoords = ax .yaxis .label , textcoords = 'offset points' ,
119
193
size = 'large' , ha = 'right' , va = 'center' )
0 commit comments