Skip to content

Commit

Permalink
Merge pull request #101 from hsalehipour/fix_jax_momentum_exchange
Browse files Browse the repository at this point in the history
Fixed the momentum exchange method in the JAX backend.
  • Loading branch information
hsalehipour authored Jan 17, 2025
2 parents d4a92bc + 4fe0905 commit 4e96b9d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 21 deletions.
27 changes: 12 additions & 15 deletions xlb/operator/boundary_condition/bc_extrapolation_outflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ def __init__(
mesh_vertices,
)

# Set the flag for auxilary data recovery
self.needs_aux_recovery = True

# find and store the normal vector using indices
self._get_normal_vec(indices)

Expand Down Expand Up @@ -159,15 +156,15 @@ def functional(
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
_f_pre: Any,
_f_post: Any,
):
# Post-streaming values are only modified at missing direction
_f = f_post
_f = _f_post
for l in range(self.velocity_set.q):
# If the mask is missing then take the opposite index
if missing_mask[l] == wp.uint8(1):
_f[l] = f_pre[_opp_indices[l]]
_f[l] = _f_pre[_opp_indices[l]]
return _f

@wp.func
Expand All @@ -177,13 +174,13 @@ def update_bc_auxilary_data(
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
_f_pre: Any,
_f_post: Any,
):
# Update the auxilary data for this BC using the neighbour's populations stored in f_aux and
# f_pre (post-streaming values of the current voxel). We use directions that leave the domain
# for storing this prepared data.
_f = f_post
_f = _f_post
nv = get_normal_vectors(missing_mask)
for l in range(self.velocity_set.q):
if missing_mask[l] == wp.uint8(1):
Expand All @@ -194,19 +191,19 @@ def update_bc_auxilary_data(
pull_index[d] = index[d] - (_c[d, l] + nv[d])
# The following is the post-streaming values of the neighbor cell
f_aux = self.compute_dtype(f_0[l, pull_index[0], pull_index[1], pull_index[2]])
_f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * f_pre[l] + sound_speed * f_aux
_f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * _f_pre[l] + sound_speed * f_aux
return _f

kernel = self._construct_kernel(functional)

return (functional, update_bc_auxilary_data), kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
def warp_implementation(self, _f_pre, _f_post, bc_mask, missing_mask):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[f_pre, f_post, bc_mask, missing_mask],
dim=f_pre.shape[1:],
inputs=[_f_pre, _f_post, bc_mask, missing_mask],
dim=_f_pre.shape[1:],
)
return f_post
return _f_post
14 changes: 9 additions & 5 deletions xlb/operator/force/momentum_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ def __init__(

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0))
def jax_implementation(self, f, bc_mask, missing_mask):
def jax_implementation(self, f_0, f_1, bc_mask, missing_mask):
"""
Parameters
----------
f : jax.numpy.ndarray
f_0 : jax.numpy.ndarray
The post-collision distribution function at each node in the grid.
f_1 : jax.numpy.ndarray
The buffer field the same size as f_0 (only given as input for consistency with the WARP backened API.)
bc_mask : jax.numpy.ndarray
A grid field with 0 everywhere except for boundary nodes which are designated
by their respective boundary id's.
Expand All @@ -69,7 +71,7 @@ def jax_implementation(self, f, bc_mask, missing_mask):
The force exerted on the solid geometry at each boundary node.
"""
# Give the input post-collision populations, streaming once and apply the BC the find post-stream values.
f_post_collision = f
f_post_collision = f_0
f_post_stream = self.stream(f_post_collision)
f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, bc_mask, missing_mask)

Expand All @@ -79,11 +81,13 @@ def jax_implementation(self, f, bc_mask, missing_mask):
boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1)))

# the following will return force as a grid-based field with zero everywhere except for boundary nodes.
is_edge = jnp.logical_and(boundary, ~missing_mask[0])
opp = self.velocity_set.opp_indices
phi = f_post_collision[opp] + f_post_stream
phi = jnp.where(jnp.logical_and(boundary, missing_mask), phi, 0.0)
phi = jnp.where(jnp.logical_and(missing_mask, is_edge), phi, 0.0)
force = jnp.tensordot(self.velocity_set.c[:, opp], phi, axes=(-1, 0))
return force
force_net = jnp.sum(force, axis=(i + 1 for i in range(self.velocity_set.d)))
return force_net

def _construct_warp(self):
# Set local constants TODO: This is a hack and should be fixed with warp update
Expand Down
2 changes: 1 addition & 1 deletion xlb/velocity_set/d3q19.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class D3Q19(VelocitySet):

def __init__(self, precision_policy, backend):
# Construct the velocity vectors and weights
c = np.array([ci for ci in itertools.product([-1, 0, 1], repeat=3) if np.sum(np.abs(ci)) <= 2]).T
c = np.array([ci for ci in itertools.product([0, -1, 1], repeat=3) if np.sum(np.abs(ci)) <= 2]).T
w = np.zeros(19)
for i in range(19):
if np.sum(np.abs(c[:, i])) == 0:
Expand Down

0 comments on commit 4e96b9d

Please sign in to comment.