diff --git a/jaxparrow/tools/operators.py b/jaxparrow/tools/operators.py index da2a52a..7b5a851 100644 --- a/jaxparrow/tools/operators.py +++ b/jaxparrow/tools/operators.py @@ -3,6 +3,8 @@ import jax.numpy as jnp from jaxtyping import Array, Float +from .sanitize import handle_land_boundary + def interpolation( field: Float[Array, "lat lon"], @@ -35,13 +37,17 @@ def interpolation( Interpolated field """ if axis == 0: - midpoint_values = 0.5 * (field[:-1, :] + field[1:, :]) + field_b, field_f = field[:-1, :], field[1:, :] + field_b, field_f = handle_land_boundary(field_b, field_f) + midpoint_values = 0.5 * (field_b + field_f) if padding == "left": field = field.at[1:, :].set(midpoint_values) else: # padding == "right" field = field.at[:-1, :].set(midpoint_values) else: # axis == 1 - midpoint_values = 0.5 * (field[:, :-1] + field[:, 1:]) + field_b, field_f = field[:, :-1], field[:, 1:] + field_b, field_f = handle_land_boundary(field_b, field_f) + midpoint_values = 0.5 * (field_b + field_f) if padding == "left": field = field.at[:, 1:].set(midpoint_values) else: @@ -83,16 +89,18 @@ def derivative( Interpolated field """ if axis == 0: - midpoint_values = field[1:, :] - field[:-1, :] + field_b, field_f = field[:-1, :], field[1:, :] if padding == "left": pad_width = ((1, 0), (0, 0)) else: # padding == "right" pad_width = ((0, 1), (0, 0)) else: # axis == 1 - midpoint_values = field[:, 1:] - field[:, :-1] + field_b, field_f = field[:, :-1], field[:, 1:] if padding == "left": pad_width = ((0, 0), (1, 0)) else: pad_width = ((0, 0), (0, 1)) + field_b, field_f = handle_land_boundary(field_b, field_f) + midpoint_values = field_f - field_b field = jnp.pad(midpoint_values, pad_width=pad_width, mode="edge") / dxy return field diff --git a/jaxparrow/tools/sanitize.py b/jaxparrow/tools/sanitize.py index 72d2f59..d41b11b 100644 --- a/jaxparrow/tools/sanitize.py +++ b/jaxparrow/tools/sanitize.py @@ -58,10 +58,39 @@ def init_mask( Initialized (if needed) mask """ if mask is None: - mask = jnp.isnan(field) + mask = jnp.isfinite(field) return mask +def handle_land_boundary( + field1: Float[Array, "lat lon"], + field2: Float[Array, "lat lon"] +) -> [Float[Array, "lat lon"], Float[Array, "lat lon"]]: + """ + Replaces the non-finite values of ``field1`` (``field2``) with values of ``field2`` (``field1``), element-wise. + + It allows to introduce less non-finite values when applying grid operators. + In such cases, ``field1`` and ``field2`` are left and right shifted versions of a field. + + Parameters + ---------- + field1 : Float[Array, "lat lon"] + A field + field2 : Float[Array, "lat lon"] + Another field + + Returns + ------- + field1 : Float[Array, "lat lon"] + A field whose non-finite values have been replaced with the ones from ``field2`` + field2 : Float[Array, "lat lon"] + A field whose non-finite values have been replaced with the ones from ``field1`` + """ + field1 = jnp.where(jnp.isfinite(field1), field1, field2) + field2 = jnp.where(jnp.isfinite(field2), field2, field1) + return field1, field2 + + def sanitize_grid_np( lat: Float[Array, "lat lon"], lon: Float[Array, "lat lon"],