Skip to content

Commit

Permalink
Restructure nesting to conform to new xarray DataTree.
Browse files Browse the repository at this point in the history
  • Loading branch information
rcjackson committed Dec 6, 2024
1 parent 09ebb06 commit fd796cd
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 79 deletions.
51 changes: 30 additions & 21 deletions examples/plot_grid_nesting_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
::
root
|---radar_1
|---radar_2
|---radar_n
|---inner_nest
|---radar_1
|---radar_2
|---radar_m
|---nest_0/radar_1
|---nest_0/radar_2
|---nest_0/radar_n
|---nest_1/radar_1
|---nest_1/radar_2
|---nest_1/radar_m
Each member of this tree is a DataTree itself. PyDDA will know if the
DataTree contains data from a radar when the name of the node begins
Expand All @@ -29,11 +28,11 @@
retrieval, allowing the user to vary the coefficients by grid level.
Using :code:`pydda.retrieval.get_dd_wind_field_nested` will allow PyDDA
to perform the retrieval on the outer grids first. It will then
perform on the inner grid levels, using the outer level grid as both the
horizontal boundary conditions and initialization for the retrieval in the inner
nest. Finally, PyDDA will update the winds in the outer grid by nearest-
neighbor interpolation of the finer grid into the overlapping portion between
to perform the retrieval on the 0th grid first. It will then
perform on the subsequent grid levels, using the previous nest as both the
horizontal boundary conditions and initialization for the retrieval in the next
nest. Finally, PyDDA will update the winds in the first grid by nearest-
neighbor interpolation of the latter grid into the overlapping portion between
the inner and outer grid level.
PyDDA will then return the retrieved wind fields as the "u", "v", and "w"
Expand All @@ -46,7 +45,7 @@
import pydda
import matplotlib.pyplot as plt
import warnings
from datatree import DataTree
from xarray import DataTree

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -84,18 +83,26 @@
)

"""
Enforce equal times for each grid. This is required for the DataTree structure since time is an
inherited dimension.
"""
test_coarse1["time"] = test_coarse0["time"]
test_fine0["time"] = test_coarse0["time"]
test_fine1["time"] = test_coarse1["time"]
"""
Provide the overlying grid structure as specified above.
"""
tree_dict = {
"/coarse/radar_ktlx": test_coarse0,
"/coarse/radar_kict": test_coarse1,
"/coarse/fine/radar_ktlx": test_fine0,
"/coarse/fine/radar_kict": test_fine1,
"/nest_0/radar_ktlx": test_coarse0,
"/nest_0/radar_kict": test_coarse1,
"/nest_1/radar_ktlx": test_fine0,
"/nest_1/radar_kict": test_fine1,
}

tree = DataTree.from_dict(tree_dict)
tree["/coarse/"].attrs = kwargs_dict
tree["/coarse/fine"].attrs = kwargs_dict
tree["/nest_0/"].attrs = kwargs_dict
tree["/nest_1/"].attrs = kwargs_dict

