Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Jan 20, 2025
1 parent 20fe25c commit 1f5c619
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 211 deletions.
1 change: 1 addition & 0 deletions jaxpm/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def slice_unpad_impl(x, pad_width):

return x[tuple(unpad_slice)]


def slice_pad_impl(x, pad_width):
return jax.tree.map(lambda x: jnp.pad(x, pad_width), x)

Expand Down
21 changes: 15 additions & 6 deletions jaxpm/kernels.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.lax import FftType
from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs
import jax

from jaxpm.distributed import autoshmap


Expand All @@ -25,7 +26,8 @@ def fftk(k_array):
def interpolate_power_spectrum(input, k, pk, sharding=None):

def pk_fn(input):
return jax.tree.map(lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input)
return jax.tree.map(
lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input)

gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P()
Expand Down Expand Up @@ -61,7 +63,8 @@ def gradient_kernel(kvec, direction, order=1):
return wts
else:
w = kvec[direction]
a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)), w)
a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)),
w)
wts = a * 1j
return wts

Expand All @@ -85,11 +88,14 @@ def invlaplace_kernel(kvec, fd=False):
Complex kernel values
"""
if fd:
kk = sum(jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki) for ki in kvec)
kk = sum(
jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki)
for ki in kvec)
else:
kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec)
kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk)
return jax.tree.map(lambda x , y : -jnp.where(y == 0, 0, 1 / x), kk_nozeros, kk)
return jax.tree.map(lambda x, y: -jnp.where(y == 0, 0, 1 / x), kk_nozeros,
kk)


def longrange_kernel(kvec, r_split):
Expand Down Expand Up @@ -131,7 +137,10 @@ def cic_compensation(kvec):
wts: array
Complex kernel values
"""
kwts = [jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i]) for i in range(3)]
kwts = [
jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i])
for i in range(3)
]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts

