diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index af9aecd..3957e37 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -79,6 +79,7 @@ def slice_unpad_impl(x, pad_width): return x[tuple(unpad_slice)] + def slice_pad_impl(x, pad_width): return jax.tree.map(lambda x: jnp.pad(x, pad_width), x) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index ef4b5c5..8fec20d 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,9 +1,10 @@ +import jax import jax.numpy as jnp import numpy as np from jax.lax import FftType from jax.sharding import PartitionSpec as P from jaxdecomp import fftfreq3d, get_output_specs -import jax + from jaxpm.distributed import autoshmap @@ -25,7 +26,8 @@ def fftk(k_array): def interpolate_power_spectrum(input, k, pk, sharding=None): def pk_fn(input): - return jax.tree.map(lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input) + return jax.tree.map( + lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input) gpu_mesh = sharding.mesh if sharding is not None else None specs = sharding.spec if sharding is not None else P() @@ -61,7 +63,8 @@ def gradient_kernel(kvec, direction, order=1): return wts else: w = kvec[direction] - a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)), w) + a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)), + w) wts = a * 1j return wts @@ -85,11 +88,14 @@ def invlaplace_kernel(kvec, fd=False): Complex kernel values """ if fd: - kk = sum(jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki) for ki in kvec) + kk = sum( + jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki) + for ki in kvec) else: kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec) kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk) - return jax.tree.map(lambda x , y : -jnp.where(y == 0, 0, 1 / x), kk_nozeros, kk) + return jax.tree.map(lambda x, y: -jnp.where(y == 0, 0, 1 / x), kk_nozeros, + kk) def longrange_kernel(kvec, r_split): @@ -131,7 +137,10 @@ def cic_compensation(kvec): wts: array Complex kernel values """ - kwts = [jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i]) for i in range(3)] + kwts = [ + jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i]) + for i in range(3) + ] wts = (kwts[0] * kwts[1] * kwts[2])**(-2) return wts diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 0210166..b6b22a4 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -19,28 +19,36 @@ def _cic_paint_impl(grid_mesh, positions, weight=None): """ positions = positions.reshape([-1, 3]) - positions = jax.tree.map(lambda p : jnp.expand_dims(p , 1) , positions) - floor = jax.tree.map(jnp.floor , positions) + positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions) + floor = jax.tree.map(jnp.floor, positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) neighboor_coords = floor + connection - kernel = 1. - jax.tree.map(jnp.abs , (positions - neighboor_coords)) + kernel = 1. - jax.tree.map(jnp.abs, (positions - neighboor_coords)) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] if weight is not None: if jax.tree.all(jax.tree.map(jnp.isscalar, weight)): - kernel = jax.tree.map(lambda k , w : jnp.multiply(jnp.expand_dims(w, axis=-1) - , k) , kernel , weight) + kernel = jax.tree.map( + lambda k, w: jnp.multiply(jnp.expand_dims(w, axis=-1), k), + kernel, weight) else: - kernel = jax.tree.map(lambda k , w : jnp.multiply(w.reshape(*positions.shape[:-1]) , k) , kernel , weight) + kernel = jax.tree.map( + lambda k, w: jnp.multiply(w.reshape(*positions.shape[:-1]), k), + kernel, weight) - neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)) , neighboor_coords) + neighboor_coords = jax.tree.map( + lambda nc: jnp.mod( + nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape) + ), neighboor_coords) dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) - mesh = jax.tree.map(lambda g , nc , k : lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums) , grid_mesh , neighboor_coords , kernel) + mesh = jax.tree.map( + lambda g, nc, k: lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums), + grid_mesh, neighboor_coords, kernel) return mesh @@ -49,7 +57,8 @@ def _cic_paint_impl(grid_mesh, positions, weight=None): def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): positions_structure = jax.tree.structure(positions) - grid_mesh = jax.tree.unflatten(positions_structure, jax.tree.leaves(grid_mesh)) + grid_mesh = jax.tree.unflatten(positions_structure, + jax.tree.leaves(grid_mesh)) positions = positions.reshape((*grid_mesh.shape, 3)) halo_size, halo_extents = get_halo_size(halo_size, sharding) @@ -79,24 +88,27 @@ def _cic_read_impl(grid_mesh, positions): # Reshape positions to a flat list of 3D coordinates positions = positions.reshape([-1, 3]) # Expand dimensions to calculate neighbor coordinates - positions = jax.tree.map(lambda p : jnp.expand_dims(p, 1) , positions) + positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions) # Floor the positions to get the base grid cell for each particle - floor = jax.tree.map(jnp.floor , positions) + floor = jax.tree.map(jnp.floor, positions) # Define connections to calculate all neighbor coordinates connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) # Calculate the 8 neighboring coordinates neighboor_coords = floor + connection # Calculate kernel weights based on distance from each neighboring coordinate - kernel = 1. - jax.tree.map(jnp.abs , positions - neighboor_coords) + kernel = 1. - jax.tree.map(jnp.abs, positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] # Modulo operation to wrap around edges if necessary - neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.astype('int32') - ,jnp.array(grid_mesh.shape)) , neighboor_coords) + neighboor_coords = jax.tree.map( + lambda nc: jnp.mod(nc.astype('int32'), jnp.array(grid_mesh.shape)), + neighboor_coords) # Ensure grid_mesh shape is as expected # Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel - grid_mesh = jax.tree.map(lambda g , nc , k : g[nc[...,0], nc[...,1], nc[...,2]] * k , grid_mesh , neighboor_coords , kernel) + grid_mesh = jax.tree.map( + lambda g, nc, k: g[nc[..., 0], nc[..., 1], nc[..., 2]] * k, grid_mesh, + neighboor_coords, kernel) return grid_mesh.sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable @@ -157,7 +169,9 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): halo_y, _ = halo_size[1] original_shape = displacements.shape - particle_mesh = jax.tree.map(lambda x : jnp.zeros(x.shape[:-1], dtype=displacements.dtype), displacements) + particle_mesh = jax.tree.map( + lambda x: jnp.zeros(x.shape[:-1], dtype=displacements.dtype), + displacements) if not jnp.isscalar(weight): if weight.shape != original_shape[:-1]: raise ValueError("Weight shape must match particle shape") @@ -165,13 +179,18 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): weight = weight.flatten() # Padding is forced to be zero in a single gpu run - a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]), - jnp.arange(x.shape[1]), - jnp.arange(x.shape[2]), - indexing='ij') , axis=0), particle_mesh) - - particle_mesh = jax.tree.map(lambda x : jnp.pad(x, halo_size), particle_mesh) - pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c) + a, b, c = jax.tree.map( + lambda x: jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]), + jnp.arange(x.shape[1]), + jnp.arange(x.shape[2]), + indexing='ij'), + axis=0), particle_mesh) + + particle_mesh = jax.tree.map(lambda x: jnp.pad(x, halo_size), + particle_mesh) + pmid = jax.tree.map( + lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, + c) return scatter(pmid.reshape([-1, 3]), displacements.reshape([-1, 3]), particle_mesh, @@ -217,12 +236,16 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size): jnp.arange(original_shape[1]), jnp.arange(original_shape[2]), indexing='ij') - a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]), - jnp.arange(original_shape[1]), - jnp.arange(original_shape[2]), - indexing='ij') , axis=0), grid_mesh) - - pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c) + a, b, c = jax.tree.map( + lambda x: jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]), + jnp.arange(original_shape[1]), + jnp.arange(original_shape[2]), + indexing='ij'), + axis=0), grid_mesh) + + pmid = jax.tree.map( + lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, + c) pmid = pmid.reshape([-1, 3]) disp = disp.reshape([-1, 3]) diff --git a/jaxpm/painting_utils.py b/jaxpm/painting_utils.py index 09e6ee5..76c063d 100644 --- a/jaxpm/painting_utils.py +++ b/jaxpm/painting_utils.py @@ -28,8 +28,8 @@ def _chunk_split(ptcl_num, chunk_size, *arrays): def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_size, new_shape): """Multilinear enmeshing.""" - base_indices = jax.tree.map(jnp.asarray , base_indices) - displacements = jax.tree.map(jnp.asarray , displacements) + base_indices = jax.tree.map(jnp.asarray, base_indices) + displacements = jax.tree.map(jnp.asarray, displacements) with jax.experimental.enable_x64(): cell_size = jnp.float64( cell_size) if new_cell_size is not None else jnp.array( @@ -61,8 +61,8 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_displacements = particle_positions - new_indices * new_cell_size if base_shape is not None: - new_displacements -= jax.tree.map(jnp.rint , - new_displacements / grid_length + new_displacements -= jax.tree.map( + jnp.rint, new_displacements / grid_length ) * grid_length # also abs(new_displacements) < new_cell_size is expected new_indices = new_indices.astype(base_indices.dtype) @@ -89,7 +89,7 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset, if base_shape is not None: new_indices %= base_shape - weights = 1 - jax.tree.map(jnp.abs , new_displacements) + weights = 1 - jax.tree.map(jnp.abs, new_displacements) if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None new_indices = jnp.where(new_indices < 0, new_shape, new_indices) @@ -109,11 +109,15 @@ def _scatter_chunk(carry, chunk): ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size, spatial_shape) # scatter - ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind) + ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)), + ind) mesh_structure = jax.tree.structure(mesh) val_flat = jax.tree.leaves(val) val_tree = jax.tree.unflatten(mesh_structure, val_flat) - mesh = jax.tree.map(lambda m , v , i, f : m.at[i].add(jnp.multiply(jnp.expand_dims(v, axis=-1), f)) , mesh , val_tree ,ind , frac) + mesh = jax.tree.map( + lambda m, v, i, f: m.at[i].add( + jnp.multiply(jnp.expand_dims(v, axis=-1), f)), mesh, val_tree, ind, + frac) carry = mesh, offset, cell_size, mesh_shape return carry, None @@ -125,10 +129,10 @@ def scatter(pmid, val=1., offset=0, cell_size=1.): - + ptcl_num, spatial_ndim = pmid.shape - val = jax.tree.map(jnp.asarray , val) - mesh = jax.tree.map(jnp.asarray , mesh) + val = jax.tree.map(jnp.asarray, val) + mesh = jax.tree.map(jnp.asarray, mesh) remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) carry = mesh, offset, cell_size, mesh.shape if remainder is not None: @@ -151,9 +155,9 @@ def _chunk_cat(remainder_array, chunked_array): def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.): ptcl_num, spatial_ndim = pmid.shape - mesh = jax.tree.map(jnp.asarray , mesh) + mesh = jax.tree.map(jnp.asarray, mesh) - val = jax.tree.map(jnp.asarray , val) + val = jax.tree.map(jnp.asarray, val) if mesh.shape[spatial_ndim:] != val.shape[1:]: raise ValueError('channel shape mismatch: ' @@ -187,11 +191,15 @@ def _gather_chunk(carry, chunk): spatial_shape) # gather - ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind) + ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)), + ind) frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac) ind_structure = jax.tree.structure(ind) frac_structure = jax.tree.structure(frac) mesh_structure = jax.tree.structure(mesh) - val += jax.tree.map(lambda m , i , f : (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1) , mesh , ind , frac) + val += jax.tree.map( + lambda m, i, f: + (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1), mesh, ind, + frac) return carry, val diff --git a/jaxpm/pm.py b/jaxpm/pm.py index a4fffc7..46004ab 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -1,3 +1,4 @@ +import jax import jax.numpy as jnp import jax_cosmo as jc @@ -7,7 +8,7 @@ from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, invlaplace_kernel, longrange_kernel) from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx -import jax + def pm_forces(positions, mesh_shape=None, @@ -52,10 +53,12 @@ def pm_forces(positions, kvec, r_split=r_split) # Computes gravitational forces forces = [ - read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions - ) for i in range(3)] + read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k), positions) + for i in range(3) + ] - forces = jax.tree.map(lambda x ,y ,z : jnp.stack([x,y,z], axis=-1), forces[0], forces[1], forces[2]) + forces = jax.tree.map(lambda x, y, z: jnp.stack([x, y, z], axis=-1), + forces[0], forces[1], forces[2]) return forces @@ -73,8 +76,9 @@ def lpt(cosmo, """ paint_absolute_pos = particles is not None if particles is None: - particles = jax.tree.map(lambda ic : jnp.zeros_like(ic, - shape=(*ic.shape, 3)) , initial_conditions) + particles = jax.tree.map( + lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)), + initial_conditions) a = jnp.atleast_1d(a) E = jnp.sqrt(jc.background.Esqr(cosmo, a)) @@ -198,7 +202,8 @@ def nbody_ode(a, state, args): # Computes the update of velocity (kick) dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces - return jax.tree.map(lambda dp , dv : jnp.stack([dp, dv],axis=0), dpos, dvel) + return jax.tree.map(lambda dp, dv: jnp.stack([dp, dv], axis=0), dpos, + dvel) return nbody_ode diff --git a/jaxpm/utils.py b/jaxpm/utils.py index ab34da2..db33bb2 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -1,16 +1,17 @@ from functools import partial +import jax import jax.numpy as jnp import numpy as np from jax.scipy.stats import norm from scipy.special import legendre -import jax __all__ = [ 'power_spectrum', 'transfer', 'coherence', 'pktranscoh', 'cross_correlation_coefficients', 'gaussian_smoothing' ] + def _initialize_pk(mesh_shape, box_shape, kedges, los): """ Parameters @@ -100,11 +101,12 @@ def power_spectrum(mesh, n_bins = len(kavg) + 2 # FFTs - meshk = jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho') , mesh) + meshk = jax.tree.map(lambda x: jnp.fft.fftn(x, norm='ortho'), mesh) if mesh2 is None: mmk = meshk.real**2 + meshk.imag**2 else: - mmk = meshk * jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho').conj() , mesh2) + mmk = meshk * jax.tree.map( + lambda x: jnp.fft.fftn(x, norm='ortho').conj(), mesh2) # Sum powers pk = jnp.empty((len(poles), n_bins)) diff --git a/tests/conftest.py b/tests/conftest.py index 1ea04c8..868bd7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -174,18 +174,22 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor): return fpm_mesh + def compare_sharding(sharding1, sharding2): + def get_axis_size(sharding, idx): axis_name = sharding.spec[idx] if axis_name is None: return 1 else: return sharding.mesh.shape[sharding.spec[idx]] + def get_pdims_from_sharding(sharding): - return tuple([get_axis_size(sharding, i) for i in range(len(sharding.spec))]) + return tuple( + [get_axis_size(sharding, i) for i in range(len(sharding.spec))]) pdims1 = get_pdims_from_sharding(sharding1) pdims2 = get_pdims_from_sharding(sharding2) - pdims1 = pdims1 + (1,) * (3 - len(pdims1)) - pdims2 = pdims2 + (1,) * (3 - len(pdims2)) - return pdims1 == pdims2 \ No newline at end of file + pdims1 = pdims1 + (1, ) * (3 - len(pdims1)) + pdims2 = pdims2 + (1, ) * (3 - len(pdims2)) + return pdims1 == pdims2 diff --git a/tests/test_against_fpm.py b/tests/test_against_fpm.py index 9a7bc93..0097e35 100644 --- a/tests/test_against_fpm.py +++ b/tests/test_against_fpm.py @@ -1,14 +1,15 @@ +import jax import pytest from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve from helpers import MSE, MSRE from jax import numpy as jnp - from jaxdecomp import ShardedArray + from jaxpm.distributed import uniform_particles from jaxpm.painting import cic_paint, cic_paint_dx from jaxpm.pm import lpt, make_diffrax_ode from jaxpm.utils import power_spectrum -import jax + _TOLERANCE = 1e-4 _PM_TOLERANCE = 1e-3 @@ -17,7 +18,8 @@ @pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("shardedArrayAPI", [True, False]) def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor, - fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI): + fpm_lpt1_field, fpm_lpt2_field, cosmo, order, + shardedArrayAPI): mesh_shape, box_shape = simulation_config cosmo._workspace = {} @@ -53,7 +55,8 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor, @pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("shardedArrayAPI", [True, False]) def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, - fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI): + fpm_lpt1_field, fpm_lpt2_field, cosmo, order, + shardedArrayAPI): mesh_shape, box_shape = simulation_config cosmo._workspace = {} @@ -77,12 +80,13 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, assert type(dx) == ShardedArray assert type(lpt_field) == ShardedArray + @pytest.mark.single_device @pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("shardedArrayAPI", [True, False]) def test_nbody_absolute(simulation_config, initial_conditions, lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, - cosmo, order , shardedArrayAPI): + cosmo, order, shardedArrayAPI): mesh_shape, box_shape = simulation_config cosmo._workspace = {} @@ -110,7 +114,8 @@ def test_nbody_absolute(simulation_config, initial_conditions, saveat = SaveAt(t1=True) - y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]), particles , dx, p) + y0 = jax.tree.map(lambda particles, dx, p: jnp.stack([particles + dx, p]), + particles, dx, p) solutions = diffeqsolve(ode_fn, solver, @@ -135,7 +140,7 @@ def test_nbody_absolute(simulation_config, initial_conditions, if shardedArrayAPI: assert type(dx) == ShardedArray - assert type( solutions.ys[-1, 0]) == ShardedArray + assert type(solutions.ys[-1, 0]) == ShardedArray assert type(final_field) == ShardedArray @@ -144,7 +149,7 @@ def test_nbody_absolute(simulation_config, initial_conditions, @pytest.mark.parametrize("shardedArrayAPI", [True, False]) def test_nbody_relative(simulation_config, initial_conditions, lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, - cosmo, order , shardedArrayAPI): + cosmo, order, shardedArrayAPI): mesh_shape, box_shape = simulation_config cosmo._workspace = {} @@ -155,8 +160,7 @@ def test_nbody_relative(simulation_config, initial_conditions, # Initial displacement dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) - ode_fn = ODETerm( - make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) solver = Dopri5() controller = PIDController(rtol=1e-9, @@ -167,7 +171,7 @@ def test_nbody_relative(simulation_config, initial_conditions, saveat = SaveAt(t1=True) - y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]), dx, p) + y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p) solutions = diffeqsolve(ode_fn, solver, @@ -192,5 +196,5 @@ def test_nbody_relative(simulation_config, initial_conditions, if shardedArrayAPI: assert type(dx) == ShardedArray - assert type( solutions.ys[-1, 0]) == ShardedArray + assert type(solutions.ys[-1, 0]) == ShardedArray assert type(final_field) == ShardedArray diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index 8054408..a93cf23 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -1,9 +1,12 @@ -from conftest import initialize_distributed , compare_sharding +from conftest import compare_sharding, initialize_distributed initialize_distributed() # ignore : E402 +from functools import partial # noqa : E402 + import jax # noqa : E402 import jax.numpy as jnp # noqa : E402 +import jax_cosmo as jc # noqa : E402 import pytest # noqa : E402 from diffrax import SaveAt # noqa : E402 from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve @@ -12,13 +15,13 @@ from jax.experimental.multihost_utils import process_allgather # noqa : E402 from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P # noqa : E402 -from jaxpm.pm import pm_forces # noqa : E402 -from jaxpm.distributed import uniform_particles , fft3d # noqa : E402 +from jaxdecomp import ShardedArray # noqa : E402 + +from jaxpm.distributed import fft3d, uniform_particles # noqa : E402 from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 +from jaxpm.pm import pm_forces # noqa : E402 from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402 -from jaxdecomp import ShardedArray # noqa : E402 -from functools import partial # noqa : E402 -import jax_cosmo as jc # noqa : E402 + _TOLERANCE = 3.0 # 🙃🙃 @@ -27,7 +30,7 @@ @pytest.mark.parametrize("absolute_painting", [True, False]) @pytest.mark.parametrize("shardedArrayAPI", [True, False]) def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, - absolute_painting,shardedArrayAPI): + absolute_painting, shardedArrayAPI): mesh_shape, box_shape = simulation_config # SINGLE DEVICE RUN @@ -42,18 +45,16 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, if shardedArrayAPI: particles = ShardedArray(particles) # Initial displacement - dx, p, _ = lpt(cosmo, - ic, - particles, - a=0.1, - order=order) + dx, p, _ = lpt(cosmo, ic, particles, a=0.1, order=order) ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) - y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p) + y0 = jax.tree.map( + lambda particles, dx, p: jnp.stack([particles + dx, p]), particles, + dx, p) else: dx, p, _ = lpt(cosmo, ic, a=0.1, order=order) - ode_fn = ODETerm( - make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) - y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape, + paint_absolute_pos=False)) + y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p) solver = Dopri5() controller = PIDController(rtol=1e-8, @@ -87,13 +88,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, sharding = NamedSharding(mesh, P('x', 'y')) halo_size = mesh_shape[0] // 2 - ic = lax.with_sharding_constraint(initial_conditions, - sharding) + ic = lax.with_sharding_constraint(initial_conditions, sharding) print(f"sharded initial conditions {ic.sharding}") if shardedArrayAPI: - ic = ShardedArray(ic , sharding) + ic = ShardedArray(ic, sharding) cosmo._workspace = {} if absolute_painting: @@ -110,12 +110,13 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode( - mesh_shape, + make_diffrax_ode(mesh_shape, halo_size=halo_size, sharding=sharding)) - y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p) + y0 = jax.tree.map( + lambda particles, dx, p: jnp.stack([particles + dx, p]), particles, + dx, p) else: dx, p, _ = lpt(cosmo, ic, @@ -124,12 +125,11 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, halo_size=halo_size, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode( - mesh_shape, + make_diffrax_ode(mesh_shape, paint_absolute_pos=False, halo_size=halo_size, sharding=sharding)) - y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p) + y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p) solver = Dopri5() controller = PIDController(rtol=1e-8, @@ -161,7 +161,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, sharding=sharding) multi_device_final_field_g = process_allgather(multi_device_final_field, - tiled=True) + tiled=True) single_device_final_field_arr, = jax.tree.leaves(single_device_final_field) multi_device_final_field_arr, = jax.tree.leaves(multi_device_final_field_g) @@ -170,18 +170,19 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, if shardedArrayAPI: assert type(multi_device_final_field) == ShardedArray - assert compare_sharding(multi_device_final_field.sharding , sharding) - assert compare_sharding(multi_device_final_field.initial_sharding , sharding) + assert compare_sharding(multi_device_final_field.sharding, sharding) + assert compare_sharding(multi_device_final_field.initial_sharding, + sharding) assert mse < _TOLERANCE - @pytest.mark.distributed @pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("absolute_painting", [True, False]) -def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, order,nbody_from_lpt1, nbody_from_lpt2, - absolute_painting): +def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, + order, nbody_from_lpt1, nbody_from_lpt2, + absolute_painting): mesh_shape, box_shape = simulation_config # SINGLE DEVICE RUN @@ -196,55 +197,53 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde print(f"sharded initial conditions {initial_conditions.sharding}") - - initial_conditions = ShardedArray(initial_conditions , sharding) + initial_conditions = ShardedArray(initial_conditions, sharding) cosmo._workspace = {} @jax.jit - def forward_model(initial_conditions , cosmo): - + def forward_model(initial_conditions, cosmo): if absolute_painting: particles = uniform_particles(mesh_shape, sharding=sharding) particles = ShardedArray(particles, sharding) # Initial displacement dx, p, _ = lpt(cosmo, - initial_conditions, - particles, - a=0.1, - order=order, - halo_size=halo_size, - sharding=sharding) + initial_conditions, + particles, + a=0.1, + order=order, + halo_size=halo_size, + sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode( - mesh_shape, - halo_size=halo_size, - sharding=sharding)) + make_diffrax_ode(mesh_shape, + halo_size=halo_size, + sharding=sharding)) - y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p) + y0 = jax.tree.map( + lambda particles, dx, p: jnp.stack([particles + dx, p]), + particles, dx, p) else: dx, p, _ = lpt(cosmo, - initial_conditions, - a=0.1, - order=order, - halo_size=halo_size, - sharding=sharding) + initial_conditions, + a=0.1, + order=order, + halo_size=halo_size, + sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode( - mesh_shape, - paint_absolute_pos=False, - halo_size=halo_size, - sharding=sharding)) - y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p) + make_diffrax_ode(mesh_shape, + paint_absolute_pos=False, + halo_size=halo_size, + sharding=sharding)) + y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p) solver = Dopri5() controller = PIDController(rtol=1e-8, - atol=1e-8, - pcoeff=0.4, - icoeff=1, - dcoeff=0) + atol=1e-8, + pcoeff=0.4, + icoeff=1, + dcoeff=0) saveat = SaveAt(t1=True) @@ -260,9 +259,9 @@ def forward_model(initial_conditions , cosmo): if absolute_painting: multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), - solutions.ys[-1, 0], - halo_size=halo_size, - sharding=sharding) + solutions.ys[-1, 0], + halo_size=halo_size, + sharding=sharding) else: multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], halo_size=halo_size, @@ -271,30 +270,31 @@ def forward_model(initial_conditions , cosmo): return multi_device_final_field @jax.jit - def model(initial_conditions , cosmo): + def model(initial_conditions, cosmo): - final_field = forward_model(initial_conditions , cosmo) + final_field = forward_model(initial_conditions, cosmo) final_field, = jax.tree.leaves(final_field) - + return MSE(final_field, nbody_from_lpt1 if order == 1 else nbody_from_lpt2) - - obs_val = model(initial_conditions , cosmo) - shifted_initial_conditions = initial_conditions + jax.random.normal(jax.random.key(42) , initial_conditions.shape) * 5 + obs_val = model(initial_conditions, cosmo) - good_grads = jax.grad(model)(initial_conditions , cosmo) - off_grads = jax.grad(model)(shifted_initial_conditions , cosmo) + shifted_initial_conditions = initial_conditions + jax.random.normal( + jax.random.key(42), initial_conditions.shape) * 5 - assert compare_sharding(good_grads.sharding , initial_conditions.sharding) - assert compare_sharding(off_grads.sharding , initial_conditions.sharding) + good_grads = jax.grad(model)(initial_conditions, cosmo) + off_grads = jax.grad(model)(shifted_initial_conditions, cosmo) + + assert compare_sharding(good_grads.sharding, initial_conditions.sharding) + assert compare_sharding(off_grads.sharding, initial_conditions.sharding) @pytest.mark.distributed @pytest.mark.parametrize("absolute_painting", [True, False]) -def test_fwd_rev_gradients(cosmo,absolute_painting): +def test_fwd_rev_gradients(cosmo, absolute_painting): - mesh_shape, box_shape = (8 , 8 , 8) , (20.0 , 20.0 , 20.0) + mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) # SINGLE DEVICE RUN cosmo._workspace = {} @@ -308,34 +308,54 @@ def test_fwd_rev_gradients(cosmo,absolute_painting): sharding) print(f"sharded initial conditions {initial_conditions.sharding}") - initial_conditions = ShardedArray(initial_conditions , sharding) + initial_conditions = ShardedArray(initial_conditions, sharding) cosmo._workspace = {} - @partial(jax.jit , static_argnums=(3,4 , 5)) - def compute_forces(initial_conditions , cosmo , particles=None , a=0.5 , halo_size=0 , sharding=None): - + @partial(jax.jit, static_argnums=(3, 4, 5)) + def compute_forces(initial_conditions, + cosmo, + particles=None, + a=0.5, + halo_size=0, + sharding=None): + paint_absolute_pos = particles is not None if particles is None: - particles = jax.tree.map(lambda ic : jnp.zeros_like(ic, - shape=(*ic.shape, 3)) , initial_conditions) + particles = jax.tree.map( + lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)), + initial_conditions) a = jnp.atleast_1d(a) E = jnp.sqrt(jc.background.Esqr(cosmo, a)) delta_k = fft3d(initial_conditions) initial_force = pm_forces(particles, - delta=delta_k, - paint_absolute_pos=paint_absolute_pos, - halo_size=halo_size, - sharding=sharding) - - return initial_force[...,0] - - particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding) , sharding) if absolute_painting else None - forces = compute_forces(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding) - back_gradient = jax.jacrev(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding) - fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding) - - assert compare_sharding(forces.sharding , initial_conditions.sharding) - assert compare_sharding(back_gradient[0,0,0,...].sharding , initial_conditions.sharding) - assert compare_sharding(fwd_gradient.sharding , initial_conditions.sharding) + delta=delta_k, + paint_absolute_pos=paint_absolute_pos, + halo_size=halo_size, + sharding=sharding) + + return initial_force[..., 0] + + particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding), + sharding) if absolute_painting else None + forces = compute_forces(initial_conditions, + cosmo, + particles=particles, + halo_size=halo_size, + sharding=sharding) + back_gradient = jax.jacrev(compute_forces)(initial_conditions, + cosmo, + particles=particles, + halo_size=halo_size, + sharding=sharding) + fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions, + cosmo, + particles=particles, + halo_size=halo_size, + sharding=sharding) + + assert compare_sharding(forces.sharding, initial_conditions.sharding) + assert compare_sharding(back_gradient[0, 0, 0, ...].sharding, + initial_conditions.sharding) + assert compare_sharding(fwd_gradient.sharding, initial_conditions.sharding) diff --git a/tests/test_sharded_array.py b/tests/test_sharded_array.py index d73e525..93d23b6 100644 --- a/tests/test_sharded_array.py +++ b/tests/test_sharded_array.py @@ -1,31 +1,31 @@ import os + #os.environ["JAX_PLATFORM_NAME"] = "cpu" #os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" -import os os.environ["EQX_ON_ERROR"] = "nan" +from functools import partial + import jax import jax.numpy as jnp import jax_cosmo as jc +from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, + diffeqsolve) from jax.debug import visualize_array_sharding +from jax.experimental.mesh_utils import create_device_mesh +from jax.experimental.multihost_utils import process_allgather +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from jaxpm.distributed import uniform_particles from jaxpm.kernels import interpolate_power_spectrum -from jaxpm.painting import cic_paint_dx , cic_read_dx , cic_paint , cic_read +from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx from jaxpm.pm import linear_field, lpt, make_diffrax_ode -from functools import partial -from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve -from jaxpm.distributed import uniform_particles #assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices" - -from jax.experimental.mesh_utils import create_device_mesh -from jax.experimental.multihost_utils import process_allgather -from jax.sharding import Mesh, NamedSharding -from jax.sharding import PartitionSpec as P - all_gather = partial(process_allgather, tiled=False) pdims = (2, 4) @@ -34,8 +34,8 @@ #sharding = NamedSharding(mesh, P('x', 'y')) sharding = None - from typing import NamedTuple + from jaxdecomp import ShardedArray mesh_shape = 64 @@ -43,42 +43,42 @@ halo_size = 2 snapshots = (0.5, 1.0) + class Params(NamedTuple): omega_c: float sigma8: float - initial_conditions : jnp.ndarray + initial_conditions: jnp.ndarray + -mesh_shape = (mesh_shape,) * 3 -box_size = (box_size,) * 3 +mesh_shape = (mesh_shape, ) * 3 +box_size = (box_size, ) * 3 omega_c = 0.25 sigma8 = 0.8 # Create a small function to generate the matter power spectrum k = jnp.logspace(-4, 1, 128) -pk = jc.power.linear_matter_power( - jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) +pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), + k) pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding) initial_conditions = linear_field(mesh_shape, - box_size, - pk_fn, - seed=jax.random.PRNGKey(0), - sharding=sharding) - + box_size, + pk_fn, + seed=jax.random.PRNGKey(0), + sharding=sharding) #initial_conditions = ShardedArray(initial_conditions, sharding) params = Params(omega_c, sigma8, initial_conditions) - -@partial(jax.jit , static_argnums=(1 , 2,3,4 )) -def forward_model(params , mesh_shape,box_size,halo_size , snapshots): +@partial(jax.jit, static_argnums=(1, 2, 3, 4)) +def forward_model(params, mesh_shape, box_size, halo_size, snapshots): # Create initial conditions cosmo = jc.Planck15(Omega_c=params.omega_c, sigma8=params.sigma8) - particles = uniform_particles(mesh_shape , sharding) + particles = uniform_particles(mesh_shape, sharding) ic_structure = jax.tree.structure(params.initial_conditions) - particles = jax.tree.unflatten(ic_structure , jax.tree.leaves(particles)) + particles = jax.tree.unflatten(ic_structure, jax.tree.leaves(particles)) # Initial displacement dx, p, f = lpt(cosmo, params.initial_conditions, @@ -90,10 +90,15 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots): # Evolve the simulation forward ode_fn = ODETerm( - make_diffrax_ode(mesh_shape, paint_absolute_pos=True,halo_size=halo_size,sharding=sharding)) + make_diffrax_ode(mesh_shape, + paint_absolute_pos=True, + halo_size=halo_size, + sharding=sharding)) solver = LeapfrogMidpoint() - y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx ,p],axis=0) , particles , dx , p) + y0 = jax.tree.map( + lambda particles, dx, p: jnp.stack([particles + dx, p], axis=0), + particles, dx, p) print(f"y0 structure: {jax.tree.structure(y0)}") stepsize_controller = ConstantStepSize() @@ -107,18 +112,17 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots): saveat=SaveAt(ts=snapshots), stepsize_controller=stepsize_controller) ode_solutions = [sol[0] for sol in res.ys] - - ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), ode_solutions[-1]) - return particles + dx , ode_field + ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), + ode_solutions[-1]) + return particles + dx, ode_field ode_field = cic_paint_dx(ode_solutions[-1]) - return dx , ode_field - - + return dx, ode_field -lpt_particles , ode_field = forward_model(params , mesh_shape,box_size,halo_size , snapshots) +lpt_particles, ode_field = forward_model(params, mesh_shape, box_size, + halo_size, snapshots) import matplotlib.pyplot as plt @@ -127,11 +131,11 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots): plt.figure(figsize=(12, 6)) plt.subplot(121) -plt.imshow(lpt_field.sum(axis=0) , cmap='magma') +plt.imshow(lpt_field.sum(axis=0), cmap='magma') plt.colorbar() plt.title('LPT field') plt.subplot(122) -plt.imshow(ode_field.sum(axis=0) , cmap='magma') +plt.imshow(ode_field.sum(axis=0), cmap='magma') plt.colorbar() plt.title('ODE field') plt.show() @@ -144,4 +148,4 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots): #field = ShardedArray(field, sharding) # # -#cic_read_dx(field , particles ) \ No newline at end of file +#cic_read_dx(field , particles )