Skip to content

Commit

Permalink
updates for FullField
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Feb 5, 2025
1 parent 34bf577 commit 9b8e42a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
35 changes: 24 additions & 11 deletions jaxpm/painting.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,39 @@ def cic_paint_2d(mesh, positions, weight):
positions: [npart, 2]
weight: [npart]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
positions = positions.reshape([-1, 2])
positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions)
floor = jax.tree.map(jnp.floor, positions)
connection = jnp.array([[[0, 0], [1., 0], [0., 1], [1., 1]]])

neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = 1. - jax.tree.map(jnp.abs, positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[..., jnp.newaxis]

neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
jnp.array(mesh.shape))
if weight is not None:
if jax.tree.all(jax.tree.map(jnp.isscalar, weight)):
kernel = jax.tree.map(
lambda k, w: jnp.multiply(jnp.expand_dims(w, axis=-1), k),
kernel, weight)
else:
kernel = jax.tree.map(
lambda k, w: jnp.multiply(w.reshape(*positions.shape[:-1]), k),
kernel, weight)

neighboor_coords = jax.tree.map(
lambda nc: jnp.mod(
nc.reshape([-1, 4, 2]).astype('int32'), jnp.array(mesh.shape)
), neighboor_coords)

dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0,
1))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
dnums)
mesh = jax.tree.map(
lambda g, nc, k: lax.scatter_add(g, nc, k.reshape([-1, 4]), dnums),
mesh, neighboor_coords, kernel)


return mesh


Expand Down
2 changes: 1 addition & 1 deletion jaxpm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,4 @@ def gaussian_smoothing(im, sigma):
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
filter /= filter[0, 0]

return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real
return jax.tree.map(lambda im : jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real , im)

0 comments on commit 9b8e42a

Please sign in to comment.