Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve the jit graph operations #61

Merged
merged 2 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 21 additions & 37 deletions jaxparrow/cyclogeostrophy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Callable
from functools import partial
import numbers
from typing import Literal, Union

from jax import jit, lax, value_and_grad
Expand Down Expand Up @@ -41,7 +40,6 @@ def cyclogeostrophy(
optim: Union[optax.GradientTransformation, str] = "sgd",
optim_kwargs: dict = None,
res_eps: float = RES_EPS_IT,
res_init: Union[float, Literal["same"]] = RES_INIT_IT,
use_res_filter: bool = False,
res_filter_size: int = RES_FILTER_SIZE_IT,
return_geos: bool = False,
Expand Down Expand Up @@ -90,12 +88,6 @@ def cyclogeostrophy(
When residuals are smaller, the iterative approach considers local convergence to cyclogeostrophy.

Defaults to ``RES_EPS_IT``
res_init : Union[float | Literal["same"]], optional
Residual initial value of the iterative approach.
When residuals are larger at the first iteration,
the iterative approach considers local divergence to cyclogeostrophy.

If equals to `same` (default) absolute values of the geostrophic velocities are used
use_res_filter : bool, optional
Use of a convolution filter for the iterative approach when computing the residuals [3]_ or not [2]_.

Expand Down Expand Up @@ -159,11 +151,22 @@ def cyclogeostrophy(
coriolis_factor_v = sanitize.sanitize_data(coriolis_factor_v, jnp.nan, mask)

if method == "variational":
if n_it is None:
n_it = N_IT_VAR
if isinstance(optim, str):
if optim_kwargs is None:
optim_kwargs = {"learning_rate": LR_VAR}
optim = getattr(optax, optim)(**optim_kwargs)
elif not isinstance(optim, optax.GradientTransformation):
raise TypeError("optim should be an optax.GradientTransformation optimizer, or a string referring to such "
"an optimizer.")
res = _variational(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask,
n_it, optim, optim_kwargs, return_losses)
n_it, optim, return_losses)
elif method == "iterative":
if n_it is None:
n_it = N_IT_IT
res = _iterative(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask,
n_it, res_eps, res_init, use_res_filter, res_filter_size, return_losses)
n_it, res_eps, use_res_filter, res_filter_size, return_losses)
else:
raise ValueError("method should be one of [\"variational\", \"iterative\"]")

Expand Down Expand Up @@ -217,6 +220,7 @@ def _it_step(

# compute dist to u_cyclo and v_cyclo
res_np1 = jnp.abs(u_np1 - u_cyclo) + jnp.abs(v_np1 - v_cyclo)
res_np1 = sanitize.sanitize_data(res_np1, 0., mask)
res_np1 = lax.cond(
use_res_filter, # apply filter
lambda operands: jsp.signal.convolve(operands[0], operands[1], mode="same", method="fft") / operands[2],
Expand Down Expand Up @@ -248,7 +252,7 @@ def _it_step(
return u_cyclo, v_cyclo, mask_it, res_n, losses, i


@partial(jit, static_argnames=("n_it", "res_init", "res_filter_size"))
@partial(jit, static_argnames=("n_it", "res_filter_size"))
def _iterative(
u_geos_u: Float[Array, "lat lon"],
v_geos_v: Float[Array, "lat lon"],
Expand All @@ -259,22 +263,12 @@ def _iterative(
coriolis_factor_u: Float[Array, "lat lon"],
coriolis_factor_v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"],
n_it: Union[int, None],
n_it: int,
res_eps: float,
res_init: Union[float, str],
use_res_filter: bool,
res_filter_size: int,
return_losses: bool
) -> [Float[Array, "lat lon"], ...]:
if n_it is None:
n_it = N_IT_IT
if res_init == "same":
res_n = jnp.maximum(jnp.abs(u_geos_u), jnp.abs(v_geos_v))
elif isinstance(res_init, numbers.Number):
res_n = res_init * jnp.ones_like(u_geos_u)
else:
raise ValueError("res_init should be equal to \"same\" or be a number.")

# used if applying a filter when computing stopping criteria
res_filter = jnp.ones((res_filter_size, res_filter_size))
res_weights = jsp.signal.convolve(jnp.ones_like(u_geos_u), res_filter, mode="same", method="fft")
Expand All @@ -294,7 +288,8 @@ def step_fn(pytree):
u_cyclo, v_cyclo, _, _, losses, _ = lax.while_loop( # noqa
lambda args: (args[-1] < n_it) | jnp.any(args[2] != 1),
step_fn,
(u_geos_u, v_geos_v, mask.astype(int), res_n, jnp.ones(n_it) * jnp.nan, 0)
(u_geos_u, v_geos_v, mask.astype(int), jnp.maximum(jnp.abs(u_geos_u), jnp.abs(v_geos_v)),
jnp.ones(n_it) * jnp.nan, 0)
)

return u_cyclo, v_cyclo, losses
Expand Down Expand Up @@ -377,7 +372,7 @@ def step_fn(pytree):
return u_cyclo_u, v_cyclo_v, losses


@partial(jit, static_argnames=("n_it", "optim", "optim_kwargs"))
@partial(jit, static_argnames=("n_it", "optim"))
def _variational(
u_geos_u: Float[Array, "lat lon"],
v_geos_v: Float[Array, "lat lon"],
Expand All @@ -388,21 +383,10 @@ def _variational(
coriolis_factor_u: Float[Array, "lat lon"],
coriolis_factor_v: Float[Array, "lat lon"],
mask: Float[Array, "lat lon"],
n_it: Union[int, None],
optim: Union[optax.GradientTransformation, str],
optim_kwargs: Union[dict, None],
n_it: int,
optim: optax.GradientTransformation,
return_losses: bool
) -> [Float[Array, "lat lon"], ...]:
if n_it is None:
n_it = N_IT_VAR
if isinstance(optim, str):
if optim_kwargs is None:
optim_kwargs = {"learning_rate": LR_VAR}
optim = getattr(optax, optim)(**optim_kwargs)
elif not isinstance(optim, optax.GradientTransformation):
raise TypeError("optim should be an optax.GradientTransformation optimizer, or a string referring to such an "
"optimizer.")

# define loss partial: freeze constant over iterations
loss_fn = partial(
_var_loss_fn,
Expand Down
11 changes: 7 additions & 4 deletions jaxparrow/tools/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .geometry import compute_spatial_step, compute_coriolis_factor
from .operators import derivative, interpolation
from .sanitize import sanitize_data
from .sanitize import init_mask, sanitize_data


def advection(
Expand Down Expand Up @@ -175,6 +175,9 @@ def normalized_relative_vorticity(
The normalised relative vorticity,
on the F grid (if ``interpolate=False``), or the T grid (if ``interpolate=True``)
"""
# Make sure the mask is initialized
mask = init_mask(u, mask)

# Compute spatial step and Coriolis factor
_, dy_u = compute_spatial_step(lat_u, lon_u)
dx_v, _ = compute_spatial_step(lat_v, lon_v)
Expand Down Expand Up @@ -202,13 +205,13 @@ def normalized_relative_vorticity(
return w


def eddy_kinetic_energy(
def kinetic_energy(
u: Float[Array, "lat lon"],
v: Float[Array, "lat lon"],
interpolate: bool = True
) -> Float[Array, "lat lon"]:
"""
Computes the Eddy Kinetic Energy (EKE) of a velocity field,
Computes the Kinetic Energy (KE) of a velocity field,
possibly on a C-grid (following NEMO convention [1]_) if ``interpolate=True``.

Parameters
Expand All @@ -227,7 +230,7 @@ def eddy_kinetic_energy(
Returns
-------
eke : Float[Array, "lat lon"]
The Eddy Kinetic Energy on the T grid
The Kinetic Energy on the T grid
"""
if interpolate:
# interpolate to the T point
Expand Down
6 changes: 3 additions & 3 deletions jaxparrow/tools/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ def derivative(
field : Float[Array, "lat lon"]
Interpolated field
"""
def do_derivate(field_b, field_f, pad_left):
def do_differentiate(field_b, field_f, pad_left):
field_b, field_f = handle_land_boundary(field_b, field_f, pad_left)
return field_f - field_b

def axis0(_field, pad_left):
field_b, field_f = _field[:-1, :], _field[1:, :]
midpoint_values = do_derivate(field_b, field_f, pad_left)
midpoint_values = do_differentiate(field_b, field_f, pad_left)

_field = lax.cond(
pad_left,
Expand All @@ -124,7 +124,7 @@ def axis0(_field, pad_left):

def axis1(_field, pad_left):
field_b, field_f = _field[:, :-1], _field[:, 1:]
midpoint_values = do_derivate(field_b, field_f, pad_left)
midpoint_values = do_differentiate(field_b, field_f, pad_left)

_field = lax.cond(
pad_left,
Expand Down
9 changes: 6 additions & 3 deletions jaxparrow/tools/sanitize.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,17 @@ def handle_land_boundary(
Replaces the non-finite values of ``field1`` (``field2``) with values of ``field2`` (``field1``), element-wise.

It allows to introduce less non-finite values when applying grid operators.
In such cases, ``field1`` and ``field2`` are left and right shifted versions of a field.
In such cases, ``field1`` and ``field2`` are left and right shifted versions of a field (along one of the axes).

Parameters
----------
field1 : Float[Array, "lat lon"]
A field
field2 : Float[Array, "lat lon"]
Another field
pad_left : bool
If `True`, apply padding in the `left` direction (i.e. `West` or `South`) ;
if `False`, apply padding in the `right` direction (i.e. `East` or `North`).

Returns
-------
Expand All @@ -102,8 +105,8 @@ def sanitize_grid_np(
Sanitizes (unstructured) grids by interpolated and extrapolated `nan` or masked values to avoid spurious
(`0`, `nan`, `inf`) spatial steps and Coriolis factors.

Helper function written using ``numpy`` and ``scipy``, and as such not used internally,
because incompatible with ``jax.vmap``.
Helper function written using pure ``numpy`` and ``scipy``, and as such not used internally,
because incompatible with ``jax.vmap`` and likes.
Should be used before calling ``jaxparrow.geostrophy`` or ``jaxparrow.cyclogeostrophy``
in case of suspicious latitudes or longitudes T grids.

Expand Down
Loading
Loading