Expand Down
81 changes: 52 additions & 29 deletions jaxpm/painting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,36 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
"""

positions = positions.reshape([-1, 3])
positions = jax.tree.map(lambda p : jnp.expand_dims(p , 1) , positions)
floor = jax.tree.map(jnp.floor , positions)
positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions)
floor = jax.tree.map(jnp.floor, positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])

neighboor_coords = floor + connection
kernel = 1. - jax.tree.map(jnp.abs , (positions - neighboor_coords))
kernel = 1. - jax.tree.map(jnp.abs, (positions - neighboor_coords))
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
if jax.tree.all(jax.tree.map(jnp.isscalar, weight)):
kernel = jax.tree.map(lambda k , w : jnp.multiply(jnp.expand_dims(w, axis=-1)
, k) , kernel , weight)
kernel = jax.tree.map(
lambda k, w: jnp.multiply(jnp.expand_dims(w, axis=-1), k),
kernel, weight)
else:
kernel = jax.tree.map(lambda k , w : jnp.multiply(w.reshape(*positions.shape[:-1]) , k) , kernel , weight)
kernel = jax.tree.map(
lambda k, w: jnp.multiply(w.reshape(*positions.shape[:-1]), k),
kernel, weight)

neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)) , neighboor_coords)
neighboor_coords = jax.tree.map(
lambda nc: jnp.mod(
nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)
), neighboor_coords)

dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = jax.tree.map(lambda g , nc , k : lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums) , grid_mesh , neighboor_coords , kernel)
mesh = jax.tree.map(
lambda g, nc, k: lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums),
grid_mesh, neighboor_coords, kernel)

return mesh

Expand All @@ -49,7 +57,8 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):

positions_structure = jax.tree.structure(positions)
grid_mesh = jax.tree.unflatten(positions_structure, jax.tree.leaves(grid_mesh))
grid_mesh = jax.tree.unflatten(positions_structure,
jax.tree.leaves(grid_mesh))
positions = positions.reshape((*grid_mesh.shape, 3))

halo_size, halo_extents = get_halo_size(halo_size, sharding)
Expand Down Expand Up @@ -79,24 +88,27 @@ def _cic_read_impl(grid_mesh, positions):
# Reshape positions to a flat list of 3D coordinates
positions = positions.reshape([-1, 3])
# Expand dimensions to calculate neighbor coordinates
positions = jax.tree.map(lambda p : jnp.expand_dims(p, 1) , positions)
positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions)
# Floor the positions to get the base grid cell for each particle
floor = jax.tree.map(jnp.floor , positions)
floor = jax.tree.map(jnp.floor, positions)
# Define connections to calculate all neighbor coordinates
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
# Calculate the 8 neighboring coordinates
neighboor_coords = floor + connection
# Calculate kernel weights based on distance from each neighboring coordinate
kernel = 1. - jax.tree.map(jnp.abs , positions - neighboor_coords)
kernel = 1. - jax.tree.map(jnp.abs, positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
# Modulo operation to wrap around edges if necessary
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.astype('int32')
,jnp.array(grid_mesh.shape)) , neighboor_coords)
neighboor_coords = jax.tree.map(
lambda nc: jnp.mod(nc.astype('int32'), jnp.array(grid_mesh.shape)),
neighboor_coords)

# Ensure grid_mesh shape is as expected
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
grid_mesh = jax.tree.map(lambda g , nc , k : g[nc[...,0], nc[...,1], nc[...,2]] * k , grid_mesh , neighboor_coords , kernel)
grid_mesh = jax.tree.map(
lambda g, nc, k: g[nc[..., 0], nc[..., 1], nc[..., 2]] * k, grid_mesh,
neighboor_coords, kernel)
return grid_mesh.sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable


Expand Down Expand Up @@ -157,21 +169,28 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
halo_y, _ = halo_size[1]

original_shape = displacements.shape
particle_mesh = jax.tree.map(lambda x : jnp.zeros(x.shape[:-1], dtype=displacements.dtype), displacements)
particle_mesh = jax.tree.map(
lambda x: jnp.zeros(x.shape[:-1], dtype=displacements.dtype),
displacements)
if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape")
else:
weight = weight.flatten()
# Padding is forced to be zero in a single gpu run

a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]),
jnp.arange(x.shape[1]),
jnp.arange(x.shape[2]),
indexing='ij') , axis=0), particle_mesh)

particle_mesh = jax.tree.map(lambda x : jnp.pad(x, halo_size), particle_mesh)
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
a, b, c = jax.tree.map(
lambda x: jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]),
jnp.arange(x.shape[1]),
jnp.arange(x.shape[2]),
indexing='ij'),
axis=0), particle_mesh)

particle_mesh = jax.tree.map(lambda x: jnp.pad(x, halo_size),
particle_mesh)
pmid = jax.tree.map(
lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b,
c)
return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]),
particle_mesh,
Expand Down Expand Up @@ -217,12 +236,16 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]),
indexing='ij')
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]),
indexing='ij') , axis=0), grid_mesh)

pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
a, b, c = jax.tree.map(
lambda x: jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]),
indexing='ij'),
axis=0), grid_mesh)

pmid = jax.tree.map(
lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b,
c)
pmid = pmid.reshape([-1, 3])
disp = disp.reshape([-1, 3])

Expand Down
36 changes: 22 additions & 14 deletions jaxpm/painting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def _chunk_split(ptcl_num, chunk_size, *arrays):
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
new_cell_size, new_shape):
"""Multilinear enmeshing."""
base_indices = jax.tree.map(jnp.asarray , base_indices)
displacements = jax.tree.map(jnp.asarray , displacements)
base_indices = jax.tree.map(jnp.asarray, base_indices)
displacements = jax.tree.map(jnp.asarray, displacements)
with jax.experimental.enable_x64():
cell_size = jnp.float64(
cell_size) if new_cell_size is not None else jnp.array(
Expand Down Expand Up @@ -61,8 +61,8 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
new_displacements = particle_positions - new_indices * new_cell_size

if base_shape is not None:
new_displacements -= jax.tree.map(jnp.rint ,
new_displacements / grid_length
new_displacements -= jax.tree.map(
jnp.rint, new_displacements / grid_length
) * grid_length # also abs(new_displacements) < new_cell_size is expected

new_indices = new_indices.astype(base_indices.dtype)
Expand All @@ -89,7 +89,7 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
if base_shape is not None:
new_indices %= base_shape

weights = 1 - jax.tree.map(jnp.abs , new_displacements)
weights = 1 - jax.tree.map(jnp.abs, new_displacements)

if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
Expand All @@ -109,11 +109,15 @@ def _scatter_chunk(carry, chunk):
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
spatial_shape)
# scatter
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)),
ind)
mesh_structure = jax.tree.structure(mesh)
val_flat = jax.tree.leaves(val)
val_tree = jax.tree.unflatten(mesh_structure, val_flat)
mesh = jax.tree.map(lambda m , v , i, f : m.at[i].add(jnp.multiply(jnp.expand_dims(v, axis=-1), f)) , mesh , val_tree ,ind , frac)
mesh = jax.tree.map(
lambda m, v, i, f: m.at[i].add(
jnp.multiply(jnp.expand_dims(v, axis=-1), f)), mesh, val_tree, ind,
frac)
carry = mesh, offset, cell_size, mesh_shape
return carry, None

Expand All @@ -125,10 +129,10 @@ def scatter(pmid,
val=1.,
offset=0,
cell_size=1.):

ptcl_num, spatial_ndim = pmid.shape
val = jax.tree.map(jnp.asarray , val)
mesh = jax.tree.map(jnp.asarray , mesh)
val = jax.tree.map(jnp.asarray, val)
mesh = jax.tree.map(jnp.asarray, mesh)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh, offset, cell_size, mesh.shape
if remainder is not None:
Expand All @@ -151,9 +155,9 @@ def _chunk_cat(remainder_array, chunked_array):
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape

mesh = jax.tree.map(jnp.asarray , mesh)
mesh = jax.tree.map(jnp.asarray, mesh)

val = jax.tree.map(jnp.asarray , val)
val = jax.tree.map(jnp.asarray, val)

if mesh.shape[spatial_ndim:] != val.shape[1:]:
raise ValueError('channel shape mismatch: '
Expand Down Expand Up @@ -187,11 +191,15 @@ def _gather_chunk(carry, chunk):
spatial_shape)

# gather
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)),
ind)
frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac)
ind_structure = jax.tree.structure(ind)
frac_structure = jax.tree.structure(frac)
mesh_structure = jax.tree.structure(mesh)
val += jax.tree.map(lambda m , i , f : (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1) , mesh , ind , frac)
val += jax.tree.map(
lambda m, i, f:
(m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1), mesh, ind,
frac)

return carry, val
19 changes: 12 additions & 7 deletions jaxpm/pm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp
import jax_cosmo as jc

Expand All @@ -7,7 +8,7 @@
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
invlaplace_kernel, longrange_kernel)
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
import jax


def pm_forces(positions,
mesh_shape=None,
Expand Down Expand Up @@ -52,10 +53,12 @@ def pm_forces(positions,
kvec, r_split=r_split)
# Computes gravitational forces
forces = [
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
) for i in range(3)]
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k), positions)
for i in range(3)
]

forces = jax.tree.map(lambda x ,y ,z : jnp.stack([x,y,z], axis=-1), forces[0], forces[1], forces[2])
forces = jax.tree.map(lambda x, y, z: jnp.stack([x, y, z], axis=-1),
forces[0], forces[1], forces[2])

return forces

Expand All @@ -73,8 +76,9 @@ def lpt(cosmo,
"""
paint_absolute_pos = particles is not None
if particles is None:
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic,
shape=(*ic.shape, 3)) , initial_conditions)
particles = jax.tree.map(
lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)),
initial_conditions)

