From e9055524bc80147aee144b60302ed771ec60056d Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Tue, 16 Apr 2024 08:56:18 +0200 Subject: [PATCH 1/2] improve the jit graph operations --- jaxparrow/cyclogeostrophy.py | 58 +++++++++++++---------------------- jaxparrow/tools/kinematics.py | 11 ++++--- jaxparrow/tools/operators.py | 6 ++-- jaxparrow/tools/sanitize.py | 9 ++++-- tests/test_velocities.py | 11 ++++--- 5 files changed, 44 insertions(+), 51 deletions(-) diff --git a/jaxparrow/cyclogeostrophy.py b/jaxparrow/cyclogeostrophy.py index 9b8ec77..3eddb53 100644 --- a/jaxparrow/cyclogeostrophy.py +++ b/jaxparrow/cyclogeostrophy.py @@ -1,6 +1,5 @@ from collections.abc import Callable from functools import partial -import numbers from typing import Literal, Union from jax import jit, lax, value_and_grad @@ -41,7 +40,6 @@ def cyclogeostrophy( optim: Union[optax.GradientTransformation, str] = "sgd", optim_kwargs: dict = None, res_eps: float = RES_EPS_IT, - res_init: Union[float, Literal["same"]] = RES_INIT_IT, use_res_filter: bool = False, res_filter_size: int = RES_FILTER_SIZE_IT, return_geos: bool = False, @@ -90,12 +88,6 @@ def cyclogeostrophy( When residuals are smaller, the iterative approach considers local convergence to cyclogeostrophy. Defaults to ``RES_EPS_IT`` - res_init : Union[float | Literal["same"]], optional - Residual initial value of the iterative approach. - When residuals are larger at the first iteration, - the iterative approach considers local divergence to cyclogeostrophy. - - If equals to `same` (default) absolute values of the geostrophic velocities are used use_res_filter : bool, optional Use of a convolution filter for the iterative approach when computing the residuals [3]_ or not [2]_. @@ -159,11 +151,22 @@ def cyclogeostrophy( coriolis_factor_v = sanitize.sanitize_data(coriolis_factor_v, jnp.nan, mask) if method == "variational": + if n_it is None: + n_it = N_IT_VAR + if isinstance(optim, str): + if optim_kwargs is None: + optim_kwargs = {"learning_rate": LR_VAR} + optim = getattr(optax, optim)(**optim_kwargs) + elif not isinstance(optim, optax.GradientTransformation): + raise TypeError("optim should be an optax.GradientTransformation optimizer, or a string referring to such " + "an optimizer.") res = _variational(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask, - n_it, optim, optim_kwargs, return_losses) + n_it, optim, return_losses) elif method == "iterative": + if n_it is None: + n_it = N_IT_IT res = _iterative(u_geos_u, v_geos_v, dx_u, dx_v, dy_u, dy_v, coriolis_factor_u, coriolis_factor_v, mask, - n_it, res_eps, res_init, use_res_filter, res_filter_size, return_losses) + n_it, res_eps, use_res_filter, res_filter_size, return_losses) else: raise ValueError("method should be one of [\"variational\", \"iterative\"]") @@ -217,6 +220,7 @@ def _it_step( # compute dist to u_cyclo and v_cyclo res_np1 = jnp.abs(u_np1 - u_cyclo) + jnp.abs(v_np1 - v_cyclo) + res_np1 = sanitize.sanitize_data(res_np1, 0., mask) res_np1 = lax.cond( use_res_filter, # apply filter lambda operands: jsp.signal.convolve(operands[0], operands[1], mode="same", method="fft") / operands[2], @@ -248,7 +252,7 @@ def _it_step( return u_cyclo, v_cyclo, mask_it, res_n, losses, i -@partial(jit, static_argnames=("n_it", "res_init", "res_filter_size")) +@partial(jit, static_argnames=("n_it", "res_filter_size")) def _iterative( u_geos_u: Float[Array, "lat lon"], v_geos_v: Float[Array, "lat lon"], @@ -259,22 +263,12 @@ def _iterative( coriolis_factor_u: Float[Array, "lat lon"], coriolis_factor_v: Float[Array, "lat lon"], mask: Float[Array, "lat lon"], - n_it: Union[int, None], + n_it: int, res_eps: float, - res_init: Union[float, str], use_res_filter: bool, res_filter_size: int, return_losses: bool ) -> [Float[Array, "lat lon"], ...]: - if n_it is None: - n_it = N_IT_IT - if res_init == "same": - res_n = jnp.maximum(jnp.abs(u_geos_u), jnp.abs(v_geos_v)) - elif isinstance(res_init, numbers.Number): - res_n = res_init * jnp.ones_like(u_geos_u) - else: - raise ValueError("res_init should be equal to \"same\" or be a number.") - # used if applying a filter when computing stopping criteria res_filter = jnp.ones((res_filter_size, res_filter_size)) res_weights = jsp.signal.convolve(jnp.ones_like(u_geos_u), res_filter, mode="same", method="fft") @@ -294,7 +288,8 @@ def step_fn(pytree): u_cyclo, v_cyclo, _, _, losses, _ = lax.while_loop( # noqa lambda args: (args[-1] < n_it) | jnp.any(args[2] != 1), step_fn, - (u_geos_u, v_geos_v, mask.astype(int), res_n, jnp.ones(n_it) * jnp.nan, 0) + (u_geos_u, v_geos_v, mask.astype(int), jnp.maximum(jnp.abs(u_geos_u), jnp.abs(v_geos_v)), + jnp.ones(n_it) * jnp.nan, 0) ) return u_cyclo, v_cyclo, losses @@ -377,7 +372,7 @@ def step_fn(pytree): return u_cyclo_u, v_cyclo_v, losses -@partial(jit, static_argnames=("n_it", "optim", "optim_kwargs")) +@partial(jit, static_argnames=("n_it", "optim")) def _variational( u_geos_u: Float[Array, "lat lon"], v_geos_v: Float[Array, "lat lon"], @@ -388,21 +383,10 @@ def _variational( coriolis_factor_u: Float[Array, "lat lon"], coriolis_factor_v: Float[Array, "lat lon"], mask: Float[Array, "lat lon"], - n_it: Union[int, None], - optim: Union[optax.GradientTransformation, str], - optim_kwargs: Union[dict, None], + n_it: int, + optim: optax.GradientTransformation, return_losses: bool ) -> [Float[Array, "lat lon"], ...]: - if n_it is None: - n_it = N_IT_VAR - if isinstance(optim, str): - if optim_kwargs is None: - optim_kwargs = {"learning_rate": LR_VAR} - optim = getattr(optax, optim)(**optim_kwargs) - elif not isinstance(optim, optax.GradientTransformation): - raise TypeError("optim should be an optax.GradientTransformation optimizer, or a string referring to such an " - "optimizer.") - # define loss partial: freeze constant over iterations loss_fn = partial( _var_loss_fn, diff --git a/jaxparrow/tools/kinematics.py b/jaxparrow/tools/kinematics.py index 15f5a59..6c56fcc 100644 --- a/jaxparrow/tools/kinematics.py +++ b/jaxparrow/tools/kinematics.py @@ -3,7 +3,7 @@ from .geometry import compute_spatial_step, compute_coriolis_factor from .operators import derivative, interpolation -from .sanitize import sanitize_data +from .sanitize import init_mask, sanitize_data def advection( @@ -175,6 +175,9 @@ def normalized_relative_vorticity( The normalised relative vorticity, on the F grid (if ``interpolate=False``), or the T grid (if ``interpolate=True``) """ + # Make sure the mask is initialized + mask = init_mask(u, mask) + # Compute spatial step and Coriolis factor _, dy_u = compute_spatial_step(lat_u, lon_u) dx_v, _ = compute_spatial_step(lat_v, lon_v) @@ -202,13 +205,13 @@ def normalized_relative_vorticity( return w -def eddy_kinetic_energy( +def kinetic_energy( u: Float[Array, "lat lon"], v: Float[Array, "lat lon"], interpolate: bool = True ) -> Float[Array, "lat lon"]: """ - Computes the Eddy Kinetic Energy (EKE) of a velocity field, + Computes the Kinetic Energy (KE) of a velocity field, possibly on a C-grid (following NEMO convention [1]_) if ``interpolate=True``. Parameters @@ -227,7 +230,7 @@ def eddy_kinetic_energy( Returns ------- eke : Float[Array, "lat lon"] - The Eddy Kinetic Energy on the T grid + The Kinetic Energy on the T grid """ if interpolate: # interpolate to the T point diff --git a/jaxparrow/tools/operators.py b/jaxparrow/tools/operators.py index 8989d56..90269da 100644 --- a/jaxparrow/tools/operators.py +++ b/jaxparrow/tools/operators.py @@ -105,13 +105,13 @@ def derivative( field : Float[Array, "lat lon"] Interpolated field """ - def do_derivate(field_b, field_f, pad_left): + def do_differentiate(field_b, field_f, pad_left): field_b, field_f = handle_land_boundary(field_b, field_f, pad_left) return field_f - field_b def axis0(_field, pad_left): field_b, field_f = _field[:-1, :], _field[1:, :] - midpoint_values = do_derivate(field_b, field_f, pad_left) + midpoint_values = do_differentiate(field_b, field_f, pad_left) _field = lax.cond( pad_left, @@ -124,7 +124,7 @@ def axis0(_field, pad_left): def axis1(_field, pad_left): field_b, field_f = _field[:, :-1], _field[:, 1:] - midpoint_values = do_derivate(field_b, field_f, pad_left) + midpoint_values = do_differentiate(field_b, field_f, pad_left) _field = lax.cond( pad_left, diff --git a/jaxparrow/tools/sanitize.py b/jaxparrow/tools/sanitize.py index 2c319d1..eab7916 100644 --- a/jaxparrow/tools/sanitize.py +++ b/jaxparrow/tools/sanitize.py @@ -68,7 +68,7 @@ def handle_land_boundary( Replaces the non-finite values of ``field1`` (``field2``) with values of ``field2`` (``field1``), element-wise. It allows to introduce less non-finite values when applying grid operators. - In such cases, ``field1`` and ``field2`` are left and right shifted versions of a field. + In such cases, ``field1`` and ``field2`` are left and right shifted versions of a field (along one of the axes). Parameters ---------- @@ -76,6 +76,9 @@ def handle_land_boundary( A field field2 : Float[Array, "lat lon"] Another field + pad_left : bool + If `True`, apply padding in the `left` direction (i.e. `West` or `South`) ; + if `False`, apply padding in the `right` direction (i.e. `East` or `North`). Returns ------- @@ -102,8 +105,8 @@ def sanitize_grid_np( Sanitizes (unstructured) grids by interpolated and extrapolated `nan` or masked values to avoid spurious (`0`, `nan`, `inf`) spatial steps and Coriolis factors. - Helper function written using ``numpy`` and ``scipy``, and as such not used internally, - because incompatible with ``jax.vmap``. + Helper function written using pure ``numpy`` and ``scipy``, and as such not used internally, + because incompatible with ``jax.vmap`` and likes. Should be used before calling ``jaxparrow.geostrophy`` or ``jaxparrow.cyclogeostrophy`` in case of suspicious latitudes or longitudes T grids. diff --git a/tests/test_velocities.py b/tests/test_velocities.py index c16e6c9..1b814a6 100644 --- a/tests/test_velocities.py +++ b/tests/test_velocities.py @@ -1,4 +1,6 @@ -from jaxparrow.cyclogeostrophy import _iterative, _variational +import optax + +from jaxparrow.cyclogeostrophy import _iterative, _variational, LR_VAR from jaxparrow.geostrophy import _geostrophy from jaxparrow.tools.operators import interpolation from jaxparrow.tools.sanitize import init_mask @@ -23,7 +25,7 @@ def test_cyclogeostrophy_penven(self): u_cyclo_est, v_cyclo_est, _ = _iterative(u_geos_u, v_geos_v, self.dXY, self.dXY, self.dXY, self.dXY, self.coriolis_factor, self.coriolis_factor, mask, - 20, 0.01, "same", False, 3, False) + 20, 0.01, False, 3, False) u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left") v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding="left") cyclo_rmse = self.compute_rmse(self.u_cyclo, self.v_cyclo, u_cyclo_est_t, v_cyclo_est_t) # around .0035 @@ -36,7 +38,7 @@ def test_cyclogeostrophy_ioannou(self): u_cyclo_est, v_cyclo_est, _ = _iterative(u_geos_u, v_geos_v, self.dXY, self.dXY, self.dXY, self.dXY, self.coriolis_factor, self.coriolis_factor, mask, - 20, 0.01, "same", True, 3, False) + 20, 0.01, True, 3, False) u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left") v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding="left") cyclo_rmse = self.compute_rmse(self.u_cyclo, self.v_cyclo, u_cyclo_est_t, v_cyclo_est_t) # around .0035 @@ -46,10 +48,11 @@ def test_cyclogeostrophy_variational(self): mask = init_mask(self.u_geos) u_geos_u = interpolation(self.u_geos, axis=1, padding="right") v_geos_v = interpolation(self.v_geos, axis=0, padding="right") + optim = optax.sgd(learning_rate=LR_VAR) u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_u, v_geos_v, self.dXY, self.dXY, self.dXY, self.dXY, self.coriolis_factor, self.coriolis_factor, mask, - 20, "sgd", None, False) + 20, optim, False) u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left") v_cyclo_est_t = interpolation(v_cyclo_est, axis=0, padding="left") cyclo_rmse = self.compute_rmse(self.u_cyclo, self.v_cyclo, u_cyclo_est_t, v_cyclo_est_t) # around .0035 From a6768a9b9a57912d73b614424db786d625c52a0f Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Tue, 16 Apr 2024 08:59:21 +0200 Subject: [PATCH 2/2] update notebooks accordingly --- notebooks/gaussian_eddy.ipynb | 122 ++++++++++++----------- notebooks/gaussian_eddy/gaussian_eddy.md | 10 +- 2 files changed, 68 insertions(+), 64 deletions(-) diff --git a/notebooks/gaussian_eddy.ipynb b/notebooks/gaussian_eddy.ipynb index c56d227..98f1c52 100644 --- a/notebooks/gaussian_eddy.ipynb +++ b/notebooks/gaussian_eddy.ipynb @@ -7,8 +7,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:18.803393Z", - "start_time": "2024-03-27T09:48:17.490282Z" + "end_time": "2024-04-16T06:06:20.795629Z", + "start_time": "2024-04-16T06:06:19.516749Z" } }, "outputs": [], @@ -18,8 +18,9 @@ "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", + "import optax\n", "\n", - "from jaxparrow.cyclogeostrophy import _iterative, _variational\n", + "from jaxparrow.cyclogeostrophy import _iterative, _variational, LR_VAR\n", "from jaxparrow.geostrophy import _geostrophy\n", "from jaxparrow.tools.kinematics import magnitude\n", "from jaxparrow.tools.operators import interpolation\n", @@ -54,8 +55,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:18.816049Z", - "start_time": "2024-03-27T09:48:18.803541Z" + "end_time": "2024-04-16T06:06:20.809311Z", + "start_time": "2024-04-16T06:06:20.796162Z" } }, "outputs": [], @@ -85,8 +86,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:19.037726Z", - "start_time": "2024-03-27T09:48:18.816067Z" + "end_time": "2024-04-16T06:06:21.038434Z", + "start_time": "2024-04-16T06:06:20.809370Z" } }, "outputs": [], @@ -116,8 +117,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:19.357171Z", - "start_time": "2024-03-27T09:48:19.039226Z" + "end_time": "2024-04-16T06:06:21.344013Z", + "start_time": "2024-04-16T06:06:21.039655Z" } }, "outputs": [ @@ -148,8 +149,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:19.632375Z", - "start_time": "2024-03-27T09:48:19.355709Z" + "end_time": "2024-04-16T06:06:21.560665Z", + "start_time": "2024-04-16T06:06:21.336989Z" } }, "outputs": [ @@ -204,8 +205,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:19.729186Z", - "start_time": "2024-03-27T09:48:19.629885Z" + "end_time": "2024-04-16T06:06:21.569698Z", + "start_time": "2024-04-16T06:06:21.551351Z" } }, "outputs": [], @@ -220,8 +221,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:20.076605Z", - "start_time": "2024-03-27T09:48:19.655143Z" + "end_time": "2024-04-16T06:06:21.910989Z", + "start_time": "2024-04-16T06:06:21.570352Z" } }, "outputs": [ @@ -256,8 +257,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:20.387878Z", - "start_time": "2024-03-27T09:48:20.075007Z" + "end_time": "2024-04-16T06:06:22.161062Z", + "start_time": "2024-04-16T06:06:21.909967Z" } }, "outputs": [ @@ -300,8 +301,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:20.607746Z", - "start_time": "2024-03-27T09:48:20.388923Z" + "end_time": "2024-04-16T06:06:22.449759Z", + "start_time": "2024-04-16T06:06:22.156730Z" } }, "outputs": [], @@ -321,8 +322,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:20.880799Z", - "start_time": "2024-03-27T09:48:20.611195Z" + "end_time": "2024-04-16T06:06:22.744903Z", + "start_time": "2024-04-16T06:06:22.452556Z" } }, "outputs": [ @@ -357,8 +358,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:20.975547Z", - "start_time": "2024-03-27T09:48:20.882108Z" + "end_time": "2024-04-16T06:06:22.849715Z", + "start_time": "2024-04-16T06:06:22.739457Z" } }, "outputs": [ @@ -389,8 +390,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:21.015802Z", - "start_time": "2024-03-27T09:48:20.974942Z" + "end_time": "2024-04-16T06:06:22.917070Z", + "start_time": "2024-04-16T06:06:22.849692Z" } }, "outputs": [ @@ -439,8 +440,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:21.031821Z", - "start_time": "2024-03-27T09:48:21.017074Z" + "end_time": "2024-04-16T06:06:22.963314Z", + "start_time": "2024-04-16T06:06:22.905593Z" } }, "outputs": [], @@ -455,8 +456,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:21.319274Z", - "start_time": "2024-03-27T09:48:21.034151Z" + "end_time": "2024-04-16T06:06:23.318Z", + "start_time": "2024-04-16T06:06:22.929184Z" } }, "outputs": [ @@ -491,8 +492,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:21.470194Z", - "start_time": "2024-03-27T09:48:21.320931Z" + "end_time": "2024-04-16T06:06:23.426827Z", + "start_time": "2024-04-16T06:06:23.314041Z" } }, "outputs": [ @@ -535,8 +536,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:21.529543Z", - "start_time": "2024-03-27T09:48:21.417223Z" + "end_time": "2024-04-16T06:06:23.539343Z", + "start_time": "2024-04-16T06:06:23.425423Z" } }, "outputs": [], @@ -563,15 +564,16 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:22.586107Z", - "start_time": "2024-03-27T09:48:21.529896Z" + "end_time": "2024-04-16T06:06:24.603106Z", + "start_time": "2024-04-16T06:06:23.540016Z" } }, "outputs": [], "source": [ + "optim = optax.sgd(learning_rate=LR_VAR)\n", "u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY,\n", " coriolis_factor, coriolis_factor, mask,\n", - " n_it=20, optim=\"sgd\", optim_kwargs=None,\n", + " n_it=20, optim=optim,\n", " return_losses=False)\n", "\n", "u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding=\"left\")\n", @@ -587,8 +589,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:22.914879Z", - "start_time": "2024-03-27T09:48:22.587554Z" + "end_time": "2024-04-16T06:06:24.874994Z", + "start_time": "2024-04-16T06:06:24.604585Z" } }, "outputs": [ @@ -623,8 +625,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:23.006573Z", - "start_time": "2024-03-27T09:48:22.915936Z" + "end_time": "2024-04-16T06:06:24.968533Z", + "start_time": "2024-04-16T06:06:24.873642Z" } }, "outputs": [ @@ -655,8 +657,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:23.024525Z", - "start_time": "2024-03-27T09:48:23.006882Z" + "end_time": "2024-04-16T06:06:24.986644Z", + "start_time": "2024-04-16T06:06:24.967457Z" } }, "outputs": [ @@ -704,15 +706,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:23.378719Z", - "start_time": "2024-03-27T09:48:23.026195Z" + "end_time": "2024-04-16T06:06:25.444250Z", + "start_time": "2024-04-16T06:06:24.987409Z" } }, "outputs": [], "source": [ "u_cyclo_est, v_cyclo_est, _ = _iterative(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY,\n", " coriolis_factor, coriolis_factor, mask,\n", - " n_it=20, res_eps=0.01, res_init=\"same\", \n", + " n_it=20, res_eps=0.01,\n", " use_res_filter=True, res_filter_size=3, \n", " return_losses=False)\n", "\n", @@ -729,8 +731,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:23.650239Z", - "start_time": "2024-03-27T09:48:23.380245Z" + "end_time": "2024-04-16T06:06:25.717737Z", + "start_time": "2024-04-16T06:06:25.445894Z" } }, "outputs": [ @@ -765,8 +767,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:23.794821Z", - "start_time": "2024-03-27T09:48:23.646712Z" + "end_time": "2024-04-16T06:06:25.811113Z", + "start_time": "2024-04-16T06:06:25.716342Z" } }, "outputs": [ @@ -797,8 +799,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:23.795617Z", - "start_time": "2024-03-27T09:48:23.743304Z" + "end_time": "2024-04-16T06:06:25.828628Z", + "start_time": "2024-04-16T06:06:25.810917Z" } }, "outputs": [ @@ -834,15 +836,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:24.088571Z", - "start_time": "2024-03-27T09:48:23.759301Z" + "end_time": "2024-04-16T06:06:26.166281Z", + "start_time": "2024-04-16T06:06:25.829283Z" } }, "outputs": [], "source": [ "u_cyclo_est, v_cyclo_est, _ = _iterative(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY,\n", " coriolis_factor, coriolis_factor, mask,\n", - " n_it=20, res_eps=0.01, res_init=\"same\", \n", + " n_it=20, res_eps=0.01, \n", " use_res_filter=False, res_filter_size=1, \n", " return_losses=False)\n", "\n", @@ -859,8 +861,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:24.400204Z", - "start_time": "2024-03-27T09:48:24.090075Z" + "end_time": "2024-04-16T06:06:26.436757Z", + "start_time": "2024-04-16T06:06:26.167302Z" } }, "outputs": [ @@ -895,8 +897,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:24.450133Z", - "start_time": "2024-03-27T09:48:24.360068Z" + "end_time": "2024-04-16T06:06:26.527336Z", + "start_time": "2024-04-16T06:06:26.435532Z" } }, "outputs": [ @@ -927,8 +929,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-27T09:48:24.476603Z", - "start_time": "2024-03-27T09:48:24.457198Z" + "end_time": "2024-04-16T06:06:26.557442Z", + "start_time": "2024-04-16T06:06:26.532119Z" } }, "outputs": [ diff --git a/notebooks/gaussian_eddy/gaussian_eddy.md b/notebooks/gaussian_eddy/gaussian_eddy.md index 4ed01bb..09e4dfc 100644 --- a/notebooks/gaussian_eddy/gaussian_eddy.md +++ b/notebooks/gaussian_eddy/gaussian_eddy.md @@ -4,8 +4,9 @@ import sys import matplotlib.pyplot as plt import numpy as np +import optax -from jaxparrow.cyclogeostrophy import _iterative, _variational +from jaxparrow.cyclogeostrophy import _iterative, _variational, LR_VAR from jaxparrow.geostrophy import _geostrophy from jaxparrow.tools.kinematics import magnitude from jaxparrow.tools.operators import interpolation @@ -273,9 +274,10 @@ mask = init_mask(u_geos_t) ```python +optim = optax.sgd(learning_rate=LR_VAR) u_cyclo_est, v_cyclo_est, _ = _variational(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY, coriolis_factor, coriolis_factor, mask, - n_it=20, optim="sgd", optim_kwargs=None, + n_it=20, optim=optim, return_losses=False) u_cyclo_est_t = interpolation(u_cyclo_est, axis=1, padding="left") @@ -348,7 +350,7 @@ Use of a convolution filter when computing the residuals. ```python u_cyclo_est, v_cyclo_est, _ = _iterative(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY, coriolis_factor, coriolis_factor, mask, - n_it=20, res_eps=0.01, res_init="same", + n_it=20, res_eps=0.01, use_res_filter=True, res_filter_size=3, return_losses=False) @@ -418,7 +420,7 @@ No convolution filter, original approach. ```python u_cyclo_est, v_cyclo_est, _ = _iterative(u_geos_u, v_geos_v, dXY, dXY, dXY, dXY, coriolis_factor, coriolis_factor, mask, - n_it=20, res_eps=0.01, res_init="same", + n_it=20, res_eps=0.01, use_res_filter=False, res_filter_size=1, return_losses=False)