Skip to content

Commit

Permalink
update callback functions; UH model; Gravity is applied to particles …
Browse files Browse the repository at this point in the history
…by default instead of grid
  • Loading branch information
Retief Lubbe committed Dec 3, 2024
1 parent 451e5ca commit b5c6777
Show file tree
Hide file tree
Showing 13 changed files with 451 additions and 641 deletions.
4 changes: 1 addition & 3 deletions hydraxmpm/config/ip_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def __init__(
project: str = "",
**kwargs: Generic,
):
jax.debug.print(
"Ignore the UserWarning from, the behavior is intended and expected."
)

# total_time: jnp.float32
self.dim = dim

Expand Down
14 changes: 7 additions & 7 deletions hydraxmpm/forces/forces.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ class Forces(eqx.Module):

config: MPMConfig = eqx.field(static=True)

# def apply_on_nodes(
# self: Self,
# particles: Particles = None,
# nodes: Nodes = None,
# step: int = 0,
# ) -> Tuple[Nodes, Self]:
# return nodes, self
def apply_on_nodes(
self: Self,
particles: Particles = None,
nodes: Nodes = None,
step: int = 0,
) -> Tuple[Nodes, Self]:
return nodes, self

def apply_on_particles(
self: Self,
Expand Down
44 changes: 40 additions & 4 deletions hydraxmpm/forces/gravity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple

import chex
import jax
import equinox as eqx
import jax.numpy as jnp
from typing_extensions import Self
Expand All @@ -19,18 +20,21 @@ class Gravity(Forces):
gravity: chex.Array
increment: chex.Array
stop_ramp_step: jnp.int32

particle_gravity: bool = eqx.field(static=True, converter=lambda x: bool(x))

def __init__(
self: Self,
config: MPMConfig,
gravity: chex.Array = None,
increment: chex.Array = None,
stop_ramp_step: jnp.int32 = 0,
particle_gravity = True
) -> Self:
"""Initialize Gravity force on Nodes."""
self.gravity = gravity
self.increment = increment
self.stop_ramp_step = stop_ramp_step
self.particle_gravity = particle_gravity
super().__init__(config)

def apply_on_nodes(
Expand All @@ -41,6 +45,9 @@ def apply_on_nodes(
) -> Tuple[Nodes, Self]:
"""Apply gravity on the nodes."""

if self.particle_gravity:
return nodes, self

if self.increment is not None:
gravity = self.gravity + self.increment * jnp.minimum(
step, self.stop_ramp_step
Expand All @@ -51,12 +58,41 @@ def apply_on_nodes(
moment_gravity = nodes.mass_stack.reshape(-1, 1) * gravity * self.config.dt

new_moment_nt_stack = nodes.moment_nt_stack + moment_gravity

new_moment_stack = nodes.moment_stack + moment_gravity

new_nodes = eqx.tree_at(
lambda state: state.moment_nt_stack,
lambda state: (state.moment_nt_stack),
nodes,
new_moment_nt_stack,
(new_moment_nt_stack),
)

# self is updated if there is a gravity ramp
return new_nodes, self


def apply_on_particles(
self: Self,
particles: Particles = None,
nodes: Nodes = None,
step: int = 0,
) -> Tuple[Particles, Self]:

if not self.particle_gravity:
return particles, self

if self.increment is not None:
gravity = self.gravity + self.increment * jnp.minimum(
step, self.stop_ramp_step
)
else:
gravity = self.gravity

def get_gravitational_force(mass):
return mass*gravity

new_particles = eqx.tree_at(
lambda state: (state.force_stack),
particles,
(jax.vmap(get_gravitational_force)(particles.mass_stack)),
)
return new_particles, self
222 changes: 119 additions & 103 deletions hydraxmpm/forces/nodelevelset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from matplotlib.pyplot import sca
from typing_extensions import Self

from ..config.mpm_config import MPMConfig
Expand All @@ -27,7 +28,7 @@ def __init__(
id_stack: chex.Array = None,
velocity_stack: chex.Array = None,
mu: float = 0.0,
thickness=2,
thickness=4,
):
"""Initialize the rigid particles."""

Expand Down Expand Up @@ -91,113 +92,128 @@ def apply_on_nodes(
nodes: Nodes,
step: int = 0,
):
@partial(jax.vmap, in_axes=(0, 0))
def vmap_selected_nodes(n_id, levelset_vel):
normal = nodes.normal_stack.at[n_id].get()
moment_nt = nodes.moment_nt_stack.at[n_id].get()
mass = nodes.mass_stack.at[n_id].get()

scalar_norm = jnp.linalg.vector_norm(normal)

small_node_cut_off = mass > nodes.small_mass_cutoff

# def give_moment(normal):
# skip the nodes with small mass, due to numerical instability
# vel_nt = moment_nt / mass
# normal = normal/ scalar_norm

vel_nt = jax.lax.cond(
mass > nodes.small_mass_cutoff,
lambda x: x / mass,
lambda x: jnp.zeros_like(x),
moment_nt,
)

# normalize the normals
normal = jax.lax.cond(
mass > nodes.small_mass_cutoff,
lambda x: x / jnp.linalg.vector_norm(x),
lambda x: jnp.zeros_like(x),
normal,
)


# check if the velocity direction of the normal and apply contact
# dot product is 0 when the vectors are orthogonal
# and 1 when they are parallel
# if othogonal no contact is happening
# if parallel the contact is happening
delta_vel = vel_nt - levelset_vel

delta_vel_dot_normal = jnp.dot(delta_vel, normal)

delta_vel_padded = jnp.pad(
delta_vel,
self.config.padding,
mode="constant",
constant_values=0,
)

norm_padded = jnp.pad(
normal,
self.config.padding,
mode="constant",
constant_values=0,
)
delta_vel_cross_normal = jnp.cross(
delta_vel_padded, norm_padded
) # works only for vectors of len 3
norm_delta_vel_cross_normal = jnp.linalg.vector_norm(
delta_vel_cross_normal
)

omega = delta_vel_cross_normal / norm_delta_vel_cross_normal

mu_prime = jnp.minimum(
self.mu, norm_delta_vel_cross_normal / delta_vel_dot_normal
)

normal_cross_omega = jnp.cross(
norm_padded, omega
) # works only for vectors of len 3

tangent = (
(norm_padded + mu_prime * normal_cross_omega)
.at[: self.config.dim]
.get()
)

# sometimes tangent become nan if velocity is zero at initialization
# which causes problems
tangent = jnp.nan_to_num(tangent)

new_nodes_vel_nt = jax.lax.cond(
delta_vel_dot_normal > 0.0,
lambda x: x - delta_vel_dot_normal * tangent,
# lambda x: x - delta_vel_dot_normal*normal, # no friction debug
lambda x: x,
vel_nt,
)

node_moments_nt = new_nodes_vel_nt * mass

return node_moments_nt
@partial(jax.vmap, in_axes=(0, 0, 0, 0, 0,0), out_axes=(0, 0))
def vmap_selected_nodes(n_id, levelset_vel, moment_nt, moment, mass,normal):
# normal = nodes.normal_stack.at[n_id].get()
# scalar_norm = jnp.linalg.vector_norm(normal)

def calculate_velocity(mom):
# check if the velocity direction of the normal and apply contact
# dot product is 0 when the vectors are orthogonal
# and 1 when they are parallel
# if othogonal no contact is happening
# if parallel the contact is happening

# vel = mom / mass

vel = jax.lax.cond(
mass > nodes.small_mass_cutoff,
lambda x: x / mass,
lambda x: jnp.zeros_like(x),
mom,
)

normal_hat = jax.lax.cond(
mass > nodes.small_mass_cutoff,
lambda x: x / jnp.linalg.vector_norm(x),
lambda x: jnp.zeros_like(x),
normal,
)
normal_hat = jnp.nan_to_num(normal_hat)

# normal_hat =normal/scalar_norm

norm_padded = jnp.pad(
normal_hat,
self.config.padding,
mode="constant",
constant_values=0,
)
delta_vel = vel - levelset_vel

delta_vel_padded = jnp.pad(
delta_vel,
self.config.padding,
mode="constant",
constant_values=0,
)

delta_vel_dot_normal = jnp.dot(delta_vel, normal_hat)

delta_vel_cross_normal = jnp.cross(
delta_vel_padded, norm_padded
) # works only for vectors of len

norm_delta_vel_cross_normal = jnp.linalg.vector_norm(
delta_vel_cross_normal
)

omega = delta_vel_cross_normal / norm_delta_vel_cross_normal

mu_prime = jnp.minimum(
self.mu, norm_delta_vel_cross_normal / delta_vel_dot_normal
)

normal_cross_omega = jnp.cross(
norm_padded, omega
) # works only for vectors of len 3

tangent = (
(norm_padded + mu_prime * normal_cross_omega)
.at[: self.config.dim]
.get()
)

# sometimes tangent become nan if velocity is zero at initialization
# which causes problems
tangent = jnp.nan_to_num(tangent)

return jax.lax.cond(
(delta_vel_dot_normal > 0.0),
lambda x: x - delta_vel_dot_normal * tangent,
# lambda x: x
# - delta_vel_dot_normal
# * normal_hat, # uncomment for debug, no friction
lambda x: x,
vel,
)

vel_nt = calculate_velocity(moment_nt)

node_moment_nt = vel_nt * mass


vel = calculate_velocity(moment)
node_moment = vel * mass

return node_moment, node_moment_nt

levelset_moment_stack, levelset_moment_nt_stack = vmap_selected_nodes(
self.id_stack,
self.velocity_stack,
nodes.moment_nt_stack.at[self.id_stack].get(),
nodes.moment_stack.at[self.id_stack].get(),
nodes.mass_stack.at[self.id_stack].get(),
nodes.normal_stack.at[self.id_stack].get(),
)

# return give_moment(normal)
# return jax.lax.cond(
# small_node_cut_off,
# # * (scalar_norm > 0.0),
# give_moment,
# lambda x: jnp.zeros_like(x),
# normal
# )
new_moment_stack = nodes.moment_stack.at[self.id_stack].set(
levelset_moment_stack
)

levelset_moment_stack = vmap_selected_nodes(self.id_stack, self.velocity_stack)
new_moment_nt_stack = nodes.moment_nt_stack.at[self.id_stack].set(
levelset_moment_nt_stack
)

# new_nodes = eqx.tree_at(
# lambda state: (state.moment_nt_stack,state.moment_stack),
# nodes,
# (new_moment_nt_stack,new_moment_stack),
# )

new_nodes = eqx.tree_at(
lambda state: (state.moment_nt_stack),
nodes,
(nodes.moment_nt_stack.at[self.id_stack].set(levelset_moment_stack)),
(new_moment_nt_stack),
)

return new_nodes, self
Loading

0 comments on commit b5c6777

Please sign in to comment.