Skip to content

Commit

Permalink
Adjoint formulations and hyperparams for DP
Browse files Browse the repository at this point in the history
  • Loading branch information
RDES (DaffyDuck) committed Jun 19, 2023
1 parent 0a3d5a5 commit 2276ab6
Show file tree
Hide file tree
Showing 28 changed files with 1,682 additions and 64 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# %%

"""
Test of the Updec package on the Laplace equation with RBFs
"""

import os
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"
import time
Expand Down
File renamed without changes.
File renamed without changes.
202 changes: 202 additions & 0 deletions demos/laplace/03_laplace_with_adjoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# %%

"""
Control of Laplace equation with Direct Adjoint Looping (DAL)
"""

import jax
import jax.numpy as jnp
import optax

import matplotlib.pyplot as plt
from tqdm import tqdm

# from torch.utils.tensorboard import SummaryWriter

from updec import *

#%%

RBF = polyharmonic
MAX_DEGREE = 4


RUN_NAME = "LaplaceAdjoint"
DATAFOLDER = "../data/" + RUN_NAME +"/"
make_dir(DATAFOLDER)
# writer = SummaryWriter("runs/"+RUN_NAME)
KEY = jax.random.PRNGKey(41) ## Use same random points for all iterations

Nx = 35
Ny = Nx
LR = 1e-2
GAMMA = 1 ### LR decay rate
EPOCHS = 500



facet_types={"North":"d", "South":"d", "West":"d", "East":"d"}
train_cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, noise_key=None, support_size="max")

train_cloud.visualize_cloud(s=0.1, title="Training cloud", figsize=(5,4));

#%%

## For the cost function
north_ids = jnp.array(train_cloud.facet_nodes["North"])
xy_north = train_cloud.sorted_nodes[north_ids, :]
x_north = xy_north[:, 0]
q_cost = jax.vmap(lambda x: jnp.cos(2*jnp.pi * x))(x_north)


## Exact solution
def laplace_exact_sol(xy):
PI = jnp.pi
x, y = xy

a = 0.5 * jnp.sin(2*PI*x) * (jnp.exp(2*PI*(y-1)) + jnp.exp(2*PI*(1-y))) / jnp.cosh(2*PI)
b = jnp.cos(2*PI*x) * (jnp.exp(2*PI*y) + jnp.exp(-2*PI*y)) / (4*PI*jnp.cosh(2*PI))

return a+b

def laplace_exact_control(x):
PI = jnp.pi
return (jnp.sin(2*PI*x)/jnp.cosh(2*PI)) + (jnp.cos(2*PI*x)*jnp.tanh(2*PI)/(2*PI))


exact_sol = jax.vmap(laplace_exact_sol)(train_cloud.sorted_nodes)
exact_control = jax.vmap(laplace_exact_control)(x_north)


#%%
@Partial(jax.jit, static_argnums=[2,3])
def my_diff_operator(x, center=None, rbf=None, monomial=None, fields=None):
return nodal_laplacian(x, center, rbf, monomial)

@Partial(jax.jit, static_argnums=[2])
def my_rhs_operator(x, centers=None, rbf=None, fields=None):
return 0.


## Boundary conditions for both primal and adjoint problem
d_south = jax.jit(lambda x: jnp.sin(2*jnp.pi * x[0]))
d_east = jax.jit(lambda x: jnp.sinh(2*jnp.pi*x[1]) / (2*jnp.pi * jnp.cosh(2*jnp.pi)))
d_west = d_east


@jax.jit
def direct_simulation(bcn):
sol = pde_solver(diff_operator=my_diff_operator,
rhs_operator = my_rhs_operator,
cloud = train_cloud,
boundary_conditions = {"South":d_south, "West":d_west, "North":bcn, "East":d_east},
rbf=RBF,
max_degree=MAX_DEGREE)
return sol

@jax.jit
def adjoint_problem(u_coefs):
grad_n_y = gradient_vec(xy_north, u_coefs, train_cloud.sorted_nodes, RBF)[...,1]

sol = pde_solver(diff_operator=my_diff_operator,
rhs_operator = my_rhs_operator,
cloud = train_cloud,
boundary_conditions = {"South":d_south, "West":d_west, "North":grad_n_y-q_cost, "East":d_east},
rbf=RBF,
max_degree=MAX_DEGREE)
return sol



