diff --git a/jaxparrow/tools/kinematics.py b/jaxparrow/tools/kinematics.py index 378be53..3968c22 100644 --- a/jaxparrow/tools/kinematics.py +++ b/jaxparrow/tools/kinematics.py @@ -152,7 +152,7 @@ def normalized_relative_vorticity( u : Float[Array, "lat lon"] U component of the velocity field (on the U grid) v : Float[Array, "lat lon"] - V component of the SSC velocity field (on the V grid) + V component of the velocity field (on the V grid) lat_u : Float[Array, "lat lon"] Latitudes of the U grid lon_u : Float[Array, "lat lon"] @@ -200,3 +200,110 @@ def normalized_relative_vorticity( w = sanitize_data(w, jnp.nan, mask) return w + + +def enstrophy( + ua: Float[Array, "lat lon"], + va: Float[Array, "lat lon"], + lat_u: Float[Array, "lat lon"], + lon_u: Float[Array, "lat lon"], + lat_v: Float[Array, "lat lon"], + lon_v: Float[Array, "lat lon"], + mask: Float[Array, "lat lon"] = None, + interpolate: bool = True +) -> Float[Array, "lat lon"]: + """ + Computes the enstrophy (ENS) of an anomaly velocity field, on a C-grid, following NEMO convention [1]_. + + The ``lat_u``, ``lon_u``, ``lat_v``, and ``lon_v`` are expected to follow the NEMO convention [1]_. + If not, the function will return inaccurate results. + + Parameters + ---------- + ua : Float[Array, "lat lon"] + U component of the anomaly velocity field (on the U grid) + va : Float[Array, "lat lon"] + V component of the anomaly velocity field (on the V grid) + lat_u : Float[Array, "lat lon"] + Latitudes of the U grid + lon_u : Float[Array, "lat lon"] + Longitudes of the U grid + lat_v : Float[Array, "lat lon"] + Latitudes of the V grid + lon_v : Float[Array, "lat lon"] + Longitudes of the V grid + mask : Float[Array, "lat lon"], optional + Mask defining the marine area of the spatial domain; `1` or `True` stands for masked (i.e. land) + interpolate : bool, optional + If `True`, the relative normalized vorticity is interpolated from the F grid to the T grid. + If `False`, it remains on the F grid. + + Defaults to `True` + + Returns + ------- + ens : Float[Array, "lat lon"] + The enstrophy, + on the F grid (if ``interpolate=False``), or the T grid (if ``interpolate=True``) + """ + # Compute spatial step + _, dy_u = compute_spatial_step(lat_u, lon_u) + dx_v, _ = compute_spatial_step(lat_v, lon_v) + + # Handle spurious data and apply mask + dy_u = sanitize_data(dy_u, jnp.nan, mask) + dx_v = sanitize_data(dx_v, jnp.nan, mask) + + # Compute the ENS + dua_dy_f = derivative(ua, dy_u, axis=0, padding="right") # (U(j), U(j+1)) -> F(j) + dva_dx_f = derivative(va, dx_v, axis=1, padding="right") # (V(i), V(i+1)) -> F(i) + ens_f = (dva_dx_f - dua_dy_f) / 2 # F(j) + + if interpolate: + ens_u = interpolation(ens_f, axis=0, padding="left") # (F(j), F(j+1)) -> U(j+1) + ens = interpolation(ens_u, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1) + else: + ens = ens_f + + ens = sanitize_data(ens, jnp.nan, mask) + + return ens + + +def eddy_kinetic_energy( + ua: Float[Array, "lat lon"], + va: Float[Array, "lat lon"], + interpolate: bool = True +) -> Float[Array, "lat lon"]: + """ + Computes the Eddy Kinetic Energy (EKE) of an anomaly velocity field, + possibly on a C-grid (following NEMO convention [1]_) if ``interpolate=True``. + + Parameters + ---------- + ua : Float[Array, "lat lon"] + U component of the anomaly velocity field (on the U grid) + va : Float[Array, "lat lon"] + V component of the anomaly velocity field (on the V grid) + interpolate : bool, optional + If `True`, the velocity components are assumed to be located on the U and V grids, + and are interpolated to the T one (following NEMO convention [1]_). + If `False`, the velocity components are assumed to be located on the T grid, and interpolation is not needed. + + Defaults to `True` + + Returns + ------- + eke : Float[Array, "lat lon"] + The Eddy Kinetic Energy on the T grid + """ + if interpolate: + # interpolate to the T point + ua_t = interpolation(ua, axis=1, padding="left") # (U(i), U(i+1)) -> T(i+1) + va_t = interpolation(va, axis=0, padding="left") # (V(j), V(j+1)) -> T(j+1) + else: + ua_t, va_t = ua, va + + eke_t = (ua_t ** 2 + va_t ** 2) / 2 + + return eke_t