Skip to content

Commit

Permalink
Grid Ops and SW refactor (#48)
Browse files Browse the repository at this point in the history
* updated notebooks. working.

* Updated SW method with API1
  • Loading branch information
jejjohnson authored Aug 1, 2023
1 parent 18b71f6 commit c654c25
Show file tree
Hide file tree
Showing 10 changed files with 5,983 additions and 1,565 deletions.
6 changes: 3 additions & 3 deletions _toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,6 @@ chapters:
- title: Shallow Water Model
sections:
- file: content/sw/sw
- file: content/sw/sw_linear_api1
- file: content/sw/sw_linear_api2
- file: content/sw/sw_nonlinear
- file: content/sw/sw_linear_rossby_api1
- file: content/sw/sw_linear_jet_api1
- file: content/sw/sw_nonlinear_jet_api1
1,543 changes: 0 additions & 1,543 deletions content/sw/sw_linear_api1.ipynb

This file was deleted.

1,579 changes: 1,579 additions & 0 deletions content/sw/sw_linear_jet_api1.ipynb

Large diffs are not rendered by default.

1,996 changes: 1,996 additions & 0 deletions content/sw/sw_linear_rossby_api1.ipynb

Large diffs are not rendered by default.

38 changes: 19 additions & 19 deletions content/sw/sw_nonlinear.ipynb

Large diffs are not rendered by default.

1,718 changes: 1,718 additions & 0 deletions content/sw/sw_nonlinear_jet_api1.ipynb

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions jaxsw/_src/models/sw/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import typing as tp
from jaxsw._src.domain.base import Domain
import jax.numpy as jnp
from jaxtyping import Array


class Params(tp.NamedTuple):
domain: Domain
depth: float
gravity: float
coriolis_f0: float # or ARRAY
coriolis_beta: float # or ARRAY

@property
def phase_speed(self):
return jnp.sqrt(self.gravity * self.depth)

def rossby_radius(self, domain):
return self.phase_speed / self.coriolis_param(domain).mean()
# return self.phase_speed / self.coriolis_f0

def coriolis_param(self, domain):
return self.coriolis_f0 + domain.grid[..., 1] * self.coriolis_beta

def lateral_viscosity(self, domain):
return 1e-3 * self.coriolis_f0 * domain.dx[0] ** 2


class State(tp.NamedTuple):
u: Array
v: Array
h: Array

@classmethod
def init_state(cls, params, init_h=None, init_v=None, init_u=None):
h = init_h(params) if init_h is not None else State.zero_init(params.domain)
v = init_v(params) if init_v is not None else State.zero_init(params.domain)
u = init_u(params) if init_u is not None else State.zero_init(params.domain)
return cls(u=u, v=v, h=h)

@staticmethod
def zero_init(domain):
return jnp.zeros_like(domain.grid[..., 0])
146 changes: 146 additions & 0 deletions jaxsw/_src/models/sw/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import jax.numpy as jnp
from jaxsw._src.models.pde import DynamicalSystem
from jaxsw._src.operators.functional import grid as F_grid
import equinox as eqx
from jaxtyping import Array
from jaxsw._src.models.sw import Params, State


def enforce_boundaries(u: Array, component: str = "h", periodic: bool = False):
if periodic:
u = u.at[0, :].set(u[-2, :])
u = u.at[-1, :].set(u[1, :])
if component == "h":
return u
elif component == "u":
return u.at[-2, :].set(jnp.asarray(0.0))
elif component == "v":
return u.at[:, -2].set(jnp.asarray(0.0))
else:
msg = f"Unrecognized component: {component}"
msg += "\nNeeds to be h, u, or v"
raise ValueError(msg)


class LinearShallowWater2D(DynamicalSystem):
@staticmethod
def boundary_f(state: State, component: str = "h"):
if component == "h":
return state
elif component == "u":
u = state.u.at[-2, :].set(jnp.asarray(0.0))
return eqx.tree_at(lambda x: x.u, state, u)
elif component == "v":
v = state.v.at[:, -2].set(jnp.asarray(0.0))
return eqx.tree_at(lambda x: x.v, state, v)
else:
msg = f"Unrecognized component: {component}"
msg += "\nNeeds to be h, u, or v"
raise ValueError(msg)

@staticmethod
def equation_of_motion(t: float, state: State, args) -> State:
"""2D Linear Shallow Water Equations
Equation:
∂h/∂t + H (∂u/∂x + ∂v/∂y) = 0
∂u/∂t - fv = - g ∂h/∂x - ku
∂v/∂t + fu = - g ∂h/∂y - kv
"""

# apply boundary conditions
h: Array = enforce_boundaries(state.h, "h")
u = enforce_boundaries(state.u, "u")
v = enforce_boundaries(state.v, "v")
# update state
state = eqx.tree_at(lambda x: x.u, state, u)
state = eqx.tree_at(lambda x: x.v, state, v)
state = eqx.tree_at(lambda x: x.h, state, h)

# apply RHS
h_rhs, u_rhs, v_rhs = equation_of_motion(state, args)

# update state
state = eqx.tree_at(lambda x: x.u, state, u_rhs)
state = eqx.tree_at(lambda x: x.v, state, v_rhs)
state = eqx.tree_at(lambda x: x.h, state, h_rhs)

return state


def equation_of_motion(state: State, params: Params):
h, u, v = state.h, state.u, state.v

domain = params.domain

# enforce boundaries
h = enforce_boundaries(h, "h")
v = enforce_boundaries(v, "v")
u = enforce_boundaries(u, "u")

# pad boundaries with edge values
h_node = jnp.pad(h[1:-1, 1:-1], 1, "edge")
h_node = enforce_boundaries(h_node, "h")

# PLANETARY VORTICITY
planetary_vort = params.coriolis_param(domain)[1:-1, 1:-1]

# ################################
# HEIGHT Equation
# ∂h/∂t = - H (∂u/∂x + ∂v/∂y)
# ################################

# finite difference
# u --> h | top edge --> cell node | right edge --> cell center
# [Nx+2,Ny+2] --> [Nx+1,Ny+2] --> [Nx,Ny]
du_dx = F_grid.difference(
u, step_size=domain.dx[0], axis=0, accuracy=1, method="left"
)
du_dx = du_dx[:-1, 1:-1]

# v --> h | right edge --> cell node | top edge --> cell center
# [Nx+2,Ny+2] --> [Nx+2,Ny+1] --> [Nx,Ny]
dv_dy = F_grid.difference(
v, step_size=domain.dx[1], axis=1, accuracy=1, method="right"
)
dv_dy = dv_dy[1:-1, :-1]

# print("H_RHS")
h_rhs = jnp.zeros_like(h)
h_rhs = h_rhs.at[1:-1, 1:-1].set(-params.depth * (du_dx + dv_dy))

# #############################
# U VELOCITY
# ∂u/∂t = fv - g ∂h/∂x
# #############################
# [Nx+2,Ny+2] --> [Nx+1,Ny+1] --> [Nx,Ny]
v_on_u = planetary_vort * F_grid.interp_center(v)[1:, :-1]

# H --> U
# [Nx+2,Ny+2] --> [Nx+1,Ny+2] --> [Nx,Ny]
dhdx_on_u = (
-params.gravity
* F_grid.difference(h, axis=0, step_size=domain.dx[0], method="right")[1:, 1:-1]
)

u_rhs = jnp.zeros_like(h)
u_rhs = u_rhs.at[1:-1, 1:-1].set(v_on_u + dhdx_on_u)

# #############################
# V - VELOCITY
# ∂v/∂t = - fu - g ∂h/∂y
# #############################
# [Nx+2,Ny+2] --> [Nx+1,Ny+1] --> [Nx,Ny]
u_on_v = -planetary_vort * F_grid.interp_center(u)[:-1, 1:]

# H --> U
# [Nx+2,Ny+2] --> [Nx+2,Ny+1] --> [Nx,Ny]
dhdy_on_v = (
-params.gravity
* F_grid.difference(h, axis=1, step_size=domain.dx[1], method="right")[1:-1, 1:]
)

v_rhs = jnp.zeros_like(h)
v_rhs = v_rhs.at[1:-1, 1:-1].set(u_on_v + dhdy_on_v)

return h_rhs, u_rhs, v_rhs
Loading

0 comments on commit c654c25

Please sign in to comment.