Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #1 #20

Merged
merged 5 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
python-version: "3.9"
- name: Install dependencies
run: |
pip install sphinx sphinx_rtd_theme myst_parser
pip install sphinx myst_parser
- name: Sphinx build
run: |
sphinx-build docs docs/_build
Expand Down
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,11 @@ See **jaxparrow** [API documentation](https://meom-group.github.io/jaxparrow/api

### As an executable

***TBP***
**jaxparrow** is also available from the command line:
```shell
jaxparrow --conf_path conf.yml
```
The YAML configuration file `conf.yml` instruct where input netCDF files are locally stored, and how to retrieve variables and coordinates from them.
It also provides the path of the output netCDF file. Optionally, it can specify which cyclogeostrophic approach should be applied and its hyperparameters.

An example configuration file detailing all the required and optional entries can be found [here](https://github.com/meom-group/jaxparrow/blob/main/docs/example-conf.yml).
95 changes: 95 additions & 0 deletions docs/example-conf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# ----------------
# ----------------
# Required entries
# ----------------
# ----------------

# -----------------------
# Input Dataset variables
# -----------------------
# file_path and var_name are required, index is optional

# longitude grid of the SSH
lon_ssh:
file_path: "notebooks/data/eNATL60MEDWEST-BLB002_y2009m07d01.1h_sossheig.nc"
var_name: "nav_lon"

# latitude grid of the SSH
lat_ssh:
file_path: "notebooks/data/eNATL60MEDWEST-BLB002_y2009m07d01.1h_sossheig.nc"
var_name: "nav_lat"

# SSH grid
ssh:
file_path: "notebooks/data/eNATL60MEDWEST-BLB002_y2009m07d01.1h_sossheig.nc"
var_name: "sossheig"
index: [0]

# longitude grid of the U component
lon_u:
file_path: "notebooks/data/eNATL60MEDWEST-BLB002_y2009m07d01.1h_sozocrtx.nc"
var_name: "nav_lon"

# latitude grid of the U component
lat_u:
file_path: "notebooks/data/eNATL60MEDWEST-BLB002_y2009m07d01.1h_sozocrtx.nc"
var_name: "nav_lat"

# longitude grid of the V component
lon_v:
file_path: "notebooks/data/eNATL60MEDWEST-BLB002_y2009m07d01.1h_somecrty.nc"
var_name: "nav_lon"

# latitude grid of the V component
lat_v:
file_path: "notebooks/data/eNATL60MEDWEST-BLB002_y2009m07d01.1h_somecrty.nc"
var_name: "nav_lat"

# ------------------------
# Output Dataset full path
# ------------------------

out_path: "notebooks/data/out.nc"


# --------------
# --------------
# Optional entry
# --------------
# --------------

# -----------------------
# Input Dataset variables
# -----------------------
# file_path and var_name are required, index is optional
# mask arrays are expected to be boolean or numeric, with True (or 1) for valid cells and False (or 0) for invalid ones.

# mask of the SSH
mask_ssh:
file_path: "notebooks/data/mask_eNATL60MEDWEST_3.6.nc"
var_name: "tmask"
index: [0, 0] # here we select the first time and elevation elements (optional, depends on the datastructure)

# mask of the U component
mask_u:
file_path: "notebooks/data/mask_eNATL60MEDWEST_3.6.nc"
var_name: "umask"
index: [0, 0]

# mask of the V component
mask_v:
file_path: "notebooks/data/mask_eNATL60MEDWEST_3.6.nc"
var_name: "vmask"
index: [0, 0]

# ---------------------------------------
# Arguments to the cyclogeostrophy method
# ---------------------------------------
# allows to tune the cyclogeostrophic approach applied, and its hyperparameters.
# refer to the documentation for a comprehensive list of the available optional arguments:
# https://meom-group.github.io/jaxparrow/jaxparrow.cyclogeostrophy.html
# if not provided, default values are used.

cyclogeostrophy:
method: "variational"
n_it: 100
5 changes: 3 additions & 2 deletions jaxparrow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .version import __version__
from .__main__ import main
from .cyclogeostrophy import cyclogeostrophy
from .geostrophy import geostrophy
from .version import __version__

__all__ = ["__version__", "cyclogeostrophy", "geostrophy"]
__all__ = ["cyclogeostrophy", "geostrophy"]
150 changes: 134 additions & 16 deletions jaxparrow/__main__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,148 @@
import argparse
from typing import Union
import yaml

import numpy as np
import numpy.ma as ma
import xarray as xr

def do_main(dir_data, name_mask, name_ssh, name_u, name_v, write_dir):
return
from .tools import compute_coriolis_factor, compute_spatial_step
from .cyclogeostrophy import cyclogeostrophy
from .geostrophy import geostrophy


def _read_data(conf_path: str) -> list:
with open(conf_path) as f:
conf = yaml.safe_load(f) # parse conf file

values = []
seen_ds = {}

# variable descriptions to be found in the conf file
variables = ["mask_ssh", "mask_u", "mask_v", "lon_ssh", "lat_ssh", "ssh", "lon_u", "lat_u", "lon_v", "lat_v"]
for var in variables:
try:
conf_entry = conf[var]

# each variable refers to a netCDF file (xarray Dataset)
if conf_entry["file_path"] not in seen_ds:
seen_ds[conf_entry["file_path"]] = xr.open_dataset(conf_entry["file_path"])
ds = seen_ds[conf_entry["file_path"]]
# and to a xarray Dataset variable
val = ds[conf_entry["var_name"]]
# optionally, the user can use indexing (if one needs to extract observation at a specific time for ex.)
if "index" in conf_entry:
idx = conf_entry["index"]
val = val[{val.dims[i]: idx[i] for i in range(len(idx))}]
except KeyError as e:
if "mask" in var: # mask variables are optional
val = None
else:
raise e

values.append(val)

# in addition, the user can provide arguments passed to the cyclogeostrophic method
values.append(conf.get("cyclogeostrophy", {}))
# and he must provide the full path (including its name and extension) of the output file
values.append(conf["out_path"])

return values


def _apply_mask(mask_ssh: Union[np.ndarray, None], mask_u: Union[np.ndarray, None], mask_v: Union[np.ndarray, None],
ssh: np.ndarray, lon_ssh: np.ndarray, lat_ssh: np.ndarray,
lon_u: np.ndarray, lat_u: np.ndarray, lon_v: np.ndarray, lat_v: np.ndarray) -> tuple:
def __do_apply(arr: np.ndarray, mask: Union[np.ndarray, None]) -> np.ndarray:
if mask is None:
mask = np.ones_like(arr)
mask = 1 - mask # don't forget to invert the masks (for ma.MaskedArray, True means invalid)
return ma.masked_array(arr, mask)

ssh = __do_apply(ssh, mask_ssh)
lon_ssh = __do_apply(lon_ssh, mask_ssh)
lat_ssh = __do_apply(lat_ssh, mask_ssh)

lon_u = __do_apply(lon_u, mask_u)
lat_u = __do_apply(lat_u, mask_u)

lon_v = __do_apply(lon_v, mask_v)
lat_v = __do_apply(lat_v, mask_v)

return ssh, lon_ssh, lat_ssh, lon_u, lat_u, lon_v, lat_v


def _compute_spatial_step(lon_ssh: ma.MaskedArray, lat_ssh: ma.MaskedArray,
lon_u: ma.MaskedArray, lat_u: ma.MaskedArray,
lon_v: ma.MaskedArray, lat_v: ma.MaskedArray) -> tuple:
dx_ssh, dy_ssh = compute_spatial_step(lat_ssh, lon_ssh)
dx_u, dy_u = compute_spatial_step(lat_u, lon_u)
dx_v, dy_v = compute_spatial_step(lat_v, lon_v)

return dx_ssh, dy_ssh, dx_u, dy_u, dx_v, dy_v


def _compute_coriolis_factor(lat_u: ma.MaskedArray, lat_v: ma.MaskedArray) -> tuple:
coriolis_factor_u = compute_coriolis_factor(lat_u)
coriolis_factor_v = compute_coriolis_factor(lat_v)

return coriolis_factor_u, coriolis_factor_v


def _to_dataset(u_geos: np.ndarray, v_geos: np.ndarray, u_cyclo: np.ndarray, v_cyclo: np.ndarray,
lon_u: ma.MaskedArray, lat_u: ma.MaskedArray, lon_v: ma.MaskedArray, lat_v: ma.MaskedArray) \
-> xr.Dataset:
ds = xr.Dataset({
"u_geos": (["y", "x"], u_geos),
"v_geos": (["y", "x"], v_geos),
"u_cyclo": (["y", "x"], u_cyclo),
"v_cyclo": (["y", "x"], v_cyclo)
}, coords={
"u_lon": (["y", "x"], lon_u),
"u_lat": (["y", "x"], lat_u),
"v_lon": (["y", "x"], lon_v),
"v_lat": (["y", "x"], lat_v),
})
return ds


def _write_data(u_geos: np.ndarray, v_geos: np.ndarray, u_cyclo: np.ndarray, v_cyclo: np.ndarray,
lon_u: ma.MaskedArray, lat_u: ma.MaskedArray, lon_v: ma.MaskedArray, lat_v: ma.MaskedArray,
out_path: str):
ds = _to_dataset(u_geos, v_geos, u_cyclo, v_cyclo, lon_u, lat_u, lon_v, lat_v)
ds.to_netcdf(out_path)


def _main(conf_path: str):
mask_ssh, mask_u, mask_v, ssh, lon_ssh, lat_ssh, lon_u, lat_u, lon_v, lat_v, cyclo_kwargs, out_path = (
_read_data(conf_path))

ssh, lon_ssh, lat_ssh, lon_u, lat_u, lon_v, lat_v = _apply_mask(mask_ssh, mask_u, mask_v, ssh, lon_ssh, lat_ssh,
lon_u, lat_u, lon_v, lat_v)

dx_ssh, dy_ssh, dx_u, dy_u, dx_v, dy_v = _compute_spatial_step(lon_ssh, lat_ssh, lon_u, lat_u, lon_v, lat_v)

coriolis_factor_u, coriolis_factor_v = _compute_coriolis_factor(lat_u, lat_v)

u_geos, v_geos = geostrophy(ssh, dx_ssh, dy_ssh, coriolis_factor_u, coriolis_factor_v)
u_cyclo, v_cyclo = cyclogeostrophy(u_geos, v_geos, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v,
**cyclo_kwargs)

_write_data(u_geos, v_geos, u_cyclo, v_cyclo, lon_u, lat_u, lon_v, lat_v, out_path)


def main():
parser = argparse.ArgumentParser(prog="Cyclogeostrophic balance",
description="Computes the inversion of the cyclogeostrophic balance using a "
"variational formulation with gradient descent minimization approach.")

parser.add_argument("--dir_data", default="notebooks/data", type=str, help="data directory")
parser.add_argument("--name_mask", default="mask_eNATL60MEDWEST_3.6.nc", type=str,
help="mask file name")
parser.add_argument("--name_ssh", default="eNATL60MEDWEST-BLB002_y2009m07d01.1h_sossheig.nc",
type=str, help="SSH file name")
parser.add_argument("--name_u", default="eNATL60MEDWEST-BLB002_y2009m07d01.1h_sozocrtx.nc",
type=str, help="u velocity file name")
parser.add_argument("--name_v", default="eNATL60MEDWEST-BLB002_y2009m07d01.1h_somecrty.nc",
type=str, help="v velocity file name")
parser.add_argument("--write_dir", default="notebooks/data", type=str,
help="cyclogeostrophic outputs directory")
"variational or iterative approach.",
epilog="For an example yaml configuration file, see the documentation: "
"https://meom-group.github.io/jaxparrow/description.html#as-an-executable")

parser.add_argument("--conf_path", type=str, help="yaml configuration file path", required=True)

args = parser.parse_args()

do_main(args.dir_data, args.name_mask, args.name_ssh, args.name_u, args.name_v, args.write_dir)
_main(args.conf_path)


if __name__ == "__main__":
Expand Down
14 changes: 7 additions & 7 deletions jaxparrow/cyclogeostrophy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@

from .tools import tools

#: Default maximum number of iterations for the iterative approach
#: Default maximum number of iterations for Penven and Ioannou approaches
N_IT_IT = 100
#: Default residual tolerance of the iterative approach
#: Default residual tolerance of Penven and Ioannou approaches
RES_EPS_IT = 0.0001
#: Default residual value used during the first iteration
#: Default residual value used during the first iteration of Penven and Ioannou approaches
RES_INIT_IT = "same"
#: Default size of the grid points used to compute the residual in Ioannou's iterative approach
#: Default size of the grid points used to compute the residual in Ioannou's approach
RES_FILTER_SIZE_IT = 3

#: Default maximum number of iterations for the variational approach
#: Default maximum number of iterations for our variational approach
N_IT_VAR = 2000
#: Default learning rate for the gradient descent of the variational approach
#: Default learning rate for the gradient descent of our variational approach
LR_VAR = 0.005

__all__ = ["cyclogeostrophy", "LR_VAR", "N_IT_IT", "N_IT_VAR", "RES_EPS_IT", "RES_INIT_IT", "RES_FILTER_SIZE_IT"]
Expand Down Expand Up @@ -60,7 +60,7 @@ def cyclogeostrophy(u_geos: Union[np.ndarray, np.ma.MaskedArray], v_geos: Union[
:type coriolis_factor_u: Union[np.ndarray, np.ma.MaskedArray]
:param coriolis_factor_v: V Coriolis factor
:type coriolis_factor_v: Union[np.ndarray, np.ma.MaskedArray]
:param method: numerical method to use, defaults to "variational"
:param method: estimation method to use, defaults to "variational"
:type method: Literal["variational", "penven", "ioannou"], optional
:param n_it: maximum number of iterations, defaults to N_IT_IT
:type n_it: int, optional
Expand Down
3 changes: 0 additions & 3 deletions jaxparrow/grid/__init__.py

This file was deleted.

Loading