Skip to content

Commit

Permalink
Merge pull request #98 from hsalehipour/bc_warp_function_helper
Browse files Browse the repository at this point in the history
Added a helper function for bc related warp functions
  • Loading branch information
hsalehipour authored Jan 7, 2025
2 parents 5340e6c + e32a41d commit b54013b
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 210 deletions.
1 change: 1 addition & 0 deletions xlb/operator/boundary_condition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from xlb.operator.boundary_condition.helper_functions_bc import HelperFunctionsBC
from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition
from xlb.operator.boundary_condition.boundary_condition_registry import BoundaryConditionRegistry
from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_do_nothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)


class DoNothingBC(BoundaryCondition):
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)


class EquilibriumBC(BoundaryCondition):
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_extrapolation_outflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)


class ExtrapolationOutflowBC(BoundaryCondition):
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_fullway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
BoundaryCondition,
ImplementationStep,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)


class FullwayBounceBackBC(BoundaryCondition):
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_grads_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)


class GradsApproximationBC(BoundaryCondition):
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_halfway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)


class HalfwayBounceBackBC(BoundaryCondition):
Expand Down
100 changes: 13 additions & 87 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
from xlb.precision_policy import PrecisionPolicy
from xlb.compute_backend import ComputeBackend
from xlb.operator.operator import Operator
from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC
from xlb.operator.boundary_condition.boundary_condition import ImplementationStep
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry
from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux
from xlb.operator.boundary_condition import ZouHeBC, HelperFunctionsBC
from xlb.operator.macroscopic import SecondMoment as MomentumFlux


class RegularizedBC(ZouHeBC):
Expand Down Expand Up @@ -64,7 +62,6 @@ def __init__(
indices,
mesh_vertices,
)
# Overwrite the boundary condition registry id with the bc_type in the name
self.momentum_flux = MomentumFlux()

@partial(jit, static_argnums=(0,), inline=True)
Expand Down Expand Up @@ -127,83 +124,12 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
return f_post

def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
# load helper functions
bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend)
# Set local constants
_d = self.velocity_set.d
_q = self.velocity_set.q

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
# compute Qi tensor and store it in self
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_opp_indices = self.velocity_set.opp_indices
_w = self.velocity_set.w
_c = self.velocity_set.c
_c_float = self.velocity_set.c_float
_qi = self.velocity_set.qi
# TODO: related to _c_float: this is way less than ideal. we should not be making new types

@wp.func
def _get_fsum(
fpop: Any,
missing_mask: Any,
):
fsum_known = self.compute_dtype(0.0)
fsum_middle = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[_opp_indices[l]] == wp.uint8(1):
fsum_known += self.compute_dtype(2.0) * fpop[l]
elif missing_mask[l] != wp.uint8(1):
fsum_middle += fpop[l]
return fsum_known + fsum_middle

@wp.func
def get_normal_vectors(
missing_mask: Any,
):
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])

