Skip to content

Commit

Permalink
Add Newtonian material
Browse files Browse the repository at this point in the history
  • Loading branch information
chahak13 committed Jul 15, 2023
1 parent 0f2e190 commit da3fbe2
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 1 deletion.
1 change: 1 addition & 0 deletions diffmpm/materials/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from diffmpm.materials._base import _Material
from diffmpm.materials.simple import SimpleMaterial
from diffmpm.materials.linear_elastic import LinearElastic
from diffmpm.materials.newtonian import Newtonian
115 changes: 114 additions & 1 deletion diffmpm/materials/newtonian.py
Original file line number Diff line number Diff line change
@@ -1 +1,114 @@
#!/usr/bin/env python3
import jax.numpy as jnp
from jax import Array, lax
from jax.typing import ArrayLike

from ._base import _Material


class Newtonian(_Material):
"""Newtonian fluid material model."""

_props = ("density", "bulk_modulus", "dynamic_viscosity")
state_vars = ("pressure",)

def __init__(self, material_properties: dict):
"""Create a Newtonian material.
Parameters
----------
material_properties: dict
Dictionary with material properties. For newtonian
materials, `density`, `bulk_modulus` and `dynamic_viscosity`
are required keys.
"""
self.validate_props(material_properties)
compressibility = 1

if material_properties.get("incompressible", False):
compressibility = 0

self.properties = {
**material_properties,
"compressibility": compressibility,
}

def __repr__(self):
return f"Newtonian(props={self.properties})"

def initialize_state_variables(self, nparticles: int) -> dict:
"""Return initial state variables dictionary.
Parameters
----------
nparticles : int
Number of particles being simulated with this material.
Returns
-------
dict
Dictionary of state variables initialized with values
decided by material type.
"""
state_vars_dict = {var: jnp.zeros((nparticles, 1)) for var in self.state_vars}
return state_vars_dict

def _thermodynamic_pressure(self, volumetric_strain: ArrayLike) -> Array:
return -self.properties["bulk_modulus"] * volumetric_strain

def compute_stress(self, particles):
"""Compute material stress."""
ndim = particles.loc.shape[-1]
if ndim not in {2, 3}:
raise ValueError(f"Cannot compute stress for {ndim}-d Newotonian material.")
volumetric_strain_rate = (
particles.strain_rate[:, 0] + particles.strain_rate[:, 1]
)
particles.state_vars["pressure"] = (
particles.state_vars["pressure"]
.at[:]
.add(
self.properties["compressibility"]
* self._thermodynamic_pressure(particles.dvolumetric_strain)
)
)

volumetric_stress_component = self.properties["compressibility"] * (
-particles.state_vars["pressure"]
- (2 * self.properties["dynamic_viscosity"] * volumetric_strain_rate / 3)
)

stress = jnp.zeros_like(particles.stress)
stress = stress.at[:, 0].set(
volumetric_stress_component
+ 2 * self.properties["dynamic_viscosity"] * particles.strain_rate[:, 0]
)
stress = stress.at[:, 1].set(
volumetric_stress_component
+ 2 * self.properties["dynamic_viscosity"] * particles.strain_rate[:, 1]
)

extra_component_2 = lax.select(
ndim == 3,
2 * self.properties["dynamic_viscosity"] * particles.strain_rate[:, 2],
jnp.zeros_like(particles.strain_rate[:, 2]),
)
stress = stress.at[:, 2].set(volumetric_stress_component + extra_component_2)

stress = stress.at[:, 3].set(
self.properties["dynamic_viscosity"] * particles.strain_rate[:, 3]
)

component_4 = lax.select(
ndim == 3,
self.properties["dynamic_viscosity"] * particles.strain_rate[:, 4],
jnp.zeros_like(particles.strain_rate[:, 4]),
)
stress = stress.at[:, 4].set(component_4)
component_5 = lax.select(
ndim == 3,
self.properties["dynamic_viscosity"] * particles.strain_rate[:, 5],
jnp.zeros_like(particles.strain_rate[:, 5]),
)
stress = stress.at[:, 5].set(component_5)

return stress
90 changes: 90 additions & 0 deletions tests/test_newtonian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import jax.numpy as jnp
import pytest
from diffmpm.constraint import Constraint
from diffmpm.element import Quadrilateral4Node
from diffmpm.materials import Newtonian
from diffmpm.node import Nodes
from diffmpm.particle import Particles

particles_element_targets = [
(
Particles(
jnp.array([[0.5, 0.5]]).reshape(1, 1, 2),
Newtonian(
{
"density": 1000,
"bulk_modulus": 8333333.333333333,
"dynamic_viscosity": 8.9e-4,
}
),
jnp.array([0]),
),
Quadrilateral4Node(
(1, 1),
1,
(4.0, 4.0),
[(0, Constraint(0, 0.02)), (0, Constraint(1, 0.03))],
Nodes(4, jnp.array([-2, -2, 2, -2, -2, 2, 2, 2]).reshape((4, 1, 2))),
),
jnp.array(
[
-52083.3333338896,
-52083.3333355583,
-52083.3333305521,
-0.0000041719,
0,
0,
]
).reshape(1, 6, 1),
),
(
Particles(
jnp.array([[0.5, 0.5]]).reshape(1, 1, 2),
Newtonian(
{
"density": 1000,
"bulk_modulus": 8333333.333333333,
"dynamic_viscosity": 8.9e-4,
"incompressible": True,
}
),
jnp.array([0]),
),
Quadrilateral4Node(
(1, 1),
1,
(4.0, 4.0),
[(0, Constraint(0, 0.02)), (0, Constraint(1, 0.03))],
Nodes(4, jnp.array([-2, -2, 2, -2, -2, 2, 2, 2]).reshape((4, 1, 2))),
),
jnp.array(
[
-0.0000033375,
-0.00000500625,
0,
-0.0000041719,
0,
0,
]
).reshape(1, 6, 1),
),
]


@pytest.mark.parametrize(
"particles, element, target",
particles_element_targets,
)
def test_compute_stress(particles, element, target):
dt = 1
particles.update_natural_coords(element)
if element.constraints:
element.apply_boundary_constraints()
particles.compute_strain(element, dt)
stress = particles.material.compute_stress(particles)
assert jnp.allclose(stress, target)


def test_init():
with pytest.raises(KeyError):
Newtonian({"dynamic_viscosity": 1, "density": 1})

0 comments on commit da3fbe2

Please sign in to comment.