Skip to content

Commit

Permalink
Merge pull request #112 from rcjackson/jax_fix
Browse files Browse the repository at this point in the history
JaxOpt for Jax engine
  • Loading branch information
rcjackson authored Oct 4, 2023
2 parents b1fb285 + 5bdd314 commit 50b531a
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 259 deletions.
1 change: 1 addition & 0 deletions continuous_integration/environment-actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ dependencies:
- distributed
- cmweather
- jax
- jaxopt
- tensorflow>=2.6
- tensorflow-probability
3 changes: 2 additions & 1 deletion pydda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
print("Detecting Jax...")
try:
import jax
import jaxopt

print("Jax engine enabled!")
except ImportError:
print("Jax is not installed on your system, unable to use Jax engine.")
print("Jax/JaxOpt are not installed on your system, unable to use Jax engine.")

print("Detecting TensorFlow...")
try:
Expand Down
2 changes: 1 addition & 1 deletion pydda/cost_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@
from ._cost_functions_numpy import calculate_model_cost
from ._cost_functions_numpy import calculate_model_gradient
from ._cost_functions_numpy import calculate_point_cost, calculate_point_gradient
from .cost_functions import J_function, grad_J
from .cost_functions import J_function, grad_J, grad_jax, J_function_jax
123 changes: 92 additions & 31 deletions pydda/cost_functions/_cost_functions_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import jax
import jax.numpy as jnp

from jax import jit
from jax import float0

JAX_AVAILABLE = True
except ImportError:
JAX_AVAILABLE = False
Expand Down Expand Up @@ -143,7 +140,7 @@ def calculate_grad_radial_vel(
if upper_bc is True:
p_z1 = p_z1.at[-1, :, :].set(0)
y = jnp.stack((p_x1, p_y1, p_z1), axis=0)
return np.copy(y.flatten())
return y.flatten()


def calculate_smoothness_cost(u, v, w, dx, dy, dz, Cx=1e-5, Cy=1e-5, Cz=1e-5):
Expand Down Expand Up @@ -217,7 +214,7 @@ def calculate_smoothness_cost(u, v, w, dx, dy, dz, Cx=1e-5, Cy=1e-5, Cz=1e-5):
)
** 2
)
return np.asanyarray(jnp.sum(x_term + y_term + z_term))
return jnp.sum(x_term + y_term + z_term)