"""
Perform the retrieval
Expand All @@ -109,7 +116,7 @@

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
pydda.vis.plot_horiz_xsection_quiver(
grid_tree["coarse"],
grid_tree["nest_0"],
ax=ax[0],
level=5,
cmap="ChaseSpectral",
Expand All @@ -124,7 +131,7 @@
quiverkey_loc="bottom_right",
)
pydda.vis.plot_horiz_xsection_quiver(
grid_tree["coarse/fine"],
grid_tree["nest_1"],
ax=ax[1],
level=5,
cmap="ChaseSpectral",
Expand All @@ -138,3 +145,5 @@
quiver_spacing_y_km=50.0,
quiverkey_loc="bottom_right",
)

plt.show()
137 changes: 88 additions & 49 deletions pydda/retrieval/nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def get_dd_wind_field_nested(grid_tree: DataTree, **kwargs):
- The list of PyART grids for the given level of the grid
* - kwargs
- The list of key word arguments for input to the :py:func:`pydda.retrieval.get_dd_wind_field` function for the set of grids.
* - children
- The list of trees that are the children of this node.
The function will output the same tree, with the list of output grids of each level output to the 'output_grids'
member of the tree structure. If *kwargs* is set to None, then the input keyword arguments will be
Expand All @@ -34,53 +32,89 @@ def get_dd_wind_field_nested(grid_tree: DataTree, **kwargs):
child_list = list(grid_tree.children.keys())
grid_list = []
rad_names = []
for child in child_list:
if "radar" in child:
grid_list.append(grid_tree[child].to_dataset())
rad_names.append(child)
# We are at the parent level, look for nest 0
if "nest_0" in child_list:
for child in grid_tree["nest_0"].children.keys():
if "radar_" in child:
grid_list.append(grid_tree["nest_0"][child].to_dataset())
rad_names.append(child)
tree_attrs = grid_tree["nest_0"].attrs
in_parent = True
else:
tree_attrs = grid_tree.attrs
# We are in nest 1...n
for child in child_list:
if "radar_" in child:
grid_list.append(grid_tree[child].to_dataset())
rad_names.append(child)
in_parent = False

if len(list(grid_tree.attrs.keys())) == 0 and len(grid_list) > 0:
if len(list(tree_attrs.keys())) == 0 and len(grid_list) > 0:
output_grids, output_parameters = get_dd_wind_field(grid_list, **kwargs)
elif len(grid_list) > 0:
my_kwargs = grid_tree.attrs
my_kwargs = tree_attrs
output_grids, output_parameters = get_dd_wind_field(grid_list, **my_kwargs)
output_parameters = output_parameters.__dict__
grid_tree["weights"] = xr.DataArray(
output_parameters.pop("weights"), dims=("nradar", "z", "y", "x")
)
grid_tree["bg_weights"] = xr.DataArray(
output_parameters.pop("bg_weights"), dims=("z", "y", "x")
)
grid_tree["model_weights"] = xr.DataArray(
output_parameters.pop("model_weights"), dims=("nmodel", "z", "y", "x")
)
output_parameters.pop("u_model")
output_parameters.pop("v_model")
output_parameters.pop("w_model")
grid_tree["output_parameters"] = xr.DataArray([], attrs=output_parameters)

grid_tree.__setitem__("u", output_grids[0]["u"])
grid_tree.__setitem__("v", output_grids[0]["v"])
grid_tree.__setitem__("w", output_grids[0]["w"])
grid_tree["u"].attrs = output_grids[0]["u"].attrs
grid_tree["v"].attrs = output_grids[0]["v"].attrs
grid_tree["w"].attrs = output_grids[0]["w"].attrs
if in_parent is True:
grid_tree["nest_0"]["weights"] = xr.DataArray(
output_parameters.pop("weights"), dims=("nradars", "z", "y", "x")
)
grid_tree["nest_0"]["bg_weights"] = xr.DataArray(
output_parameters.pop("bg_weights"), dims=("z", "y", "x")
)
grid_tree["nest_0"]["model_weights"] = xr.DataArray(
output_parameters.pop("model_weights"), dims=("nmodel", "z", "y", "x")
)
output_parameters.pop("u_model")
output_parameters.pop("v_model")
output_parameters.pop("w_model")
grid_tree["nest_0"]["output_parameters"] = xr.DataArray(
[], attrs=output_parameters
)
grid_tree["nest_0"].__setitem__("u", output_grids[0]["u"])
grid_tree["nest_0"].__setitem__("v", output_grids[0]["v"])
grid_tree["nest_0"].__setitem__("w", output_grids[0]["w"])
grid_tree["nest_0"]["u"].attrs = output_grids[0]["u"].attrs
grid_tree["nest_0"]["v"].attrs = output_grids[0]["v"].attrs
grid_tree["nest_0"]["w"].attrs = output_grids[0]["w"].attrs
else:
grid_tree["weights"] = xr.DataArray(
output_parameters.pop("weights"), dims=("nradars", "z", "y", "x")
)
grid_tree["bg_weights"] = xr.DataArray(
output_parameters.pop("bg_weights"), dims=("z", "y", "x")
)
grid_tree["model_weights"] = xr.DataArray(
output_parameters.pop("model_weights"), dims=("nmodel", "z", "y", "x")
)
output_parameters.pop("u_model")
output_parameters.pop("v_model")
output_parameters.pop("w_model")
grid_tree["output_parameters"] = xr.DataArray([], attrs=output_parameters)
grid_tree.__setitem__("u", output_grids[0]["u"])
grid_tree.__setitem__("v", output_grids[0]["v"])
grid_tree.__setitem__("w", output_grids[0]["w"])
grid_tree["u"].attrs = output_grids[0]["u"].attrs
grid_tree["v"].attrs = output_grids[0]["v"].attrs
grid_tree["w"].attrs = output_grids[0]["w"].attrs

if child_list == []:
return grid_tree

nests = []
for child in child_list:
if "radar_" not in child:
if "nest_" in child:
nests.append(child)
nests = sorted(nests)
for child in nests:
for i, child in enumerate(nests):
if i == 0:
continue
# Only update child initalization if we are not in parent node
if len(grid_list) > 0:
temp_src = grid_tree[rad_names[0]].to_dataset()
temp_src["u"] = grid_tree.ds["u"]
temp_src["v"] = grid_tree.ds["v"]
temp_src["w"] = grid_tree.ds["w"]
temp_src = grid_tree[f"nest_{i-1}"][rad_names[0]].to_dataset()
temp_src["u"] = grid_tree[f"nest_{i-1}"].ds["u"]
temp_src["v"] = grid_tree[f"nest_{i-1}"].ds["v"]
temp_src["w"] = grid_tree[f"nest_{i-1}"].ds["w"]
input_grids = make_initialization_from_other_grid(
temp_src, grid_tree.children[child][rad_names[0]].to_dataset()
)
Expand Down Expand Up @@ -145,26 +179,31 @@ def get_dd_wind_field_nested(grid_tree: DataTree, **kwargs):
grid_tree.children[child]["w"].attrs = temp_tree.ds["w"].attrs

# Update parent grids from children
for child in child_list:
if "nest_" in child:
nests.append(child)
nests = sorted(nests)

if len(rad_names) > 0:
for child in nests:
temp_src = grid_tree.children[child][rad_names[0]].to_dataset()
temp_src["u"] = grid_tree.children[child].ds["u"]
temp_src["v"] = grid_tree.children[child].ds["v"]
temp_src["w"] = grid_tree.children[child].ds["w"]
temp_dest = grid_tree[rad_names[0]].to_dataset()
temp_dest["u"] = grid_tree.ds["u"]
temp_dest["v"] = grid_tree.ds["v"]
temp_dest["w"] = grid_tree.ds["w"]
for i, child in enumerate(nests[:-1]):
temp_src = grid_tree.children[nests[i + 1]][rad_names[0]].to_dataset()
temp_src["u"] = grid_tree.children[nests[i + 1]].ds["u"]
temp_src["v"] = grid_tree.children[nests[i + 1]].ds["v"]
temp_src["w"] = grid_tree.children[nests[i + 1]].ds["w"]
temp_dest = grid_tree.children[nests[i]][rad_names[0]].to_dataset()
temp_dest["u"] = grid_tree.children[nests[i]].ds["u"]
temp_dest["v"] = grid_tree.children[nests[i]].ds["v"]
temp_dest["w"] = grid_tree.children[nests[i]].ds["w"]
output_grids = make_initialization_from_other_grid(
temp_src,
temp_dest,
)
grid_tree.__setitem__("u", output_grids["u"])
grid_tree.__setitem__("v", output_grids["v"])
grid_tree.__setitem__("w", output_grids["w"])
grid_tree["u"].attrs = output_grids["u"].attrs
grid_tree["v"].attrs = output_grids["v"].attrs
grid_tree["w"].attrs = output_grids["w"].attrs
grid_tree.children[nests[i]].__setitem__("u", output_grids["u"])
grid_tree.children[nests[i]].__setitem__("v", output_grids["v"])
grid_tree.children[nests[i]].__setitem__("w", output_grids["w"])
grid_tree.children[nests[i]]["u"].attrs = output_grids["u"].attrs
grid_tree.children[nests[i]]["v"].attrs = output_grids["v"].attrs
grid_tree.children[nests[i]]["w"].attrs = output_grids["w"].attrs
return grid_tree


Expand Down
6 changes: 5 additions & 1 deletion pydda/tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ def test_get_iem_data():
Grid = pydda.io.read_from_pyart_grid(Grid)
station_obs = pydda.constraints.get_iem_obs(Grid)
names = [x["site_id"] for x in station_obs]
assert names == ["P28", "WLD", "WDG", "SWO", "END"]
assert "P28" in names
assert "WLD" in names
assert "WDG" in names
assert "SWO" in names
assert "END" in names


def test_hrrr_data():
Expand Down
20 changes: 12 additions & 8 deletions pydda/tests/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,21 +289,25 @@ def test_nested_retrieval():
engine="scipy",
)

test_coarse1["time"] = test_coarse0["time"]
test_fine0["time"] = test_coarse0["time"]
test_fine1["time"] = test_coarse1["time"]

tree_dict = {
"/coarse/radar_ktlx": test_coarse0,
"/coarse/radar_kict": test_coarse1,
"/coarse/fine/radar_ktlx": test_fine0,
"/coarse/fine/radar_kict": test_fine1,
"/nest_0/radar_ktlx": test_coarse0,
"/nest_0/radar_kict": test_coarse1,
"/nest_1/radar_ktlx": test_fine0,
"/nest_1/radar_kict": test_fine1,
}

tree = DataTree.from_dict(tree_dict)
tree["/coarse/"].attrs = kwargs_dict
tree["/coarse/fine"].attrs = kwargs_dict
tree["/nest_0/"].attrs = kwargs_dict
tree["/nest_1/"].attrs = kwargs_dict

grid_tree = pydda.retrieval.get_dd_wind_field_nested(tree)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
pydda.vis.plot_horiz_xsection_quiver(
grid_tree["coarse"],
grid_tree["nest_0"],
ax=ax[0],
level=5,
cmap="ChaseSpectral",
Expand All @@ -318,7 +322,7 @@ def test_nested_retrieval():
quiverkey_loc="bottom_right",
)
pydda.vis.plot_horiz_xsection_quiver(
grid_tree["coarse/fine"],
grid_tree["nest_1"],
ax=ax[1],
level=5,
cmap="ChaseSpectral",
Expand Down

0 comments on commit fd796cd

Please sign in to comment.