Skip to content

Commit

Permalink
some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
vadmbertr committed Mar 27, 2024
1 parent 00a3c49 commit 7daca19
Showing 1 changed file with 23 additions and 31 deletions.
54 changes: 23 additions & 31 deletions jaxparrow/tools/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@ def interpolation(
field : Float[Array, "lat lon"]
Interpolated field
"""
def do_interpolate(field_b, field_f, pad_left):
field_b, field_f = handle_land_boundary(field_b, field_f, pad_left)
return 0.5 * (field_b + field_f)

def axis0(arr, pad_left):
arr_b, arr_f = arr[:-1, :], arr[1:, :]
arr_b, arr_f = handle_land_boundary(arr_b, arr_f, pad_left)
midpoint_values = 0.5 * (arr_b + arr_f)
field_b, field_f = arr[:-1, :], arr[1:, :]
midpoint_values = do_interpolate(field_b, field_f, pad_left)
arr = lax.cond(
pad_left,
lambda operands: operands[0].at[1:, :].set(operands[1]),
Expand All @@ -50,9 +53,8 @@ def axis0(arr, pad_left):
return arr

def axis1(arr, pad_left):
arr_b, arr_f = arr[:, :-1], arr[:, 1:]
arr_b, arr_f = handle_land_boundary(arr_b, arr_f, pad_left)
midpoint_values = 0.5 * (arr_b + arr_f)
field_b, field_f = arr[:, :-1], arr[:, 1:]
midpoint_values = do_interpolate(field_b, field_f, pad_left)
arr = lax.cond(
pad_left,
lambda operands: operands[0].at[:, 1:].set(operands[1]),
Expand Down Expand Up @@ -103,50 +105,40 @@ def derivative(
field : Float[Array, "lat lon"]
Interpolated field
"""
def do_interpolate(field_b, field_f, pad_left):
def do_derivate(field_b, field_f, _dxy, pad_left):
field_b, field_f = handle_land_boundary(field_b, field_f, pad_left)
return field_f - field_b

def axis0(_field, pad_left):
def pad_left_fn(_field_b, _field_f):
midpoint_values = do_interpolate(_field_b, _field_f, True)
return jnp.pad(midpoint_values, pad_width=((1, 0), (0, 0)), mode="edge") / dxy

def pad_right_fn(_field_b, _field_f):
midpoint_values = do_interpolate(_field_b, _field_f, False)
return jnp.pad(midpoint_values, pad_width=((0, 1), (0, 0)), mode="edge") / dxy
return (field_f - field_b) / _dxy

def axis0(_field, _dxy, pad_left):
field_b, field_f = _field[:-1, :], _field[1:, :]
midpoint_values = do_derivate(field_b, field_f, _dxy, pad_left)

_field = lax.cond(
pad_left,
lambda _operands: pad_left_fn(*_operands), lambda _operands: pad_right_fn(*_operands),
(field_b, field_f)
lambda operand: jnp.pad(operand, pad_width=((1, 0), (0, 0)), mode="edge"),
lambda operand: jnp.pad(operand, pad_width=((0, 1), (0, 0)), mode="edge"),
midpoint_values
)

return _field

def axis1(_field, pad_left):
def pad_left_fn(_field_b, _field_f):
midpoint_values = do_interpolate(_field_b, _field_f, True)
return jnp.pad(midpoint_values, pad_width=((0, 0), (1, 0)), mode="edge") / dxy

def pad_right_fn(_field_b, _field_f):
midpoint_values = do_interpolate(_field_b, _field_f, False)
return jnp.pad(midpoint_values, pad_width=((0, 0), (0, 1)), mode="edge") / dxy

def axis1(_field, _dxy, pad_left):
field_b, field_f = _field[:, :-1], _field[:, 1:]
midpoint_values = do_derivate(field_b, field_f, _dxy, pad_left)

_field = lax.cond(
pad_left,
lambda _operands: pad_left_fn(*_operands), lambda _operands: pad_right_fn(*_operands),
(field_b, field_f)
lambda operand: jnp.pad(operand, pad_width=((0, 0), (1, 0)), mode="edge"),
lambda operand: jnp.pad(operand, pad_width=((0, 0), (0, 1)), mode="edge"),
midpoint_values
)

return _field

field = lax.cond(
axis == 0,
lambda operands: axis0(*operands), lambda operands: axis1(*operands),
(field, padding == "left")
(field, dxy, padding == "left")
)

return field

0 comments on commit 7daca19

Please sign in to comment.