@wp.func
def bounceback_nonequilibrium(
fpop: Any,
feq: Any,
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop

@wp.func
def regularize_fpop(
fpop: Any,
feq: Any,
):
"""
Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop.
"""
# Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq}
f_neq = fpop - feq
PiNeq = self.momentum_flux.warp_functional(f_neq)

# Compute double dot product Qi:Pi1 (where Pi1 = PiNeq)
nt = _d * (_d + 1) // 2
for l in range(_q):
QiPi1 = self.compute_dtype(0.0)
for t in range(nt):
QiPi1 += _qi[l, t] * PiNeq[t]

# assign all populations based on eq 45 of Latt et al (2008)
# fneq ~ f^1
fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1
fpop[l] = feq[l] + fpop1
return fpop

@wp.func
def functional_velocity(
Expand All @@ -219,7 +145,7 @@ def functional_velocity(
_f = f_post

# Find normal vector
normals = get_normal_vectors(missing_mask)
normals = bc_helper.get_normal_vectors(missing_mask)

# Find the value of u from the missing directions
# Since we are only considering normal velocity, we only need to find one value (stored at the center of f_1)
Expand All @@ -228,18 +154,18 @@ def functional_velocity(
_u = -prescribed_value * normals

# calculate rho
fsum = _get_fsum(_f, missing_mask)
fsum = bc_helper.get_bc_fsum(_f, missing_mask)
unormal = self.compute_dtype(0.0)
for d in range(_d):
unormal += _u[d] * normals[d]
_rho = fsum / (self.compute_dtype(1.0) + unormal)

# impose non-equilibrium bounceback
feq = self.equilibrium_operator.warp_functional(_rho, _u)
_f = bounceback_nonequilibrium(_f, feq, missing_mask)
_f = bc_helper.bounceback_nonequilibrium(_f, feq, missing_mask)

# Regularize the boundary fpop
_f = regularize_fpop(_f, feq)
_f = bc_helper.regularize_fpop(_f, feq)
return _f

@wp.func
Expand All @@ -256,23 +182,23 @@ def functional_pressure(
_f = f_post

# Find normal vector
normals = get_normal_vectors(missing_mask)
normals = bc_helper.get_normal_vectors(missing_mask)

# Find the value of rho from the missing directions
# Since we need only one scalar value, we only need to find one value (stored at the center of f_1)
_rho = f_1[0, index[0], index[1], index[2]]

# calculate velocity
fsum = _get_fsum(_f, missing_mask)
fsum = bc_helper.get_bc_fsum(_f, missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
_u = unormal * normals

# impose non-equilibrium bounceback
feq = self.equilibrium_operator.warp_functional(_rho, _u)
_f = bounceback_nonequilibrium(_f, feq, missing_mask)
_f = bc_helper.bounceback_nonequilibrium(_f, feq, missing_mask)

# Regularize the boundary fpop
_f = regularize_fpop(_f, feq)
_f = bc_helper.regularize_fpop(_f, feq)
return _f

if self.bc_type == "velocity":
Expand Down
66 changes: 10 additions & 56 deletions xlb/operator/boundary_condition/bc_zouhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)
from xlb.operator.boundary_condition import HelperFunctionsBC
from xlb.operator.equilibrium import QuadraticEquilibrium
import jax


class ZouHeBC(BoundaryCondition):
Expand Down Expand Up @@ -272,55 +269,12 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
return f_post

def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
# load helper functions
bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend)
# Set local constants
_d = self.velocity_set.d
_q = self.velocity_set.q

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_opp_indices = self.velocity_set.opp_indices
_c = self.velocity_set.c
_c_float = self.velocity_set.c_float
# TODO: this is way less than ideal. we should not be making new types

@wp.func
def _get_fsum(
fpop: Any,
missing_mask: Any,
):
fsum_known = self.compute_dtype(0.0)
fsum_middle = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[_opp_indices[l]] == wp.uint8(1):
fsum_known += self.compute_dtype(2.0) * fpop[l]
elif missing_mask[l] != wp.uint8(1):
fsum_middle += fpop[l]
return fsum_known + fsum_middle

@wp.func
def get_normal_vectors(
missing_mask: Any,
):
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])

@wp.func
def bounceback_nonequilibrium(
fpop: Any,
feq: Any,
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop

@wp.func
def functional_velocity(
Expand All @@ -336,10 +290,10 @@ def functional_velocity(
_f = _f_post

# Find normal vector
normals = get_normal_vectors(_missing_mask)
normals = bc_helper.get_normal_vectors(_missing_mask)

# calculate rho
fsum = _get_fsum(_f, _missing_mask)
fsum = bc_helper.get_bc_fsum(_f, _missing_mask)
unormal = self.compute_dtype(0.0)

# Find the value of u from the missing directions
Expand All @@ -355,7 +309,7 @@ def functional_velocity(

# impose non-equilibrium bounceback
_feq = self.equilibrium_operator.warp_functional(_rho, _u)
_f = bounceback_nonequilibrium(_f, _feq, _missing_mask)
_f = bc_helper.bounceback_nonequilibrium(_f, _feq, _missing_mask)
return _f

@wp.func
Expand All @@ -372,20 +326,20 @@ def functional_pressure(
_f = _f_post

# Find normal vector
normals = get_normal_vectors(_missing_mask)
normals = bc_helper.get_normal_vectors(_missing_mask)

# Find the value of rho from the missing directions
# Since we need only one scalar value, we only need to find one value (stored at the center of f_1)
_rho = f_1[0, index[0], index[1], index[2]]

# calculate velocity
fsum = _get_fsum(_f, _missing_mask)
fsum = bc_helper.get_bc_fsum(_f, _missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
_u = unormal * normals

# impose non-equilibrium bounceback
feq = self.equilibrium_operator.warp_functional(_rho, _u)
_f = bounceback_nonequilibrium(_f, feq, _missing_mask)
_f = bc_helper.bounceback_nonequilibrium(_f, feq, _missing_mask)
return _f

if self.bc_type == "velocity":
Expand Down
Loading

0 comments on commit b54013b

Please sign in to comment.