From b132a0e2aa182cd9ff1ac695b9ed2d01cefc3cd6 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Sat, 21 Dec 2024 23:14:45 +0100 Subject: [PATCH 1/3] update jaxdecomp version and test gradients --- pyproject.toml | 2 +- tests/test_gradients.py | 115 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 tests/test_gradients.py diff --git a/pyproject.toml b/pyproject.toml index a41096d..a204633 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ readme = "README.md" requires-python = ">=3.9" license = { file = "LICENSE" } urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/JaxPM" } -dependencies = ["jax_cosmo", "jax>=0.4.30", "jaxdecomp>=0.2.2"] +dependencies = ["jax_cosmo", "jax>=0.4.35", "jaxdecomp>=0.2.3"] [tool.setuptools] packages = ["jaxpm"] diff --git a/tests/test_gradients.py b/tests/test_gradients.py new file mode 100644 index 0000000..1ac10b5 --- /dev/null +++ b/tests/test_gradients.py @@ -0,0 +1,115 @@ +import pytest +from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve +from helpers import MSE +from jax import numpy as jnp + +from jaxpm.distributed import uniform_particles +from jaxpm.painting import cic_paint, cic_paint_dx +from jaxpm.pm import lpt, make_diffrax_ode +import jax + + +@pytest.mark.single_device +@pytest.mark.parametrize("order", [1, 2]) +def test_grad_relative(simulation_config, initial_conditions, + lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, + cosmo, order): + + mesh_shape, _ = simulation_config + cosmo._workspace = {} + + @jax.jit + @jax.grad + def forward_model(initial_conditions, cosmo): + + # Initial displacement + dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) + + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + + solver = Dopri5() + controller = PIDController(rtol=1e-7, + atol=1e-7, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + y0 = jnp.stack([dx, p]) + + solutions = diffeqsolve(ode_fn, + solver, + t0=lpt_scale_factor, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) + + final_field = cic_paint_dx(solutions.ys[-1, 0]) + + return MSE(final_field, nbody_from_lpt1 if order == 1 else nbody_from_lpt2) + + + bad_initial_conditions = initial_conditions + jax.random.normal(jax.random.PRNGKey(0), initial_conditions.shape) * 0.5 + best_ic = forward_model(initial_conditions , cosmo) + bad_ic = forward_model(bad_initial_conditions, cosmo) + + assert jnp.max(best_ic) < 1e-5 + assert jnp.max(bad_ic) > 1e-5 + +@pytest.mark.single_device +@pytest.mark.parametrize("order", [1, 2]) +def test_grad_absolute(simulation_config, initial_conditions, + lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, + cosmo, order): + + mesh_shape, _ = simulation_config + cosmo._workspace = {} + + @jax.jit + @jax.grad + def forward_model(initial_conditions, cosmo): + + # Initial displacement + particles = uniform_particles(mesh_shape) + dx, p, _ = lpt(cosmo, initial_conditions,particles, a=lpt_scale_factor, order=order) + + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=True)) + + solver = Dopri5() + controller = PIDController(rtol=1e-7, + atol=1e-7, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + y0 = jnp.stack([particles + dx, p]) + + solutions = diffeqsolve(ode_fn, + solver, + t0=lpt_scale_factor, + t1=1.0, + dt0=None, + y0=y0, + stepsize_controller=controller, + saveat=saveat) + + final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0]) + + return MSE(final_field, nbody_from_lpt1 if order == 1 else nbody_from_lpt2) + + + bad_initial_conditions = initial_conditions + jax.random.normal(jax.random.PRNGKey(0), initial_conditions.shape) * 0.5 + best_ic = forward_model(initial_conditions , cosmo) + bad_ic = forward_model(bad_initial_conditions, cosmo) + + assert jnp.max(best_ic) < 1e-5 + assert jnp.max(bad_ic) > 1e-5 + + From a924458f0d09432d36bd2a5d1406d78d082b7662 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Sat, 21 Dec 2024 23:24:19 +0100 Subject: [PATCH 2/3] Prepare for DTO tests --- tests/test_gradients.py | 150 ++++++++++++++++------------------------ 1 file changed, 60 insertions(+), 90 deletions(-) diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 1ac10b5..b35d656 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -1,115 +1,85 @@ +import jax import pytest -from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve +from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve , RecursiveCheckpointAdjoint, BacksolveAdjoint from helpers import MSE from jax import numpy as jnp from jaxpm.distributed import uniform_particles from jaxpm.painting import cic_paint, cic_paint_dx from jaxpm.pm import lpt, make_diffrax_ode -import jax @pytest.mark.single_device @pytest.mark.parametrize("order", [1, 2]) -def test_grad_relative(simulation_config, initial_conditions, - lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, - cosmo, order): +@pytest.mark.parametrize("absolute_painting", [True, False]) +@pytest.mark.parametrize("adjoint", ['DTO', 'OTD']) +def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, + nbody_from_lpt1, nbody_from_lpt2, cosmo, order, + absolute_painting , adjoint): mesh_shape, _ = simulation_config cosmo._workspace = {} - - @jax.jit - @jax.grad - def forward_model(initial_conditions, cosmo): - - # Initial displacement - dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) - - ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) - - solver = Dopri5() - controller = PIDController(rtol=1e-7, - atol=1e-7, - pcoeff=0.4, - icoeff=1, - dcoeff=0) - saveat = SaveAt(t1=True) + if adjoint == 'OTD': + pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)") - y0 = jnp.stack([dx, p]) - - solutions = diffeqsolve(ode_fn, - solver, - t0=lpt_scale_factor, - t1=1.0, - dt0=None, - y0=y0, - stepsize_controller=controller, - saveat=saveat) - - final_field = cic_paint_dx(solutions.ys[-1, 0]) - - return MSE(final_field, nbody_from_lpt1 if order == 1 else nbody_from_lpt2) - - - bad_initial_conditions = initial_conditions + jax.random.normal(jax.random.PRNGKey(0), initial_conditions.shape) * 0.5 - best_ic = forward_model(initial_conditions , cosmo) - bad_ic = forward_model(bad_initial_conditions, cosmo) - - assert jnp.max(best_ic) < 1e-5 - assert jnp.max(bad_ic) > 1e-5 + adjoint = RecursiveCheckpointAdjoint() if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5()) -@pytest.mark.single_device -@pytest.mark.parametrize("order", [1, 2]) -def test_grad_absolute(simulation_config, initial_conditions, - lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, - cosmo, order): - - mesh_shape, _ = simulation_config - cosmo._workspace = {} - @jax.jit @jax.grad def forward_model(initial_conditions, cosmo): - # Initial displacement - particles = uniform_particles(mesh_shape) - dx, p, _ = lpt(cosmo, initial_conditions,particles, a=lpt_scale_factor, order=order) - - ode_fn = ODETerm( - make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=True)) - - solver = Dopri5() - controller = PIDController(rtol=1e-7, - atol=1e-7, - pcoeff=0.4, - icoeff=1, - dcoeff=0) - - saveat = SaveAt(t1=True) - - y0 = jnp.stack([particles + dx, p]) - - solutions = diffeqsolve(ode_fn, - solver, - t0=lpt_scale_factor, - t1=1.0, - dt0=None, - y0=y0, - stepsize_controller=controller, - saveat=saveat) - - final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0]) - - return MSE(final_field, nbody_from_lpt1 if order == 1 else nbody_from_lpt2) - - - bad_initial_conditions = initial_conditions + jax.random.normal(jax.random.PRNGKey(0), initial_conditions.shape) * 0.5 - best_ic = forward_model(initial_conditions , cosmo) + # Initial displacement + if absolute_painting: + particles = uniform_particles(mesh_shape) + dx, p, _ = lpt(cosmo, + initial_conditions, + particles, + a=lpt_scale_factor, + order=order) + ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) + y0 = jnp.stack([particles + dx, p]) + + else: + dx, p, _ = lpt(cosmo, + initial_conditions, + a=lpt_scale_factor, + order=order) + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) + y0 = jnp.stack([dx, p]) + + solver = Dopri5() + controller = PIDController(rtol=1e-7, + atol=1e-7, + pcoeff=0.4, + icoeff=1, + dcoeff=0) + + saveat = SaveAt(t1=True) + + solutions = diffeqsolve(ode_fn, + solver, + t0=lpt_scale_factor, + t1=1.0, + dt0=None, + y0=y0, + adjoint=adjoint, + stepsize_controller=controller, + saveat=saveat) + + if absolute_painting: + final_field = cic_paint(jnp.zeros(mesh_shape), solutions.ys[-1, 0]) + else: + final_field = cic_paint_dx(solutions.ys[-1, 0]) + + return MSE(final_field, + nbody_from_lpt1 if order == 1 else nbody_from_lpt2) + + bad_initial_conditions = initial_conditions + jax.random.normal( + jax.random.PRNGKey(0), initial_conditions.shape) * 0.5 + best_ic = forward_model(initial_conditions, cosmo) bad_ic = forward_model(bad_initial_conditions, cosmo) assert jnp.max(best_ic) < 1e-5 assert jnp.max(bad_ic) > 1e-5 - - From bbacd45dcfa5d3921abb295e51a1cda592d92caf Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Sat, 21 Dec 2024 23:27:05 +0100 Subject: [PATCH 3/3] format --- tests/test_gradients.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_gradients.py b/tests/test_gradients.py index b35d656..bb48920 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -1,6 +1,7 @@ import jax import pytest -from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve , RecursiveCheckpointAdjoint, BacksolveAdjoint +from diffrax import (BacksolveAdjoint, Dopri5, ODETerm, PIDController, + RecursiveCheckpointAdjoint, SaveAt, diffeqsolve) from helpers import MSE from jax import numpy as jnp @@ -15,15 +16,16 @@ @pytest.mark.parametrize("adjoint", ['DTO', 'OTD']) def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, cosmo, order, - absolute_painting , adjoint): + absolute_painting, adjoint): mesh_shape, _ = simulation_config cosmo._workspace = {} if adjoint == 'OTD': - pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)") + pytest.skip("OTD adjoint not implemented yet (needs PFFT3D JVP)") - adjoint = RecursiveCheckpointAdjoint() if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5()) + adjoint = RecursiveCheckpointAdjoint( + ) if adjoint == 'DTO' else BacksolveAdjoint(solver=Dopri5()) @jax.jit @jax.grad