diff --git a/continuous_integration/environment-actions.yml b/continuous_integration/environment-actions.yml index b5678f33..dd282668 100644 --- a/continuous_integration/environment-actions.yml +++ b/continuous_integration/environment-actions.yml @@ -20,5 +20,6 @@ dependencies: - distributed - cmweather - jax + - jaxopt - tensorflow>=2.6 - tensorflow-probability diff --git a/pydda/__init__.py b/pydda/__init__.py index 35c3360b..02f276fd 100644 --- a/pydda/__init__.py +++ b/pydda/__init__.py @@ -19,10 +19,11 @@ print("Detecting Jax...") try: import jax + import jaxopt print("Jax engine enabled!") except ImportError: - print("Jax is not installed on your system, unable to use Jax engine.") + print("Jax/JaxOpt are not installed on your system, unable to use Jax engine.") print("Detecting TensorFlow...") try: diff --git a/pydda/cost_functions/__init__.py b/pydda/cost_functions/__init__.py index 80550734..79174927 100644 --- a/pydda/cost_functions/__init__.py +++ b/pydda/cost_functions/__init__.py @@ -95,4 +95,4 @@ from ._cost_functions_numpy import calculate_model_cost from ._cost_functions_numpy import calculate_model_gradient from ._cost_functions_numpy import calculate_point_cost, calculate_point_gradient -from .cost_functions import J_function, grad_J +from .cost_functions import J_function, grad_J, grad_jax, J_function_jax diff --git a/pydda/cost_functions/_cost_functions_jax.py b/pydda/cost_functions/_cost_functions_jax.py index 94901a63..5dde4c90 100644 --- a/pydda/cost_functions/_cost_functions_jax.py +++ b/pydda/cost_functions/_cost_functions_jax.py @@ -5,9 +5,6 @@ import jax import jax.numpy as jnp - from jax import jit - from jax import float0 - JAX_AVAILABLE = True except ImportError: JAX_AVAILABLE = False @@ -143,7 +140,7 @@ def calculate_grad_radial_vel( if upper_bc is True: p_z1 = p_z1.at[-1, :, :].set(0) y = jnp.stack((p_x1, p_y1, p_z1), axis=0) - return np.copy(y.flatten()) + return y.flatten() def calculate_smoothness_cost(u, v, w, dx, dy, dz, Cx=1e-5, Cy=1e-5, Cz=1e-5): @@ -217,7 +214,7 @@ def calculate_smoothness_cost(u, v, w, dx, dy, dz, Cx=1e-5, Cy=1e-5, Cz=1e-5): ) ** 2 ) - return np.asanyarray(jnp.sum(x_term + y_term + z_term)) + return jnp.sum(x_term + y_term + z_term) def calculate_smoothness_gradient( @@ -256,31 +253,95 @@ def calculate_smoothness_gradient( y: float array value of gradient of smoothness cost function """ - du = np.zeros(w.shape) - dv = np.zeros(w.shape) - dw = np.zeros(w.shape) - grad_u = np.zeros(w.shape) - grad_v = np.zeros(w.shape) - grad_w = np.zeros(w.shape) - scipy.ndimage.laplace(u, du, mode="wrap") - scipy.ndimage.laplace(v, dv, mode="wrap") - scipy.ndimage.laplace(w, dw, mode="wrap") - du = du / dx - dv = dv / dy - dw = dw / dz - scipy.ndimage.laplace(du, grad_u, mode="wrap") - scipy.ndimage.laplace(dv, grad_v, mode="wrap") - scipy.ndimage.laplace(dw, grad_w, mode="wrap") - - grad_u = grad_u / dx - grad_v = grad_v / dy - grad_w = grad_w / dz + dudx = jnp.gradient(u, dx, axis=2) + dudy = jnp.gradient(u, dy, axis=1) + dudz = jnp.gradient(u, dz, axis=0) + dvdx = jnp.gradient(v, dx, axis=2) + dvdy = jnp.gradient(v, dy, axis=1) + dvdz = jnp.gradient(v, dz, axis=0) + dwdx = jnp.gradient(w, dx, axis=2) + dwdy = jnp.gradient(w, dy, axis=1) + dwdz = jnp.gradient(w, dz, axis=0) + + x_term = ( + Cx + * ( + jnp.gradient(dudx, dx, axis=2) + + jnp.gradient(dvdx, dx, axis=1) + + jnp.gradient(dwdx, dx, axis=2) + ) + ** 2 + ) + y_term = ( + Cy + * ( + jnp.gradient(dudy, dy, axis=2) + + jnp.gradient(dvdy, dy, axis=1) + + jnp.gradient(dwdy, dy, axis=2) + ) + ** 2 + ) + z_term = ( + Cz + * ( + jnp.gradient(dudz, dz, axis=2) + + jnp.gradient(dvdz, dz, axis=1) + + jnp.gradient(dwdz, dz, axis=2) + ) + ** 2 + ) + + du = x_term / dx + dv = y_term / dy + dw = z_term / dz + dudx = jnp.gradient(du, dx, axis=2) + dudy = jnp.gradient(du, dy, axis=1) + dudz = jnp.gradient(du, dz, axis=0) + dvdx = jnp.gradient(dv, dx, axis=2) + dvdy = jnp.gradient(dv, dy, axis=1) + dvdz = jnp.gradient(dv, dz, axis=0) + dwdx = jnp.gradient(dw, dx, axis=2) + dwdy = jnp.gradient(dw, dy, axis=1) + dwdz = jnp.gradient(dw, dz, axis=0) + + x_term = ( + Cx + * ( + jnp.gradient(dudx, dx, axis=2) + + jnp.gradient(dvdx, dx, axis=1) + + jnp.gradient(dwdx, dx, axis=2) + ) + ** 2 + ) + y_term = ( + Cy + * ( + jnp.gradient(dudy, dy, axis=2) + + jnp.gradient(dvdy, dy, axis=1) + + jnp.gradient(dwdy, dy, axis=2) + ) + ** 2 + ) + z_term = ( + Cz + * ( + jnp.gradient(dudz, dz, axis=2) + + jnp.gradient(dvdz, dz, axis=1) + + jnp.gradient(dwdz, dz, axis=2) + ) + ** 2 + ) + + grad_u = x_term / dx + grad_v = y_term / dy + grad_w = z_term / dz + # Impermeability condition - grad_w[0, :, :] = 0 + grad_w.at[0, :, :].set(0) if upper_bc is True: - grad_w[-1, :, :] = 0 - y = np.stack([grad_u * Cx * 2, grad_v * Cy * 2, grad_w * Cz * 2], axis=0) - y = np.nan_to_num(y) + grad_w.at[-1, :, :].set(0) + y = jnp.stack([grad_u * Cx * 2, grad_v * Cy * 2, grad_w * Cz * 2], axis=0) + return y.flatten() @@ -494,7 +555,7 @@ def calculate_mass_continuity_gradient( if upper_bc is True: grad_w = grad_w.at[-1, :, :].set(0) y = jnp.stack([grad_u, grad_v, grad_w], axis=0) - return y.flatten().copy() + return y.flatten() def calculate_background_cost(u, v, w, weights, u_back, v_back, Cb=0.01): @@ -697,7 +758,7 @@ def calculate_vertical_vorticity_gradient( w_grad.at[0, :, :].set(0) if upper_bc is True: w_grad.at[-1, :, :].set(0) - y = np.stack([u_grad, v_grad, w_grad], axis=0) + y = jnp.stack([u_grad, v_grad, w_grad], axis=0) return y.flatten().copy() @@ -781,5 +842,5 @@ def calculate_model_gradient(u, v, w, weights, u_model, v_model, w_model, coeff= calculate_model_cost, u, v, w, weights, u_model, v_model, w_model, coeff ) u_grad, v_grad, w_grad, _, _, _, _, _ = fun_vjp(1.0) - y = np.stack([u_grad, v_grad, w_grad], axis=0) + y = jnp.stack([u_grad, v_grad, w_grad], axis=0) return y.flatten().copy() diff --git a/pydda/cost_functions/cost_functions.py b/pydda/cost_functions/cost_functions.py index 159eeb08..1db3bab2 100644 --- a/pydda/cost_functions/cost_functions.py +++ b/pydda/cost_functions/cost_functions.py @@ -12,8 +12,8 @@ from jax.config import config config.update("jax_enable_x64", True) - from jax import float0 import jax.numpy as jnp + import jax JAX_AVAILABLE = True except ImportError: @@ -279,119 +279,7 @@ def J_function(winds, parameters): else: Jpoint = 0 elif parameters.engine == "jax": - if not JAX_AVAILABLE: - raise ImportError("Jax is needed in order to use the Jax-based PyDDA!") - - winds = np.reshape( - winds, - ( - 3, - parameters.grid_shape[0], - parameters.grid_shape[1], - parameters.grid_shape[2], - ), - ) - # Had to change to float because Jax returns device array (use np.float_()) - Jvel = _cost_functions_jax.calculate_radial_vel_cost_function( - parameters.vrs, - parameters.azs, - parameters.els, - winds[0], - winds[1], - winds[2], - parameters.wts, - rmsVr=parameters.rmsVr, - weights=parameters.weights, - coeff=parameters.Co, - ) - # print("apples Jvel", Jvel) - - if parameters.Cm > 0: - # Had to change to float because Jax returns device array (use np.float_()) - Jmass = _cost_functions_jax.calculate_mass_continuity( - winds[0], - winds[1], - winds[2], - parameters.z, - parameters.dx, - parameters.dy, - parameters.dz, - coeff=parameters.Cm, - ) - else: - Jmass = 0 - - if parameters.Cx > 0 or parameters.Cy > 0 or parameters.Cz > 0: - Jsmooth = _cost_functions_jax.calculate_smoothness_cost( - winds[0], - winds[1], - winds[2], - parameters.dx, - parameters.dy, - parameters.dz, - Cx=parameters.Cx, - Cy=parameters.Cy, - Cz=parameters.Cz, - ) - else: - Jsmooth = 0 - - if parameters.Cb > 0: - Jbackground = _cost_functions_jax.calculate_background_cost( - winds[0], - winds[1], - winds[2], - parameters.bg_weights, - parameters.u_back, - parameters.v_back, - parameters.Cb, - ) - else: - Jbackground = 0 - - if parameters.Cv > 0: - # Had to change to float because Jax returns device array (use np.float_()) - Jvorticity = _cost_functions_jax.calculate_vertical_vorticity_cost( - winds[0], - winds[1], - winds[2], - parameters.dx, - parameters.dy, - parameters.dz, - parameters.Ut, - parameters.Vt, - coeff=parameters.Cv, - ) - else: - Jvorticity = 0 - - if parameters.Cmod > 0: - Jmod = _cost_functions_jax.calculate_model_cost( - winds[0], - winds[1], - winds[2], - parameters.model_weights, - parameters.u_model, - parameters.v_model, - parameters.w_model, - coeff=parameters.Cmod, - ) - else: - Jmod = 0 - - if parameters.Cpoint > 0: - Jpoint = _cost_functions_jax.calculate_point_cost( - winds[0], - winds[1], - parameters.x, - parameters.y, - parameters.z, - parameters.point_list, - Cp=parameters.Cpoint, - roi=parameters.roi, - ) - else: - Jpoint = 0 + return J_function_jax(winds, parameters) if parameters.Nfeval % 10 == 0: print( @@ -660,105 +548,7 @@ def grad_J(winds, parameters): upper_bc=parameters.upper_bc, ) elif parameters.engine == "jax": - winds = jnp.reshape( - winds, - ( - 3, - parameters.grid_shape[0], - parameters.grid_shape[1], - parameters.grid_shape[2], - ), - ) - grad = _cost_functions_jax.calculate_grad_radial_vel( - parameters.vrs, - parameters.els, - parameters.azs, - winds[0], - winds[1], - winds[2], - parameters.wts, - parameters.weights, - parameters.rmsVr, - coeff=parameters.Co, - upper_bc=parameters.upper_bc, - ) - - if parameters.Cm > 0: - grad += _cost_functions_jax.calculate_mass_continuity_gradient( - winds[0], - winds[1], - winds[2], - parameters.z, - parameters.dx, - parameters.dy, - parameters.dz, - coeff=parameters.Cm, - upper_bc=parameters.upper_bc, - ) - - if parameters.Cx > 0 or parameters.Cy > 0 or parameters.Cz > 0: - grad += _cost_functions_jax.calculate_smoothness_gradient( - winds[0], - winds[1], - winds[2], - parameters.dx, - parameters.dy, - parameters.dz, - Cx=parameters.Cx, - Cy=parameters.Cy, - Cz=parameters.Cz, - upper_bc=parameters.upper_bc, - ) - - if parameters.Cb > 0: - grad += _cost_functions_jax.calculate_background_gradient( - winds[0], - winds[1], - winds[2], - parameters.bg_weights, - parameters.u_back, - parameters.v_back, - parameters.Cb, - ) - - if parameters.Cv > 0: - grad += _cost_functions_jax.calculate_vertical_vorticity_gradient( - winds[0], - winds[1], - winds[2], - parameters.dx, - parameters.dy, - parameters.dz, - parameters.Ut, - parameters.Vt, - coeff=parameters.Cv, - upper_bc=parameters.upper_bc, - ).numpy() - - if parameters.Cmod > 0: - grad += _cost_functions_jax.calculate_model_gradient( - winds[0], - winds[1], - winds[2], - parameters.model_weights, - parameters.u_model, - parameters.v_model, - parameters.w_model, - coeff=parameters.Cmod, - ) - - if parameters.Cpoint > 0: - grad += _cost_functions_jax.calculate_point_gradient( - winds[0], - winds[1], - parameters.x, - parameters.y, - parameters.z, - parameters.point_list, - Cp=parameters.Cpoint, - roi=parameters.roi, - upper_bc=parameters.upper_bc, - ) + return grad_jax(winds, parameters) if parameters.Nfeval % 10 == 0: print("The gradient of the cost functions is", str(np.linalg.norm(grad, 2))) @@ -814,3 +604,223 @@ def calculate_fall_speed(grid, refl_field=None, frz=4500.0): print(fallspeed.max()) del A, B, rho return np.ma.masked_invalid(fallspeed) + + +def J_function_jax(winds, parameters): + if not JAX_AVAILABLE: + raise ImportError("Jax is needed in order to use the Jax-based PyDDA!") + + winds = jnp.reshape( + winds, + ( + 3, + parameters["grid_shape"][0], + parameters["grid_shape"][1], + parameters["grid_shape"][2], + ), + ) + # Had to change to float because Jax returns device array (use np.float_()) + Jvel = _cost_functions_jax.calculate_radial_vel_cost_function( + parameters["vrs"], + parameters["azs"], + parameters["els"], + winds[0], + winds[1], + winds[2], + parameters["wts"], + rmsVr=parameters["rmsVr"], + weights=parameters["weights"], + coeff=parameters["Co"], + ) + + if parameters["Cm"] > 0: + # Had to change to float because Jax returns device array (use np.float_()) + Jmass = _cost_functions_jax.calculate_mass_continuity( + winds[0], + winds[1], + winds[2], + parameters["z"], + parameters["dx"], + parameters["dy"], + parameters["dz"], + coeff=parameters["Cm"], + ) + else: + Jmass = 0 + + if parameters["Cx"] > 0 or parameters["Cy"] > 0 or parameters["Cz"] > 0: + Jsmooth = _cost_functions_jax.calculate_smoothness_cost( + winds[0], + winds[1], + winds[2], + parameters["dx"], + parameters["dy"], + parameters["dz"], + Cx=parameters["Cx"], + Cy=parameters["Cy"], + Cz=parameters["Cz"], + ) + else: + Jsmooth = 0 + + if parameters["Cb"] > 0: + Jbackground = _cost_functions_jax.calculate_background_cost( + winds[0], + winds[1], + winds[2], + parameters["bg_weights"], + parameters["u_back"], + parameters["v_back"], + parameters["Cb"], + ) + else: + Jbackground = 0 + + if parameters["Cv"] > 0: + # Had to change to float because Jax returns device array (use np.float_()) + Jvorticity = _cost_functions_jax.calculate_vertical_vorticity_cost( + winds[0], + winds[1], + winds[2], + parameters["dx"], + parameters["dy"], + parameters["dz"], + parameters["Ut"], + parameters["Vt"], + coeff=parameters["Cv"], + ) + else: + Jvorticity = 0 + + if parameters["Cmod"] > 0: + Jmod = _cost_functions_jax.calculate_model_cost( + winds[0], + winds[1], + winds[2], + parameters["model_weights"], + parameters["u_model"], + parameters["v_model"], + parameters["w_model"], + coeff=parameters["Cmod"], + ) + else: + Jmod = 0 + + if parameters["Cpoint"] > 0: + Jpoint = _cost_functions_jax.calculate_point_cost( + winds[0], + winds[1], + parameters["x"], + parameters["y"], + parameters["z"], + parameters["point_list"], + Cp=parameters["Cpoint"], + roi=parameters["roi"], + ) + else: + Jpoint = 0 + + return Jvel + Jsmooth + Jmass + Jmod + Jpoint + Jvorticity + Jbackground + + +def grad_jax(winds, parameters): + winds = jnp.reshape( + winds, + ( + 3, + parameters["grid_shape"][0], + parameters["grid_shape"][1], + parameters["grid_shape"][2], + ), + ) + grad = _cost_functions_jax.calculate_grad_radial_vel( + parameters["vrs"], + parameters["els"], + parameters["azs"], + winds[0], + winds[1], + winds[2], + parameters["wts"], + parameters["weights"], + parameters["rmsVr"], + coeff=parameters["Co"], + upper_bc=parameters["upper_bc"], + ) + + if parameters["Cm"] > 0: + grad += _cost_functions_jax.calculate_mass_continuity_gradient( + winds[0], + winds[1], + winds[2], + parameters["z"], + parameters["dx"], + parameters["dy"], + parameters["dz"], + coeff=parameters["Cm"], + upper_bc=parameters["upper_bc"], + ) + + if parameters["Cx"] > 0 or parameters["Cy"] > 0 or parameters["Cz"] > 0: + grad += _cost_functions_jax.calculate_smoothness_gradient( + winds[0], + winds[1], + winds[2], + parameters["dx"], + parameters["dy"], + parameters["dz"], + Cx=parameters["Cx"], + Cy=parameters["Cy"], + Cz=parameters["Cz"], + upper_bc=parameters["upper_bc"], + ) + + if parameters["Cb"] > 0: + grad += _cost_functions_jax.calculate_background_gradient( + winds[0], + winds[1], + winds[2], + parameters["bg_weights"], + parameters["u_back"], + parameters["v_back"], + parameters["Cb"], + ) + + if parameters["Cv"] > 0: + grad += _cost_functions_jax.calculate_vertical_vorticity_gradient( + winds[0], + winds[1], + winds[2], + parameters["dx"], + parameters["dy"], + parameters["dz"], + parameters["Ut"], + parameters["Vt"], + coeff=parameters["Cv"], + upper_bc=parameters["upper_bc"], + ).numpy() + + if parameters["Cmod"] > 0: + grad += _cost_functions_jax.calculate_model_gradient( + winds[0], + winds[1], + winds[2], + parameters["model_weights"], + parameters["u_model"], + parameters["v_model"], + parameters["w_model"], + coeff=parameters["Cmod"], + ) + + if parameters["Cpoint"] > 0: + grad += _cost_functions_jax.calculate_point_gradient( + winds[0], + winds[1], + parameters["x"], + parameters["y"], + parameters["z"], + parameters["point_list"], + Cp=parameters["Cpoint"], + roi=parameters["roi"], + upper_bc=parameters["upper_bc"], + ) + return grad diff --git a/pydda/retrieval/wind_retrieve.py b/pydda/retrieval/wind_retrieve.py index f8ff04fc..8d3973f6 100644 --- a/pydda/retrieval/wind_retrieve.py +++ b/pydda/retrieval/wind_retrieve.py @@ -26,13 +26,21 @@ try: import jax.numpy as jnp + import jax + import jaxopt JAX_AVAILABLE = True except ImportError: JAX_AVAILABLE = False # imports changed to local import path to run on computer -from ..cost_functions import J_function, grad_J, calculate_fall_speed +from ..cost_functions import ( + J_function, + grad_J, + calculate_fall_speed, + grad_jax, + J_function_jax, +) from copy import deepcopy from .angles import add_azimuth_as_field, add_elevation_as_field @@ -598,19 +606,44 @@ def _vert_velocity_callback(x): return False parameters.print_out = False + if engine.lower() == "scipy": + winds = fmin_l_bfgs_b( + J_function, + winds, + args=(parameters,), + maxiter=max_iterations, + pgtol=tolerance, + bounds=bounds, + fprime=grad_J, + disp=0, + iprint=-1, + callback=_vert_velocity_callback, + ) + else: - winds = fmin_l_bfgs_b( - J_function, - winds, - args=(parameters,), - maxiter=max_iterations, - pgtol=tolerance, - bounds=bounds, - fprime=grad_J, - disp=0, - iprint=-1, - callback=_vert_velocity_callback, - ) + def loss_and_gradient(x): + x_loss = J_function_jax(x["winds"], vars(parameters)) + x_grad = {} + x_grad["winds"] = grad_jax(x["winds"], vars(parameters)) + return x_loss, x_grad + + bounds = ( + {"winds": -100 * jnp.ones(winds.shape)}, + {"winds": 100 * jnp.ones(winds.shape)}, + ) + winds = jnp.array(winds) + solver = jaxopt.LBFGSB( + loss_and_gradient, + True, + has_aux=False, + maxiter=max_iterations, + tol=tolerance, + jit=True, + implicit_diff=False, + ) + winds = {"winds": winds} + winds, state = solver.run(winds, bounds=bounds) + winds = [np.asanyarray(winds["winds"])] winds = np.reshape( winds[0],