@jax.jit ################ TODO TODO TODO don't jitt compile this, jitt the PDE solver instead !!!!
def loss_fn(u_coeffs):
grad_n_y = gradient_vec(xy_north, u_coeffs, train_cloud.sorted_nodes, RBF)[...,1]

loss_cost = (grad_n_y - q_cost)**2
return jnp.trapz(loss_cost, x=x_north)

def grad_loss_fn(lambda_coeffs):
return gradient_vec(xy_north, lambda_coeffs, train_cloud.sorted_nodes, RBF)[...,1]


# %%

### Optimisation start ###

optimal_bcn = jnp.zeros((north_ids.shape[0]))
history_cost = []
north_mse = []

scheduler = optax.piecewise_constant_schedule(init_value=LR,
boundaries_and_scales={int(EPOCHS*0.4):0.1, int(EPOCHS*0.8):0.1})
optimiser = optax.adam(learning_rate=scheduler)
opt_state = optimiser.init(optimal_bcn)

### Optimsation start ###
for step in range(1, EPOCHS+1):

u = direct_simulation(optimal_bcn)
lamb = adjoint_problem(u.coeffs)

loss = loss_fn(u.coeffs)
grad = grad_loss_fn(lamb.coeffs)

updates, opt_state = optimiser.update(grad, opt_state, optimal_bcn)
optimal_bcn = optax.apply_updates(optimal_bcn, updates)

# writer.add_scalar('loss', float(loss), step)

north_error = jnp.mean((optimal_bcn-exact_control)**2)
history_cost.append(loss)
north_mse.append(north_error)

if step<=3 or step%100==0:
print("Epoch: %-5d InitLR: %.4f Loss: %.8f TestMSE: %.6f" % (step, LR, loss, north_error))


### Visualisation at north
ax = plot(x_north, exact_control, "-", label="Analytical", x_label=r"$x$", figsize=(5,3), ylim=(-.2,.2))
plot(x_north, optimal_bcn, "--", label="DAL", ax=ax, title=f"Optimised north solution / MSE = {north_error:.4f}");
# plt.savefig(DATAFOLDER+"bcn_"+str(step)+".png", transparent=True)


ax = plot(history_cost, label='Cost objective', x_label='epochs', title="Loss", y_scale="log");
plot(north_mse, label='Test Error at North', x_label='epochs', title="Loss", y_scale="log", ax=ax);


# %%

############# Just for fun ########## TODO do this outside the loop

optimal_conditions = {"South":d_south, "West":d_west, "North":optimal_bcn, "East":d_east}
sol = pde_solver(diff_operator=my_diff_operator,
rhs_operator = my_rhs_operator,
cloud = train_cloud,
boundary_conditions = optimal_conditions,
rbf=RBF,
max_degree=MAX_DEGREE)
# optimal_error = jnp.mean((exact_sol-sol.vals)**2)

# print("calculated sol = ", sol.vals)

### Optional visualisation of whole solution
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(6*2,5))
train_cloud.visualize_field(sol.vals, cmap="jet", projection="2d", title="Optimized solution", ax=ax1, vmin=-1, vmax=1)
# test_cloud.visualize_field(exact_sol, cmap="jet", projection="3d", title="Analytical solution", ax=ax2, vmin=-1, vmax=1)
train_cloud.visualize_field(jnp.abs(sol.vals-exact_sol), cmap="magma", projection="2d", title="Absolute error", ax=ax2, vmin=0, vmax=1);
plt.savefig(DATAFOLDER+"solution_"+str(step)+".png", transparent=True)

############# fun ends ##########



## Write to tensorboard
# hparams_dict = {"learning_rate":LR, "nb_epochs":EPOCHS, "rbf":RBF.__name__, "max_degree":MAX_DEGREE, "nb_nodes":cloud.N, "support_size":cloud.support_size}
# metrics_dict = {"metrics/mse_error_north":float(north_error)}
# writer.add_hparams(hparams_dict, metrics_dict, run_name="hp_params")
# writer.add_figure("plots", fig)
# writer.flush()
# writer.close()


# %%
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# %%

"""
Control of Laplace equation with differentiable physics
"""

import jax
import jax.numpy as jnp
import optax

import matplotlib.pyplot as plt
from tqdm import tqdm
Expand All @@ -16,7 +22,7 @@


