Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update loops to use jax calls and JIT some functions #27

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/2d/uniaxial_nodal_forces/mpm-nodal-forces.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type = "generator"
nelements = [3, 1]
element_length = [0.1, 0.1]
particle_element_ids = [0]
element = "Quadrilateral4Node"
element = "Quad4N"
entity_sets = "entity_sets.json"

[[mesh.constraints]]
Expand All @@ -46,7 +46,7 @@ id = 0
density = 1000
poisson_ratio = 0
youngs_modulus = 1000000
type = "LinearElastic"
type = "linear_elastic"

[[particles]]
file = "particles-2d-nodal-force.json"
Expand Down
9 changes: 8 additions & 1 deletion benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
from pathlib import Path

import jax

jax.config.update("jax_platform_name", "cpu")
import jax.numpy as jnp

from diffmpm import MPM
from diffmpm.mpm import MPM


def test_benchmarks():
Expand Down Expand Up @@ -32,3 +35,7 @@ def test_benchmarks():
result = jnp.load("results/uniaxial-nodal-forces/particles_0990.npz")
assert jnp.round(result["stress"][0, :, 0].min() - 0.9999990078443788, 5) == 0.0
assert jnp.round(result["stress"][0, :, 0].max() - 0.9999990292713694, 5) == 0.0


if __name__ == "__main__":
test_benchmarks()
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type = "generator"
nelements = [3, 1]
element_length = [0.1, 0.1]
particle_element_ids = [0]
element = "Quadrilateral4Node"
element = "Quad4N"
entity_sets = "entity_sets.json"

[[mesh.constraints]]
Expand All @@ -46,7 +46,7 @@ id = 0
density = 1000
poisson_ratio = 0
youngs_modulus = 1000000
type = "LinearElastic"
type = "linear_elastic"

[[particles]]
file = "particles-2d-particle-traction.json"
Expand Down
9 changes: 8 additions & 1 deletion benchmarks/2d/uniaxial_particle_traction/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
from pathlib import Path

import jax

jax.config.update("jax_platform_name", "cpu")
import jax.numpy as jnp

from diffmpm import MPM
from diffmpm.mpm import MPM


def test_benchmarks():
Expand Down Expand Up @@ -31,3 +34,7 @@ def test_benchmarks():
result = jnp.load("results/uniaxial-particle-traction/particles_0990.npz")
assert jnp.round(result["stress"][0, :, 0].min() - 0.750002924022295, 5) == 0.0
assert jnp.round(result["stress"][0, :, 0].max() - 0.9999997782938734, 5) == 0.0


if __name__ == "__main__":
test_benchmarks()
4 changes: 2 additions & 2 deletions benchmarks/2d/uniaxial_stress/mpm-uniaxial-stress.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type = "generator"
nelements = [1, 1]
element_length = [1, 1]
particle_element_ids = [0]
element = "Quadrilateral4Node"
element = "Quad4N"
entity_sets = "entity_sets.json"

[[mesh.constraints]]
Expand All @@ -47,7 +47,7 @@ id = 0
density = 1
poisson_ratio = 0
youngs_modulus = 1000
type = "LinearElastic"
type = "linear_elastic"

[[particles]]
file = "particles-2d-uniaxial-stress.json"
Expand Down
10 changes: 9 additions & 1 deletion benchmarks/2d/uniaxial_stress/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import os
from pathlib import Path

import jax

jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp

from diffmpm import MPM
from diffmpm.mpm import MPM


def test_benchmarks():
Expand All @@ -19,3 +23,7 @@ def test_benchmarks():

assert jnp.round(result["stress"][0, :, 1].max() - true_stress_yy, 8) == 0.0
assert jnp.round(result["stress"][0, :, 0].max() - true_stress_xx, 8) == 0.0


if __name__ == "__main__":
test_benchmarks()
44 changes: 1 addition & 43 deletions diffmpm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,5 @@
from importlib.metadata import version
from pathlib import Path

import diffmpm.writers as writers
from diffmpm.io import Config
from diffmpm.solver import MPMExplicit

__all__ = ["MPM", "__version__"]
__all__ = ["__version__"]

__version__ = version("diffmpm")


class MPM:
def __init__(self, filepath):
self._config = Config(filepath)
mesh = self._config.parse()
out_dir = Path(self._config.parsed_config["output"]["folder"]).joinpath(
self._config.parsed_config["meta"]["title"],
)

write_format = self._config.parsed_config["output"].get("format", None)
if write_format is None or write_format.lower() == "none":
writer_func = None
elif write_format == "npz":
writer_func = writers.NPZWriter().write
else:
raise ValueError(f"Specified output format not supported: {write_format}")

if self._config.parsed_config["meta"]["type"] == "MPMExplicit":
self.solver = MPMExplicit(
mesh,
self._config.parsed_config["meta"]["dt"],
velocity_update=self._config.parsed_config["meta"]["velocity_update"],
sim_steps=self._config.parsed_config["meta"]["nsteps"],
out_steps=self._config.parsed_config["output"]["step_frequency"],
out_dir=out_dir,
writer_func=writer_func,
)
else:
raise ValueError("Wrong type of solver specified.")

def solve(self):
"""Solve the MPM simulation using JIT solver."""
arrays = self.solver.solve_jit(
self._config.parsed_config["external_loading"]["gravity"],
)
return arrays
2 changes: 1 addition & 1 deletion diffmpm/cli/mpm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import click

from diffmpm import MPM
from diffmpm.mpm import MPM


@click.command() # type: ignore
Expand Down
39 changes: 32 additions & 7 deletions diffmpm/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,44 @@ def tree_unflatten(cls, aux_data, children):
del children
return cls(*aux_data)

def apply(self, obj, ids):
def apply_vel(self, vel, ids):
"""Apply constraint values to the passed object.

Parameters
----------
obj : diffmpm.node.Nodes, diffmpm.particle.Particles
obj : diffmpm.node.Nodes, diffmpm.particle._ParticlesState

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be vel in the docstring or obj in the function, same issue with all the functions.

Object on which the constraint is applied
ids : array_like
The indices of the container `obj` on which the constraint
will be applied.
"""
obj.velocity = obj.velocity.at[ids, :, self.dir].set(self.velocity)
obj.momentum = obj.momentum.at[ids, :, self.dir].set(
obj.mass[ids, :, 0] * self.velocity
)
obj.acceleration = obj.acceleration.at[ids, :, self.dir].set(0)
velocity = vel.at[ids, :, self.dir].set(self.velocity)
return velocity

def apply_mom(self, mom, mass, ids):
"""Apply constraint values to the passed object.

Parameters
----------
obj : diffmpm.node.Nodes, diffmpm.particle._ParticlesState
Object on which the constraint is applied
ids : array_like
The indices of the container `obj` on which the constraint
will be applied.
"""
momentum = mom.at[ids, :, self.dir].set(mass[ids, :, 0] * self.velocity)
return momentum

def apply_acc(self, acc, ids):
"""Apply constraint values to the passed object.

Parameters
----------
obj : diffmpm.node.Nodes, diffmpm.particle._ParticlesState
Object on which the constraint is applied
ids : array_like
The indices of the container `obj` on which the constraint
will be applied.
"""
acceleration = acc.at[ids, :, self.dir].set(0)
return acceleration
Loading
Loading