diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index dca9abea..c3757da0 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -17,7 +17,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11"] os: [macOS, ubuntu] inlcude: - os: macos-latest diff --git a/REQUIREMENTS.txt b/REQUIREMENTS.txt index cb8a06ce..5902dc2f 100644 --- a/REQUIREMENTS.txt +++ b/REQUIREMENTS.txt @@ -9,4 +9,4 @@ pooch cmweather cdsapi xarray -datatree +xarray-datatree diff --git a/continuous_integration/environment-actions.yml b/continuous_integration/environment-actions.yml index dd282668..f8e6c5d9 100644 --- a/continuous_integration/environment-actions.yml +++ b/continuous_integration/environment-actions.yml @@ -23,3 +23,4 @@ dependencies: - jaxopt - tensorflow>=2.6 - tensorflow-probability + - xarray-datatree diff --git a/doc/environment_docs.yml b/doc/environment_docs.yml index 11b63090..f0df2aab 100644 --- a/doc/environment_docs.yml +++ b/doc/environment_docs.yml @@ -27,3 +27,4 @@ dependencies: - sphinx-gallery - sphinx-copybutton - sphinx-design + - xarray-datatree diff --git a/doc/source/contributors_guide/index.rst b/doc/source/contributors_guide/index.rst index 91dbe4c4..4581e833 100644 --- a/doc/source/contributors_guide/index.rst +++ b/doc/source/contributors_guide/index.rst @@ -48,6 +48,7 @@ Examples of unacceptable behavior by participants include: advances Trolling, insulting/derogatory comments, and personal or political attacks + Public or private harassment Publishing others' private information, such as a physical or electronic diff --git a/examples/README.txt b/examples/README.txt index ecfe8377..dc4bc490 100644 --- a/examples/README.txt +++ b/examples/README.txt @@ -1,8 +1,6 @@ PyDDA Example Gallery ==================== -Different examples are given on how to retrieve winds using HRRR and radar data. - -Example grid data files for Hurricane Florence are available at: - -https://drive.google.com/drive/folders/1pcQxWRJV78xuJePTZnlXPPpMe1qut0ie +In this section, we show different examples on: + * How to use HRRR to initalize your wind retrieval + * How to adjust the variational retrieval parameters diff --git a/pydda/cost_functions/_cost_functions_jax.py b/pydda/cost_functions/_cost_functions_jax.py index 5dde4c90..aa54a033 100644 --- a/pydda/cost_functions/_cost_functions_jax.py +++ b/pydda/cost_functions/_cost_functions_jax.py @@ -389,9 +389,8 @@ def calculate_point_cost(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0): ), jnp.abs(z - the_point["z"]) < roi, ) - J += jnp.sum( - ((u[the_box] - the_point["u"]) ** 2 + (v[the_box] - the_point["v"]) ** 2) - ) + the_box = jnp.where(the_box, 1.0, 0.0) + J += jnp.sum(((u - the_point["u"]) ** 2 + (v - the_point["v"]) ** 2) * the_box) return J * Cp diff --git a/pydda/cost_functions/cost_functions.py b/pydda/cost_functions/cost_functions.py index a74230b9..b537bbdc 100644 --- a/pydda/cost_functions/cost_functions.py +++ b/pydda/cost_functions/cost_functions.py @@ -9,9 +9,6 @@ TENSORFLOW_AVAILABLE = False try: - from jax.config import config - - config.update("jax_enable_x64", True) import jax.numpy as jnp JAX_AVAILABLE = True @@ -858,7 +855,6 @@ def grad_jax(winds, parameters): parameters.point_list, Cp=parameters.Cpoint, roi=parameters.roi, - upper_bc=parameters.upper_bc, ) return grad diff --git a/pydda/io/__init__.py b/pydda/io/__init__.py index fb87d127..17467697 100644 --- a/pydda/io/__init__.py +++ b/pydda/io/__init__.py @@ -12,6 +12,7 @@ read_grid read_from_pyart_grid + read_hpl """ from .read_grid import read_grid, read_from_pyart_grid diff --git a/pydda/io/read_grid.py b/pydda/io/read_grid.py index ed2b37ae..d4378c89 100644 --- a/pydda/io/read_grid.py +++ b/pydda/io/read_grid.py @@ -1,5 +1,4 @@ import xarray as xr -import xradar as xd import numpy as np from glob import glob diff --git a/pydda/retrieval/nesting.py b/pydda/retrieval/nesting.py index 23874e3d..13a24e06 100644 --- a/pydda/retrieval/nesting.py +++ b/pydda/retrieval/nesting.py @@ -11,14 +11,23 @@ def get_dd_wind_field_nested(grid_tree: DataTree, **kwargs): """ Does a wind retrieval over nested grids. The nested grids are created using PyART's :func:`pyart.map.grid_from_radars` function and then placed into a tree structure using - dictionaries. Each node of the tree has three parameters: - 'input_grids': The list of PyART grids for the given level of the grid - 'kwargs': The list of key word arguments for input to the get_dd_wind_field function for the set of grids. - If this is None, then the default keyword arguments are carried from the keyword arguments of this function. - 'children': The list of trees that are the children of this node. + :func:`dataTree`s. Each node of the tree has three parameters: + .. list-table:: Title + :widths: 25 100 + :header-rows: 1 + + * - Dictionary key + - Description + * - input_grids + - The list of PyART grids for the given level of the grid + * - kwargs + - The list of key word arguments for input to the :py:func:`pydda.retrieval.get_dd_wind_field` function for the set of grids. + * - children + - The list of trees that are the children of this node. The function will output the same tree, with the list of output grids of each level output to the 'output_grids' - member of the tree structure. + member of the tree structure. If *kwargs* is set to None, then the input keyword arguments will be + used throughout the retrieval. """ # Look for radars in current level diff --git a/pydda/retrieval/wind_retrieve.py b/pydda/retrieval/wind_retrieve.py index f88cf108..3a755208 100644 --- a/pydda/retrieval/wind_retrieve.py +++ b/pydda/retrieval/wind_retrieve.py @@ -1326,7 +1326,7 @@ def get_dd_wind_field( Using Tensorflow or Jax expands PyDDA's capabiability to take advantage of GPU-based systems. In addition, these two implementations use automatic differentation to calculate the gradient of the cost function in order to optimize the gradient calculation. - TensorFlow 2.6 and tensorflow-probability are required for the TensorFlow-basedengine. + TensorFlow 2.6 and tensorflow-probability are required for the TensorFlow-based engine. The latest version of Jax is required for the Jax-based engine. points: None or list of dicts Point observations as returned by :func:`pydda.constraints.get_iem_obs`. Set @@ -1413,9 +1413,9 @@ def get_dd_wind_field( The list of fields in the first grid in Grids that contain the custom data interpolated to the Grid's grid specification. Helper functions to create such gridded fields for HRRR and NetCDF WRF data exist - in ::pydda.constraints::. PyDDA will look for fields named U_(model - field name), V_(model field name), and W_(model field name). For - example, if you have U_hrrr, V_hrrr, and W_hrrr, then specify ["hrrr"] + in :py:func:`pydda.constraints`. PyDDA will look for fields named *U_(model + field name)*, *V_(model field name)*, and *W_(model field name)*. For + example, if you have *U_hrrr*, *V_hrrr*, and *W_hrrr*, then specify *["hrrr"]* into model_fields. output_cost_functions: bool Set to True to output the value of each cost function every @@ -1429,9 +1429,9 @@ def get_dd_wind_field( wind_tol: float Stop iterations after maximum change in winds is less than this value. tolerance: float - Tolerance for L2 norm of gradient before stopping. + Tolerance for :math:`L_{2}` norm of gradient before stopping. max_wind_magnitude: float - Constrain the optimization to have :math:`|u|, :math:`|w|`, and :math:`|w| < x` m/s. + Constrain the optimization to have :math:`|u|`, :math:`|w|`, and :math:`|w| < x` m/s. Returns ======= diff --git a/pydda/tests/test_cost_functions.py b/pydda/tests/test_cost_functions.py index 2b96537b..ec490984 100644 --- a/pydda/tests/test_cost_functions.py +++ b/pydda/tests/test_cost_functions.py @@ -501,8 +501,6 @@ def test_vert_vorticity_tf(): def test_point_cost(): u = 1 * np.ones((10, 10, 10)) v = 1 * np.ones((10, 10, 10)) - 0 * np.ones((10, 10, 10)) - x = np.linspace(-10, 10, 10) y = np.linspace(-10, 10, 10) z = np.linspace(-10, 10, 10) @@ -556,8 +554,6 @@ def test_point_cost(): def test_point_cost_jax(): u = 1 * np.ones((10, 10, 10)) v = 1 * np.ones((10, 10, 10)) - 0 * np.ones((10, 10, 10)) - x = np.linspace(-10, 10, 10) y = np.linspace(-10, 10, 10) z = np.linspace(-10, 10, 10)