def calculate_smoothness_gradient(
Expand Down Expand Up @@ -256,31 +253,95 @@ def calculate_smoothness_gradient(
y: float array
value of gradient of smoothness cost function
"""
du = np.zeros(w.shape)
dv = np.zeros(w.shape)
dw = np.zeros(w.shape)
grad_u = np.zeros(w.shape)
grad_v = np.zeros(w.shape)
grad_w = np.zeros(w.shape)
scipy.ndimage.laplace(u, du, mode="wrap")
scipy.ndimage.laplace(v, dv, mode="wrap")
scipy.ndimage.laplace(w, dw, mode="wrap")
du = du / dx
dv = dv / dy
dw = dw / dz
scipy.ndimage.laplace(du, grad_u, mode="wrap")
scipy.ndimage.laplace(dv, grad_v, mode="wrap")
scipy.ndimage.laplace(dw, grad_w, mode="wrap")

grad_u = grad_u / dx
grad_v = grad_v / dy
grad_w = grad_w / dz
dudx = jnp.gradient(u, dx, axis=2)
dudy = jnp.gradient(u, dy, axis=1)
dudz = jnp.gradient(u, dz, axis=0)
dvdx = jnp.gradient(v, dx, axis=2)
dvdy = jnp.gradient(v, dy, axis=1)
dvdz = jnp.gradient(v, dz, axis=0)
dwdx = jnp.gradient(w, dx, axis=2)
dwdy = jnp.gradient(w, dy, axis=1)
dwdz = jnp.gradient(w, dz, axis=0)

x_term = (
Cx
* (
jnp.gradient(dudx, dx, axis=2)
+ jnp.gradient(dvdx, dx, axis=1)
+ jnp.gradient(dwdx, dx, axis=2)
)
** 2
)
y_term = (
Cy
* (
jnp.gradient(dudy, dy, axis=2)
+ jnp.gradient(dvdy, dy, axis=1)
+ jnp.gradient(dwdy, dy, axis=2)
)
** 2
)
z_term = (
Cz
* (
jnp.gradient(dudz, dz, axis=2)
+ jnp.gradient(dvdz, dz, axis=1)
+ jnp.gradient(dwdz, dz, axis=2)
)
** 2
)

du = x_term / dx
dv = y_term / dy
dw = z_term / dz
dudx = jnp.gradient(du, dx, axis=2)
dudy = jnp.gradient(du, dy, axis=1)
dudz = jnp.gradient(du, dz, axis=0)
dvdx = jnp.gradient(dv, dx, axis=2)
dvdy = jnp.gradient(dv, dy, axis=1)
dvdz = jnp.gradient(dv, dz, axis=0)
dwdx = jnp.gradient(dw, dx, axis=2)
dwdy = jnp.gradient(dw, dy, axis=1)
dwdz = jnp.gradient(dw, dz, axis=0)

x_term = (
Cx
* (
jnp.gradient(dudx, dx, axis=2)
+ jnp.gradient(dvdx, dx, axis=1)
+ jnp.gradient(dwdx, dx, axis=2)
)
** 2
)
y_term = (
Cy
* (
jnp.gradient(dudy, dy, axis=2)
+ jnp.gradient(dvdy, dy, axis=1)
+ jnp.gradient(dwdy, dy, axis=2)
)
** 2
)
z_term = (
Cz
* (
jnp.gradient(dudz, dz, axis=2)
+ jnp.gradient(dvdz, dz, axis=1)
+ jnp.gradient(dwdz, dz, axis=2)
)
** 2
)

grad_u = x_term / dx
grad_v = y_term / dy
grad_w = z_term / dz

# Impermeability condition
grad_w[0, :, :] = 0
grad_w.at[0, :, :].set(0)
if upper_bc is True:
grad_w[-1, :, :] = 0
y = np.stack([grad_u * Cx * 2, grad_v * Cy * 2, grad_w * Cz * 2], axis=0)
y = np.nan_to_num(y)
grad_w.at[-1, :, :].set(0)
y = jnp.stack([grad_u * Cx * 2, grad_v * Cy * 2, grad_w * Cz * 2], axis=0)

return y.flatten()


Expand Down Expand Up @@ -494,7 +555,7 @@ def calculate_mass_continuity_gradient(
if upper_bc is True:
grad_w = grad_w.at[-1, :, :].set(0)
y = jnp.stack([grad_u, grad_v, grad_w], axis=0)
return y.flatten().copy()
return y.flatten()


def calculate_background_cost(u, v, w, weights, u_back, v_back, Cb=0.01):
Expand Down Expand Up @@ -697,7 +758,7 @@ def calculate_vertical_vorticity_gradient(
w_grad.at[0, :, :].set(0)
if upper_bc is True:
w_grad.at[-1, :, :].set(0)
y = np.stack([u_grad, v_grad, w_grad], axis=0)
y = jnp.stack([u_grad, v_grad, w_grad], axis=0)
return y.flatten().copy()


Expand Down Expand Up @@ -781,5 +842,5 @@ def calculate_model_gradient(u, v, w, weights, u_model, v_model, w_model, coeff=
calculate_model_cost, u, v, w, weights, u_model, v_model, w_model, coeff
)
u_grad, v_grad, w_grad, _, _, _, _, _ = fun_vjp(1.0)
y = np.stack([u_grad, v_grad, w_grad], axis=0)
y = jnp.stack([u_grad, v_grad, w_grad], axis=0)
return y.flatten().copy()
Loading

0 comments on commit 50b531a

Please sign in to comment.