diff --git a/demos/Advection/01_adv_diff_periodic.py b/demos/Advection/01_adv_diff_periodic.py index 27395ac..c7443d1 100644 --- a/demos/Advection/01_adv_diff_periodic.py +++ b/demos/Advection/01_adv_diff_periodic.py @@ -34,12 +34,12 @@ MAX_DEGREE = 0 DT = 1e-4 -NB_TIMESTEPS = 50 +NB_TIMESTEPS = 100 PLOT_EVERY = 10 ## Diffusive constant K = 0.08 -VEL = jnp.array([100.0, 0.0]) +VEL = jnp.array([-100.0, 0.0]) Nx = 25 Ny = 25 @@ -77,12 +77,19 @@ def my_rhs_operator(x, centers=None, rbf=None, fields=None): ## u0 is zero everywhere except at a point in the middle -u0 = jnp.zeros(cloud.N) -source_id = int(cloud.N*0.71) -source_neighbors = jnp.array(cloud.local_supports[source_id][:cloud.N//40]) -# source_id = 0 -# source_neighbors = jnp.array(cloud.local_supports[source_id][:1]) -u0 = u0.at[source_neighbors].set(0.95) +# u0 = jnp.zeros(cloud.N) +# source_id = int(cloud.N*0.71) +# source_neighbors = jnp.array(cloud.local_supports[source_id][:cloud.N//40]) +# # source_id = 0 +# # source_neighbors = jnp.array(cloud.local_supports[source_id][:1]) +# u0 = u0.at[source_neighbors].set(0.95) + +def gaussian(x, y, x0, y0, sigma): + return jnp.exp(-((x-x0)**2 + (y-y0)**2) / (2*sigma**2)) +xy = cloud.sorted_nodes +u0 = gaussian(xy[:,0], xy[:,1], 0.75, 0.5, 1/10) + + ## Begin timestepping for 100 steps diff --git a/demos/Gray-Scott/00_gray-scott_with_rbf.py b/demos/Gray-Scott/gray-scott.py similarity index 100% rename from demos/Gray-Scott/00_gray-scott_with_rbf.py rename to demos/Gray-Scott/gray-scott.py diff --git a/demos/Wave/00_wave.py b/demos/Wave/00_wave.py new file mode 100644 index 0000000..26677e6 --- /dev/null +++ b/demos/Wave/00_wave.py @@ -0,0 +1,135 @@ +# %% +%load_ext autoreload +%autoreload 2 + +# %% + +""" +Test of the Updec package on the Wave equation with RBFs: +PDE here: https://en.wikipedia.org/wiki/Convection%E2%80%93diffusion_equation +""" + +import time + +import jax +import jax.numpy as jnp + +# jax.config.update('jax_platform_name', 'cpu') +jax.config.update("jax_enable_x64", True) + +from updec import * +# key = jax.random.PRNGKey(13) +key = None + +# from torch.utils.tensorboard import SummaryWriter +jax.numpy.set_printoptions(precision=2) + +RUN_NAME = "TempFolder" +DATAFOLDER = "./data/" + RUN_NAME +"/" +# DATAFOLDER = "demos/Advection/data/"+RUN_NAME+"/" +make_dir(DATAFOLDER) + +RBF = partial(polyharmonic, a=3) +# RBF = gaussian +MAX_DEGREE = 2 + +DT = 5e-4 +NB_TIMESTEPS = 500 +PLOT_EVERY = 5 + +## Diffusive constant +C = 1. + +Nx = 25 +Ny = 25 +SUPPORT_SIZE = "max" + +# facet_types={"South":"d", "North":"d", "West":"d", "East":"d"} +facet_types={"South":"n", "North":"n", "West":"n", "East":"n"} +cloud = SquareCloud(Nx=Nx, Ny=Ny, facet_types=facet_types, noise_key=key, support_size=SUPPORT_SIZE) + +# cloud.visualize_cloud(s=0.1, figsize=(7,6)); + +# cloud.facet_types +# cloud.facet_nodes +# print("Local supports:", cloud.local_supports[0]) +print(cloud.Np) +# print(jnp.flip(cloud.global_indices.T, axis=0)) +# cloud.print_global_indices() +# print(cloud.sorted_nodes) +# cloud.sorted_outward_normals +# cloud.outward_normals + +# %% + +def my_diff_operator(x, center=None, rbf=None, monomial=None, fields=None): + val = nodal_value(x, center, rbf, monomial) + lap = nodal_laplacian(x, center, rbf, monomial) + return (val/(DT**2)) + C*lap + +def my_rhs_operator(x, centers=None, rbf=None, fields=None): + u_prev = value(x, fields[:,0], centers, rbf) + u_prev_prev = value(x, fields[:,1], centers, rbf) + return (2*u_prev - u_prev_prev)/(DT**2) + +d_zero = lambda x: 0. +boundary_conditions = {"South":d_zero, "West":d_zero, "North":d_zero, "East":d_zero} + +## Uo is a 2D gaussian centered at the middle of the domain +def gaussian(x, y, x0, y0, sigma): + return jnp.exp(-((x-x0)**2 + (y-y0)**2) / (2*sigma**2)) +xy = cloud.sorted_nodes +u0 = gaussian(xy[:,0], xy[:,1], 0.85, 0.85, 1/10) + +## Begin timestepping for 100 steps +cloud.visualize_field(u0, cmap="coolwarm", title=f"Step {0}", vmin=0, vmax=1, figsize=(6,6),colorbar=False); + + +# %% +ulist = [u0, 1*DT + u0] + +start = time.time() + +for i in range(1, NB_TIMESTEPS+1): + uprev = ulist[-1] + uprevprev = ulist[-2] + + ufield = pde_solver_jit(diff_operator=my_diff_operator, + rhs_operator = my_rhs_operator, + rhs_args=[uprev, uprevprev], + cloud = cloud, + boundary_conditions = boundary_conditions, + rbf=RBF, + max_degree=MAX_DEGREE,) + ulist.append(ufield.vals) + + # if i<=3 or i%PLOT_EVERY==0: + # print(f"Step {i}") + # # plt.cla() + # # cloud.visualize_field(ulist[-1], cmap="coolwarm", projection="3d", title=f"Step {i}") + # ax, _ = cloud.visualize_field(ulist[-1], cmap="coolwarm", title=f"Step {i}", vmin=None, vmax=None, figsize=(6,6),colorbar=False) + # # plt.draw() + # plt.show() + + +walltime = time.time() - start + +minutes = walltime // 60 % 60 +seconds = walltime % 60 +print(f"Walltime: {minutes} minutes {seconds:.2f} seconds") + + + +# %% + +## Clip ulist arrays between -1 and 1 +ulist = [jnp.clip(u, -1, 1) for u in ulist] + +filename = DATAFOLDER + "wave.gif" +cloud.animate_fields([ulist], cmaps="coolwarm", filename=filename, figsize=(7.5,6), titles=["Wave with RBFs"]); + + + +# %% + +# ulist diff --git a/demos/Wave/data/TempFolder/wave.gif b/demos/Wave/data/TempFolder/wave.gif new file mode 100644 index 0000000..32af951 Binary files /dev/null and b/demos/Wave/data/TempFolder/wave.gif differ diff --git a/updec/assembly.py b/updec/assembly.py index 7840fde..cbbc4ad 100644 --- a/updec/assembly.py +++ b/updec/assembly.py @@ -272,7 +272,7 @@ def bdPhi_r_body_func(i, bdPhi): # print("BdPhi:", bdPhi) def bdPhi_pd_body_func(i, vals): - bdPhi, nb_conds = vals + bdPhi, nb_conds, jump_points = vals support1 = cloud.sorted_local_supports[i] support2 = cloud.sorted_local_supports[i+nb_conds] @@ -280,56 +280,112 @@ def bdPhi_pd_body_func(i, vals): vals1 = rbf_vec(nodes[i], nodes[support]) vals2 = rbf_vec(nodes[i+nb_conds], nodes[support]) + # jax.debug.print("Nodes 1 and 2: \n {} {} \n", nodes[i], nodes[i+nb_conds]) - return bdPhi.at[i-Ni, support].set(vals1-vals2), nb_conds + return bdPhi.at[i-Ni-jump_points, support].set(vals1-vals2), nb_conds, jump_points Np_ = Ni+Nd+Nn+Nr + jump_points = 0 for nb_p_points in Np: nb_conds = nb_p_points//2 - bdPhi, _ = jax.lax.fori_loop(Np_, Np_+nb_conds, bdPhi_pd_body_func, (bdPhi, nb_conds)) - # print("BdPhi:", bdPhi) - Np_ += nb_conds - - + bdPhi, _, _ = jax.lax.fori_loop(Np_, Np_+nb_conds, bdPhi_pd_body_func, (bdPhi, nb_conds, jump_points)) + Np_ += nb_p_points + jump_points += nb_conds + # print("BdPhi:", bdPhi) def bdPhi_pn_body_func(i, vals): - bdPhi, nb_conds, jump_normals = vals + bdPhi, nb_conds, jump_points = vals - support1 = cloud.sorted_local_supports[i+jump_normals-sum(Np)//2] - support2 = cloud.sorted_local_supports[i+nb_conds+jump_normals-sum(Np)//2] + support1 = cloud.sorted_local_supports[i] + support2 = cloud.sorted_local_supports[i+nb_conds] support = jnp.concatenate((support1, support2)) - grads1 = jnp.nan_to_num(grad_rbf_vec(nodes[i+jump_normals-sum(Np)//2], nodes[support]), neginf=0., posinf=0.) - grads2 = jnp.nan_to_num(grad_rbf_vec(nodes[i+nb_conds+jump_normals-sum(Np)//2], nodes[support]), neginf=0., posinf=0.) + grads1 = jnp.nan_to_num(grad_rbf_vec(nodes[i], nodes[support]), neginf=0., posinf=0.) + grads2 = jnp.nan_to_num(grad_rbf_vec(nodes[i+nb_conds], nodes[support]), neginf=0., posinf=0.) + + # jax.debug.print("Nodes 1 and 2: \n {} {} \n", nodes[i+jump_normals-sum(Np)//2], nodes[i+nb_conds+jump_normals-sum(Np)//2]) # grads1 = grad_rbf_vec(nodes[i], nodes[support]) # grads2 = grad_rbf_vec(nodes[i+nb_conds], nodes[support]) if hasattr(cloud, "sorted_outward_normals"): - normals1 = cloud.sorted_outward_normals[i-Ni-Nd+jump_normals-sum(Np)//2] - normals2 = cloud.sorted_outward_normals[i-Ni-Nd+jump_normals-sum(Np)//2+nb_conds] + normals1 = cloud.sorted_outward_normals[i-Ni-Nd] + normals2 = cloud.sorted_outward_normals[i-Ni-Nd+nb_conds] # normals1 = cloud.outward_normals[i] # normals2 = cloud.outward_normals[i+nb_conds] + # jax.debug.print("Normals 1 and 2: {} {} \n", normals1, normals2) else: normals1 = jnp.zeros((DIM,)) normals2 = jnp.zeros((DIM,)) + # print("Shapes before dot:", grads1.shape, normals1.shape, grads2.shape, normals2.shape) diff_grads = jnp.dot(grads1, normals1) - jnp.dot(grads2, -normals2) + # jax.debug.print("Dots 1 and 2: {}\n {} \n", i-Ni-jump_points+sum(Np)//2, jnp.dot(grads1, normals1)-jnp.dot(grads2, -normals2)) + + # print("SHupport shape", support.shape) + # jax.debug.print("Support: \n {} \n", support) + # diff_grads = jnp.zeros_like(jnp.dot(grads1, normals1)) - return bdPhi.at[i-Ni, support].set(diff_grads), nb_conds, jump_normals + # diff_grads = diff_grads / 10. + # diff_grads = jnp.clip(diff_grads, -2., 2.) + # jax.debug.print("Current positions to fill: {} \n {} \n", i-Ni-jump_normals+sum(Np)//2) + return bdPhi.at[i-Ni-jump_points+sum(Np)//2, support].set(diff_grads), nb_conds, jump_points + - Np_ = Ni+Nd+Nn+Nr + sum(Np)//2 - jump_normals = 0 + Np_ = Ni+Nd+Nn+Nr + jump_points = 0 for nb_p_points in Np: nb_conds = nb_p_points//2 - bdPhi, _, _ = jax.lax.fori_loop(Np_, Np_+nb_conds, bdPhi_pn_body_func, (bdPhi, nb_conds, jump_normals)) - Np_ += nb_conds - jump_normals += nb_p_points + bdPhi, _, _ = jax.lax.fori_loop(Np_, Np_+nb_conds, bdPhi_pn_body_func, (bdPhi, nb_conds, jump_points)) + Np_ += nb_p_points + jump_points += nb_conds # print("Final value of Np_:", Np_) - # print("BdPhi:", bdPhi) + # print("BdPhi:", bdPhi) + + + + # def bdPhi_pn_body_func(i, vals): + # bdPhi, nb_conds, jump_normals = vals + + # support1 = cloud.sorted_local_supports[i+jump_normals-sum(Np)//2] + # support2 = cloud.sorted_local_supports[i+nb_conds+jump_normals-sum(Np)//2] + # support = jnp.concatenate((support1, support2)) + + # grads1 = jnp.nan_to_num(grad_rbf_vec(nodes[i+jump_normals-sum(Np)//2], nodes[support]), neginf=0., posinf=0.) + # grads2 = jnp.nan_to_num(grad_rbf_vec(nodes[i+nb_conds+jump_normals-sum(Np)//2], nodes[support]), neginf=0., posinf=0.) + + # # jax.debug.print("Nodes 1 and 2: \n {} {} \n", nodes[i+jump_normals-sum(Np)//2], nodes[i+nb_conds+jump_normals-sum(Np)//2]) + + # # grads1 = grad_rbf_vec(nodes[i], nodes[support]) + # # grads2 = grad_rbf_vec(nodes[i+nb_conds], nodes[support]) + + # if hasattr(cloud, "sorted_outward_normals"): + # normals1 = cloud.sorted_outward_normals[i-Ni-Nd+jump_normals-sum(Np)//2] + # normals2 = cloud.sorted_outward_normals[i-Ni-Nd+jump_normals-sum(Np)//2+nb_conds] + # # normals1 = cloud.outward_normals[i] + # # normals2 = cloud.outward_normals[i+nb_conds] + # else: + # normals1 = jnp.zeros((DIM,)) + # normals2 = jnp.zeros((DIM,)) + + # diff_grads = jnp.dot(grads1, normals1) - jnp.dot(grads2, -normals2) + # # diff_grads = jnp.zeros_like(jnp.dot(grads1, normals1)) + # # diff_grads = diff_grads / 10. + # # diff_grads = jnp.clip(diff_grads, -2., 2.) + # return bdPhi.at[i-Ni, support].set(diff_grads), nb_conds, jump_normals + + # Np_ = Ni+Nd+Nn+Nr + sum(Np)//2 + # jump_normals = 0 + # for nb_p_points in Np: + # nb_conds = nb_p_points//2 + # bdPhi, _, _ = jax.lax.fori_loop(Np_, Np_+nb_conds, bdPhi_pn_body_func, (bdPhi, nb_conds, jump_normals)) + # Np_ += nb_conds + # jump_normals += nb_p_points + # # print("Final value of Np_:", Np_) + # # print("BdPhi:", bdPhi) @@ -405,33 +461,107 @@ def bdPhi_pn_body_func(i, vals): # print("BdP:\n", bdP) + # print("All boundary nodes:", nodes[16:, :]) + + + + + + + # if len(Np) > 0: + # Np_ = Ni+Nd+Nn+Nr + # for nb_p_points in Np: + + # # node_ids_pd1 = [k for k,v in cloud.node_types.items() if v[:-1] == n_type] + + # nb_conds = nb_p_points//2 + # node_ids_pd1 = jnp.arange(Np_, Np_+nb_conds) + # node_ids_pd2 = jnp.arange(Np_+nb_conds, Np_+nb_p_points) + + # # print("These are the nodes 1:", nodes[node_ids_pd1]) + # # print("These are the nodes 2:", nodes[node_ids_pd2]) + + # for j in range(M): + # monomial_vec = jax.vmap(monomials[j], in_axes=(0,), out_axes=0) + # diff = monomial_vec(nodes[node_ids_pd1]) - monomial_vec(nodes[node_ids_pd2]) + # bdP = bdP.at[node_ids_pd1-Ni, j].set(diff) + # # print("BdP:\n", bdP) + + # Np_ += nb_p_points + + # jax.debug.print("BdP: \n {} \n", bdP) + + # half_Np = sum(Np)//2 + # Np_ = Ni+Nd+Nn+Nr + half_Np + # for nb_p_points in Np: + # nb_conds = nb_p_points//2 + # node_ids_pn1 = range(Np_-half_Np, Np_+nb_conds-half_Np) + # node_ids_pn2 = range(Np_+nb_conds-half_Np, Np_+nb_p_points-half_Np) + + # # print("These are the nodes 1:", nodes[jnp.array(node_ids_pn1)]) + # # print("These are the nodes 2:", nodes[jnp.array(node_ids_pn2)]) + + # normals_pn1 = jnp.stack([cloud.outward_normals[i] for i in node_ids_pn1], axis=0) + # normals_pn2 = jnp.stack([cloud.outward_normals[i] for i in node_ids_pn2], axis=0) + + # node_ids_pn1 = jnp.array(node_ids_pn1) + # node_ids_pn2 = jnp.array(node_ids_pn2) + + # dot_vec = jax.vmap(jnp.dot, in_axes=(0,0), out_axes=0) + # for j in range(M): + # grad_monomial = jax.grad(monomials[j]) + # grad_monomial_vec = jax.vmap(grad_monomial, in_axes=(0,), out_axes=0) + # grads1 = grad_monomial_vec(nodes[node_ids_pn1]) + # grads2 = grad_monomial_vec(nodes[node_ids_pn2]) + # diff_grads = dot_vec(grads1, normals_pn1) - dot_vec(grads2, -normals_pn2) + # bdP = bdP.at[node_ids_pn1-Ni+nb_conds, j].set(diff_grads) + + # Np_ += nb_p_points + # # print("BdP:\n", bdP) + + # # print("Numper of nodes of each type:", Nd, Nn, Nr, Np) + # jax.debug.print("BdP: \n {} \n", bdP) + + + + + + if len(Np) > 0: + jump_points = Ni+Nd+Nn+Nr Np_ = Ni+Nd+Nn+Nr for nb_p_points in Np: # node_ids_pd1 = [k for k,v in cloud.node_types.items() if v[:-1] == n_type] nb_conds = nb_p_points//2 - node_ids_pd1 = jnp.arange(Np_, Np_+nb_conds) - node_ids_pd2 = jnp.arange(Np_+nb_conds, Np_+nb_p_points) + node_ids_pd1 = jnp.arange(jump_points, jump_points+nb_conds) + node_ids_pd2 = jnp.arange(jump_points+nb_conds, jump_points+nb_p_points) + + # print("These are the nodes 1:", nodes[node_ids_pd1]) + # print("These are the nodes 2:", nodes[node_ids_pd2]) for j in range(M): monomial_vec = jax.vmap(monomials[j], in_axes=(0,), out_axes=0) diff = monomial_vec(nodes[node_ids_pd1]) - monomial_vec(nodes[node_ids_pd2]) - bdP = bdP.at[node_ids_pd1-Ni, j].set(diff) + bdP = bdP.at[jnp.arange(Np_,Np_+nb_conds)-Ni, j].set(diff) # print("BdP:\n", bdP) Np_ += nb_conds + jump_points += nb_p_points + + # jax.debug.print("BdP: \n {} \n", bdP) half_Np = sum(Np)//2 - Np_ = Ni+Nd+Nn+Nr + half_Np + jump_points = Ni+Nd+Nn+Nr + Np_ = Ni+Nd+Nn+Nr for nb_p_points in Np: nb_conds = nb_p_points//2 - node_ids_pn1 = range(Np_-half_Np, Np_+nb_conds-half_Np) - node_ids_pn2 = range(Np_+nb_conds-half_Np, Np_+nb_p_points-half_Np) + node_ids_pn1 = range(jump_points, jump_points+nb_conds) + node_ids_pn2 = range(jump_points+nb_conds, jump_points+nb_p_points) - # print("Node ids pn1:", node_ids_pn1) - # print("Node ids pn2:", node_ids_pn2) + # print("These are the nodes 1:", nodes[jnp.array(node_ids_pn1)]) + # print("These are the nodes 2:", nodes[jnp.array(node_ids_pn2)]) normals_pn1 = jnp.stack([cloud.outward_normals[i] for i in node_ids_pn1], axis=0) normals_pn2 = jnp.stack([cloud.outward_normals[i] for i in node_ids_pn2], axis=0) @@ -445,12 +575,21 @@ def bdPhi_pn_body_func(i, vals): grad_monomial_vec = jax.vmap(grad_monomial, in_axes=(0,), out_axes=0) grads1 = grad_monomial_vec(nodes[node_ids_pn1]) grads2 = grad_monomial_vec(nodes[node_ids_pn2]) - diff_grads = dot_vec(grads1, normals_pn1) - dot_vec(grads2, normals_pn2) - bdP = bdP.at[node_ids_pn1-Ni+half_Np, j].set(diff_grads) + diff_grads = dot_vec(grads1, normals_pn1) - dot_vec(grads2, -normals_pn2) + bdP = bdP.at[jnp.arange(Np_,Np_+nb_conds)-Ni+half_Np, j].set(diff_grads) Np_ += nb_conds + jump_points += nb_p_points # print("BdP:\n", bdP) + # print("Numper of nodes of each type:", Nd, Nn, Nr, Np) + # jax.debug.print("BdP again: \n {} \n", bdP) + + + + + + return bdPhi, bdP diff --git a/updec/cloud.py b/updec/cloud.py index afe9893..0494f60 100644 --- a/updec/cloud.py +++ b/updec/cloud.py @@ -27,7 +27,17 @@ def __init__(self, facet_types, support_size="max"): self.facet_precedence = {k:i for i,(k,v) in enumerate(facet_types.items())} ## Facet order of precedence usefull for corner nodes membership ## For each periodic facet type we encounter, we append a letter of the alphabet to it. This is useful for renumbering nodes; and clean for the user. - self.facet_types = {k:v+str(i) for i,(k,v) in enumerate(facet_types.items())} + # self.facet_types = {k:v+str(i) for i,(k,v) in enumerate(facet_types.items()) if v[0]=="p" else k:v} + ## Use for look here + new_facet_types = {} + for i, (k,v) in enumerate(facet_types.items()): + if v[0]=="p": + new_facet_types[k] = v+str(i) + else: + new_facet_types[k] = v + self.facet_types = new_facet_types + + print("Facet types:", self.facet_types) def print_global_indices(self): print(jnp.flip(self.global_indices.T, axis=0)) @@ -336,6 +346,19 @@ def animate(frame): + + + + + + + + + + + + + class SquareCloud(Cloud): def __init__(self, Nx=7, Ny=5, noise_key=None, **kwargs): super().__init__(**kwargs) @@ -430,7 +453,8 @@ def define_node_types(self): self.Nd = 0 self.Nn = 0 self.Nr = 0 - self.Np = {v[:-1]:0 for v in self.facet_types.values()} ## Number of nodes per periodic boundary + self.Np = {v[:-1]:0 for v in self.facet_types.values() if v[0]=="p"} ## Number of nodes per periodic boundary + # print("I got here:", self.Np) for f_id, f_type in self.facet_types.items(): if f_type == "d": @@ -440,9 +464,11 @@ def define_node_types(self): if f_type == "r": self.Nr += len(self.facet_nodes[f_id]) if f_type[0] == "p": + # print("I got here:", f_type, self.facet_nodes[f_id]) # print("Details here:", f_type, self.facet_nodes[f_id]) self.Np[f_type[:-1]] += len(self.facet_nodes[f_id]) ## Get periodic count as a list sorted by keys + # print("I got here:", self.Np) self.Np = [self.Np[k] for k in sorted(self.Np.keys())] self.Ni = self.N - self.Nd - self.Nn - self.Nr - sum(self.Np) @@ -478,6 +504,16 @@ def define_outward_normals(self): + + + + + + + + + + class GmshCloud(Cloud): """ Parses gmsh format 4.0.8, not the newer version """ diff --git a/updec/utils.py b/updec/utils.py index 9a418b6..3dc6e01 100644 --- a/updec/utils.py +++ b/updec/utils.py @@ -28,11 +28,11 @@ ## Euclidian distance def distance(node1, node2): - # diff = node1 - node2 + diff = node1 - node2 # return jnp.sum(diff*diff) ## Squared distance - return jnp.linalg.norm(node1 - node2) ## Carefull: not differentiable at 0 # return periodic_distance_squre(node1, node2, 1., 1.) - + # return jnp.linalg.norm(node1 - node2) ## Carefull: not differentiable at 0 + return jnp.sqrt(diff.T @ diff) def print_line_by_line(dictionary):