RUN_NAME = "LaplaceDiffPhys"
DATAFOLDER = "./data/" + RUN_NAME +"/"
DATAFOLDER = "../data/" + RUN_NAME +"/"
make_dir(DATAFOLDER)
# writer = SummaryWriter("runs/"+RUN_NAME)
KEY = jax.random.PRNGKey(41) ## Use same random points for all iterations
Expand Down Expand Up @@ -102,15 +108,22 @@ def loss_fn(bcn):
history_cost = []
north_mse = []

scheduler = optax.piecewise_constant_schedule(init_value=LR,
boundaries_and_scales={int(EPOCHS*0.4):0.1, int(EPOCHS*0.8):0.1})
optimiser = optax.adam(learning_rate=scheduler)
opt_state = optimiser.init(optimal_bcn)

for step in range(1, EPOCHS+1):

### Optimsation start ###
loss, grad = grad_loss_fn(optimal_bcn)
# print("calculated grad = ", grad)
learning_rate = LR * (GAMMA**step)

optimal_bcn = optimal_bcn - grad * learning_rate
# learning_rate = LR * (GAMMA**step)
# optimal_bcn = optimal_bcn - grad * learning_rate

updates, opt_state = optimiser.update(grad, opt_state, optimal_bcn)
optimal_bcn = optax.apply_updates(optimal_bcn, updates)

# writer.add_scalar('loss', float(loss), step)

Expand All @@ -119,17 +132,17 @@ def loss_fn(bcn):
north_mse.append(north_error)

if step<=3 or step%100==0:
print("Epoch: %-5d LR: %.4f Loss: %.8f TestMSE: %.6f" % (step, learning_rate, loss, north_error))
print("Epoch: %-5d InitLR: %.4f Loss: %.8f TestError: %.6f" % (step, LR, loss, north_error))


### Visualisation at north
ax = plot(x_north, exact_control, "-", label="Ideal/Analytical", x_label=r"$x$", figsize=(5,3), ylim=(-.2,.2))
plot(x_north, optimal_bcn, "--", label="Differentiable Physics", ax=ax, title=f"Optimised north solution / MSE = {north_error:.4f}");
ax = plot(x_north, exact_control, "-", label="Analytical", x_label=r"$x$", figsize=(5,3), ylim=(-.2,.2))
plot(x_north, optimal_bcn, "--", label="Diff. Physics", ax=ax, title=f"Optimised north solution / MSE = {north_error:.4f}");
# plt.savefig(DATAFOLDER+"bcn_"+str(step)+".png", transparent=True)


ax = plot(history_cost, label='Cost objective', x_label='epochs', title="Loss", y_scale="log");
plot(north_mse, label='Test MSE', x_label='epochs', title="Loss", y_scale="log", ax=ax);
plot(north_mse, label='Test Error at North', x_label='epochs', title="Loss", y_scale="log", ax=ax);


# %%
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#%%

"""
Control of Laplace equation with PINNs (Preliminary step)
"""

import jax
import jax.numpy as jnp
import flax.linen as nn
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#%%
"""
Control of Laplace equation with PINNs (Step 1)
"""

import jax
import jax.numpy as jnp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#%%
"""
Control of Laplace equation with PINNs (Step 2)
"""

import jax
import jax.numpy as jnp
Expand Down
4 changes: 2 additions & 2 deletions demos/meshes/channel_blowing_suction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import sys


lc = 0.3
ref_io = 4 ## Refinement factor to account for Infow/Outflow
lc = 0.3 ## TODO Set this to 4 !
ref_io = 8 ## Refinement factor to account for Infow/Outflow
ref_bs = 8 ## Refinement factor to account for Blowing/Suction
box_half_length = 0.003
Lx = 1.5
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@

# %%

"""
Control of Navier Stokes equation with DAL (Direct fomulation)
"""

import jax
import jax.numpy as jnp
from jax.tree_util import Partial
Expand Down Expand Up @@ -27,7 +32,7 @@
Re = 100
Pa = 0.

NB_ITER = 15
NB_ITER = 3


# %%
Expand Down Expand Up @@ -151,7 +156,8 @@ def simulate_forward_navier_stokes(cloud_vel,
# max_degree=MAX_DEGREE)


for i in tqdm(range(NB_ITER)):
# for i in tqdm(range(NB_ITER), disable=True):
for i in range(NB_ITER):
# print("Starting iteration %d" % i)

p = interpolate_field(p_, cloud_phi, cloud_vel)
Expand Down
Loading

0 comments on commit 2276ab6

Please sign in to comment.