a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
Expand Down Expand Up @@ -198,7 +202,8 @@ def nbody_ode(a, state, args):
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces

return jax.tree.map(lambda dp , dv : jnp.stack([dp, dv],axis=0), dpos, dvel)
return jax.tree.map(lambda dp, dv: jnp.stack([dp, dv], axis=0), dpos,
dvel)

return nbody_ode

Expand Down
8 changes: 5 additions & 3 deletions jaxpm/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm
from scipy.special import legendre
import jax

__all__ = [
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
'cross_correlation_coefficients', 'gaussian_smoothing'
]


def _initialize_pk(mesh_shape, box_shape, kedges, los):
"""
Parameters
Expand Down Expand Up @@ -100,11 +101,12 @@ def power_spectrum(mesh,
n_bins = len(kavg) + 2

# FFTs
meshk = jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho') , mesh)
meshk = jax.tree.map(lambda x: jnp.fft.fftn(x, norm='ortho'), mesh)
if mesh2 is None:
mmk = meshk.real**2 + meshk.imag**2
else:
mmk = meshk * jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho').conj() , mesh2)
mmk = meshk * jax.tree.map(
lambda x: jnp.fft.fftn(x, norm='ortho').conj(), mesh2)

# Sum powers
pk = jnp.empty((len(poles), n_bins))
Expand Down
Loading

0 comments on commit 1f5c619

Please sign in to comment.