diff --git a/examples/plot_grid_nesting_example.py b/examples/plot_grid_nesting_example.py index ea777828..2b93ca3a 100644 --- a/examples/plot_grid_nesting_example.py +++ b/examples/plot_grid_nesting_example.py @@ -12,13 +12,12 @@ :: root - |---radar_1 - |---radar_2 - |---radar_n - |---inner_nest - |---radar_1 - |---radar_2 - |---radar_m + |---nest_0/radar_1 + |---nest_0/radar_2 + |---nest_0/radar_n + |---nest_1/radar_1 + |---nest_1/radar_2 + |---nest_1/radar_m Each member of this tree is a DataTree itself. PyDDA will know if the DataTree contains data from a radar when the name of the node begins @@ -29,11 +28,11 @@ retrieval, allowing the user to vary the coefficients by grid level. Using :code:`pydda.retrieval.get_dd_wind_field_nested` will allow PyDDA -to perform the retrieval on the outer grids first. It will then -perform on the inner grid levels, using the outer level grid as both the -horizontal boundary conditions and initialization for the retrieval in the inner -nest. Finally, PyDDA will update the winds in the outer grid by nearest- -neighbor interpolation of the finer grid into the overlapping portion between +to perform the retrieval on the 0th grid first. It will then +perform on the subsequent grid levels, using the previous nest as both the +horizontal boundary conditions and initialization for the retrieval in the next +nest. Finally, PyDDA will update the winds in the first grid by nearest- +neighbor interpolation of the latter grid into the overlapping portion between the inner and outer grid level. PyDDA will then return the retrieved wind fields as the "u", "v", and "w" @@ -46,7 +45,7 @@ import pydda import matplotlib.pyplot as plt import warnings -from datatree import DataTree +from xarray import DataTree warnings.filterwarnings("ignore") @@ -84,18 +83,26 @@ ) """ +Enforce equal times for each grid. This is required for the DataTree structure since time is an +inherited dimension. +""" +test_coarse1["time"] = test_coarse0["time"] +test_fine0["time"] = test_coarse0["time"] +test_fine1["time"] = test_coarse1["time"] +""" + Provide the overlying grid structure as specified above. """ tree_dict = { - "/coarse/radar_ktlx": test_coarse0, - "/coarse/radar_kict": test_coarse1, - "/coarse/fine/radar_ktlx": test_fine0, - "/coarse/fine/radar_kict": test_fine1, + "/nest_0/radar_ktlx": test_coarse0, + "/nest_0/radar_kict": test_coarse1, + "/nest_1/radar_ktlx": test_fine0, + "/nest_1/radar_kict": test_fine1, } tree = DataTree.from_dict(tree_dict) -tree["/coarse/"].attrs = kwargs_dict -tree["/coarse/fine"].attrs = kwargs_dict +tree["/nest_0/"].attrs = kwargs_dict +tree["/nest_1/"].attrs = kwargs_dict """ Perform the retrieval @@ -109,7 +116,7 @@ fig, ax = plt.subplots(1, 2, figsize=(10, 5)) pydda.vis.plot_horiz_xsection_quiver( - grid_tree["coarse"], + grid_tree["nest_0"], ax=ax[0], level=5, cmap="ChaseSpectral", @@ -124,7 +131,7 @@ quiverkey_loc="bottom_right", ) pydda.vis.plot_horiz_xsection_quiver( - grid_tree["coarse/fine"], + grid_tree["nest_1"], ax=ax[1], level=5, cmap="ChaseSpectral", @@ -138,3 +145,5 @@ quiver_spacing_y_km=50.0, quiverkey_loc="bottom_right", ) + +plt.show() diff --git a/pydda/retrieval/nesting.py b/pydda/retrieval/nesting.py index 931f03be..d417bdaa 100644 --- a/pydda/retrieval/nesting.py +++ b/pydda/retrieval/nesting.py @@ -22,8 +22,6 @@ def get_dd_wind_field_nested(grid_tree: DataTree, **kwargs): - The list of PyART grids for the given level of the grid * - kwargs - The list of key word arguments for input to the :py:func:`pydda.retrieval.get_dd_wind_field` function for the set of grids. - * - children - - The list of trees that are the children of this node. The function will output the same tree, with the list of output grids of each level output to the 'output_grids' member of the tree structure. If *kwargs* is set to None, then the input keyword arguments will be @@ -34,53 +32,89 @@ def get_dd_wind_field_nested(grid_tree: DataTree, **kwargs): child_list = list(grid_tree.children.keys()) grid_list = [] rad_names = [] - for child in child_list: - if "radar" in child: - grid_list.append(grid_tree[child].to_dataset()) - rad_names.append(child) + # We are at the parent level, look for nest 0 + if "nest_0" in child_list: + for child in grid_tree["nest_0"].children.keys(): + if "radar_" in child: + grid_list.append(grid_tree["nest_0"][child].to_dataset()) + rad_names.append(child) + tree_attrs = grid_tree["nest_0"].attrs + in_parent = True + else: + tree_attrs = grid_tree.attrs + # We are in nest 1...n + for child in child_list: + if "radar_" in child: + grid_list.append(grid_tree[child].to_dataset()) + rad_names.append(child) + in_parent = False - if len(list(grid_tree.attrs.keys())) == 0 and len(grid_list) > 0: + if len(list(tree_attrs.keys())) == 0 and len(grid_list) > 0: output_grids, output_parameters = get_dd_wind_field(grid_list, **kwargs) elif len(grid_list) > 0: - my_kwargs = grid_tree.attrs + my_kwargs = tree_attrs output_grids, output_parameters = get_dd_wind_field(grid_list, **my_kwargs) output_parameters = output_parameters.__dict__ - grid_tree["weights"] = xr.DataArray( - output_parameters.pop("weights"), dims=("nradar", "z", "y", "x") - ) - grid_tree["bg_weights"] = xr.DataArray( - output_parameters.pop("bg_weights"), dims=("z", "y", "x") - ) - grid_tree["model_weights"] = xr.DataArray( - output_parameters.pop("model_weights"), dims=("nmodel", "z", "y", "x") - ) - output_parameters.pop("u_model") - output_parameters.pop("v_model") - output_parameters.pop("w_model") - grid_tree["output_parameters"] = xr.DataArray([], attrs=output_parameters) - - grid_tree.__setitem__("u", output_grids[0]["u"]) - grid_tree.__setitem__("v", output_grids[0]["v"]) - grid_tree.__setitem__("w", output_grids[0]["w"]) - grid_tree["u"].attrs = output_grids[0]["u"].attrs - grid_tree["v"].attrs = output_grids[0]["v"].attrs - grid_tree["w"].attrs = output_grids[0]["w"].attrs + if in_parent is True: + grid_tree["nest_0"]["weights"] = xr.DataArray( + output_parameters.pop("weights"), dims=("nradars", "z", "y", "x") + ) + grid_tree["nest_0"]["bg_weights"] = xr.DataArray( + output_parameters.pop("bg_weights"), dims=("z", "y", "x") + ) + grid_tree["nest_0"]["model_weights"] = xr.DataArray( + output_parameters.pop("model_weights"), dims=("nmodel", "z", "y", "x") + ) + output_parameters.pop("u_model") + output_parameters.pop("v_model") + output_parameters.pop("w_model") + grid_tree["nest_0"]["output_parameters"] = xr.DataArray( + [], attrs=output_parameters + ) + grid_tree["nest_0"].__setitem__("u", output_grids[0]["u"]) + grid_tree["nest_0"].__setitem__("v", output_grids[0]["v"]) + grid_tree["nest_0"].__setitem__("w", output_grids[0]["w"]) + grid_tree["nest_0"]["u"].attrs = output_grids[0]["u"].attrs + grid_tree["nest_0"]["v"].attrs = output_grids[0]["v"].attrs + grid_tree["nest_0"]["w"].attrs = output_grids[0]["w"].attrs + else: + grid_tree["weights"] = xr.DataArray( + output_parameters.pop("weights"), dims=("nradars", "z", "y", "x") + ) + grid_tree["bg_weights"] = xr.DataArray( + output_parameters.pop("bg_weights"), dims=("z", "y", "x") + ) + grid_tree["model_weights"] = xr.DataArray( + output_parameters.pop("model_weights"), dims=("nmodel", "z", "y", "x") + ) + output_parameters.pop("u_model") + output_parameters.pop("v_model") + output_parameters.pop("w_model") + grid_tree["output_parameters"] = xr.DataArray([], attrs=output_parameters) + grid_tree.__setitem__("u", output_grids[0]["u"]) + grid_tree.__setitem__("v", output_grids[0]["v"]) + grid_tree.__setitem__("w", output_grids[0]["w"]) + grid_tree["u"].attrs = output_grids[0]["u"].attrs + grid_tree["v"].attrs = output_grids[0]["v"].attrs + grid_tree["w"].attrs = output_grids[0]["w"].attrs if child_list == []: return grid_tree nests = [] for child in child_list: - if "radar_" not in child: + if "nest_" in child: nests.append(child) nests = sorted(nests) - for child in nests: + for i, child in enumerate(nests): + if i == 0: + continue # Only update child initalization if we are not in parent node if len(grid_list) > 0: - temp_src = grid_tree[rad_names[0]].to_dataset() - temp_src["u"] = grid_tree.ds["u"] - temp_src["v"] = grid_tree.ds["v"] - temp_src["w"] = grid_tree.ds["w"] + temp_src = grid_tree[f"nest_{i-1}"][rad_names[0]].to_dataset() + temp_src["u"] = grid_tree[f"nest_{i-1}"].ds["u"] + temp_src["v"] = grid_tree[f"nest_{i-1}"].ds["v"] + temp_src["w"] = grid_tree[f"nest_{i-1}"].ds["w"] input_grids = make_initialization_from_other_grid( temp_src, grid_tree.children[child][rad_names[0]].to_dataset() ) @@ -145,26 +179,31 @@ def get_dd_wind_field_nested(grid_tree: DataTree, **kwargs): grid_tree.children[child]["w"].attrs = temp_tree.ds["w"].attrs # Update parent grids from children + for child in child_list: + if "nest_" in child: + nests.append(child) + nests = sorted(nests) + if len(rad_names) > 0: - for child in nests: - temp_src = grid_tree.children[child][rad_names[0]].to_dataset() - temp_src["u"] = grid_tree.children[child].ds["u"] - temp_src["v"] = grid_tree.children[child].ds["v"] - temp_src["w"] = grid_tree.children[child].ds["w"] - temp_dest = grid_tree[rad_names[0]].to_dataset() - temp_dest["u"] = grid_tree.ds["u"] - temp_dest["v"] = grid_tree.ds["v"] - temp_dest["w"] = grid_tree.ds["w"] + for i, child in enumerate(nests[:-1]): + temp_src = grid_tree.children[nests[i + 1]][rad_names[0]].to_dataset() + temp_src["u"] = grid_tree.children[nests[i + 1]].ds["u"] + temp_src["v"] = grid_tree.children[nests[i + 1]].ds["v"] + temp_src["w"] = grid_tree.children[nests[i + 1]].ds["w"] + temp_dest = grid_tree.children[nests[i]][rad_names[0]].to_dataset() + temp_dest["u"] = grid_tree.children[nests[i]].ds["u"] + temp_dest["v"] = grid_tree.children[nests[i]].ds["v"] + temp_dest["w"] = grid_tree.children[nests[i]].ds["w"] output_grids = make_initialization_from_other_grid( temp_src, temp_dest, ) - grid_tree.__setitem__("u", output_grids["u"]) - grid_tree.__setitem__("v", output_grids["v"]) - grid_tree.__setitem__("w", output_grids["w"]) - grid_tree["u"].attrs = output_grids["u"].attrs - grid_tree["v"].attrs = output_grids["v"].attrs - grid_tree["w"].attrs = output_grids["w"].attrs + grid_tree.children[nests[i]].__setitem__("u", output_grids["u"]) + grid_tree.children[nests[i]].__setitem__("v", output_grids["v"]) + grid_tree.children[nests[i]].__setitem__("w", output_grids["w"]) + grid_tree.children[nests[i]]["u"].attrs = output_grids["u"].attrs + grid_tree.children[nests[i]]["v"].attrs = output_grids["v"].attrs + grid_tree.children[nests[i]]["w"].attrs = output_grids["w"].attrs return grid_tree diff --git a/pydda/tests/test_initialization.py b/pydda/tests/test_initialization.py index 2970cbc6..6baf317d 100644 --- a/pydda/tests/test_initialization.py +++ b/pydda/tests/test_initialization.py @@ -64,7 +64,11 @@ def test_get_iem_data(): Grid = pydda.io.read_from_pyart_grid(Grid) station_obs = pydda.constraints.get_iem_obs(Grid) names = [x["site_id"] for x in station_obs] - assert names == ["P28", "WLD", "WDG", "SWO", "END"] + assert "P28" in names + assert "WLD" in names + assert "WDG" in names + assert "SWO" in names + assert "END" in names def test_hrrr_data(): diff --git a/pydda/tests/test_retrieval.py b/pydda/tests/test_retrieval.py index dcccaf5e..f537bb6c 100644 --- a/pydda/tests/test_retrieval.py +++ b/pydda/tests/test_retrieval.py @@ -289,21 +289,25 @@ def test_nested_retrieval(): engine="scipy", ) + test_coarse1["time"] = test_coarse0["time"] + test_fine0["time"] = test_coarse0["time"] + test_fine1["time"] = test_coarse1["time"] + tree_dict = { - "/coarse/radar_ktlx": test_coarse0, - "/coarse/radar_kict": test_coarse1, - "/coarse/fine/radar_ktlx": test_fine0, - "/coarse/fine/radar_kict": test_fine1, + "/nest_0/radar_ktlx": test_coarse0, + "/nest_0/radar_kict": test_coarse1, + "/nest_1/radar_ktlx": test_fine0, + "/nest_1/radar_kict": test_fine1, } tree = DataTree.from_dict(tree_dict) - tree["/coarse/"].attrs = kwargs_dict - tree["/coarse/fine"].attrs = kwargs_dict + tree["/nest_0/"].attrs = kwargs_dict + tree["/nest_1/"].attrs = kwargs_dict grid_tree = pydda.retrieval.get_dd_wind_field_nested(tree) fig, ax = plt.subplots(1, 2, figsize=(10, 5)) pydda.vis.plot_horiz_xsection_quiver( - grid_tree["coarse"], + grid_tree["nest_0"], ax=ax[0], level=5, cmap="ChaseSpectral", @@ -318,7 +322,7 @@ def test_nested_retrieval(): quiverkey_loc="bottom_right", ) pydda.vis.plot_horiz_xsection_quiver( - grid_tree["coarse/fine"], + grid_tree["nest_1"], ax=ax[1], level=5, cmap="ChaseSpectral",