Skip to content

Commit

Permalink
Moved omega from an attribute of the collision and stepper operations…
Browse files Browse the repository at this point in the history
… to an input of the methods.
  • Loading branch information
hsalehipour committed Jan 3, 2025
1 parent 5340e6c commit ec707cb
Show file tree
Hide file tree
Showing 14 changed files with 46 additions and 51 deletions.
3 changes: 1 addition & 2 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def setup_boundary_conditions(self):

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
Expand Down Expand Up @@ -127,7 +126,7 @@ def bc_profile_jax():
def run(self, num_steps, post_process_interval=100):
start_time = time.time()
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if i % post_process_interval == 0 or i == num_steps - 1:
Expand Down
3 changes: 1 addition & 2 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,14 @@ def setup_boundary_conditions(self):

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
)

def run(self, num_steps, post_process_interval=100):
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if i % post_process_interval == 0 or i == num_steps - 1:
Expand Down
1 change: 0 additions & 1 deletion examples/cfd/lid_driven_cavity_2d_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, pre
def setup_stepper(self):
# Create the base stepper
stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
Expand Down
3 changes: 1 addition & 2 deletions examples/cfd/turbulent_channel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def initialize_fields(self):

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="KBC",
Expand All @@ -108,7 +107,7 @@ def setup_stepper(self):
def run(self, num_steps, print_interval, post_process_interval=100):
start_time = time.time()
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if (i + 1) % print_interval == 0:
Expand Down
5 changes: 2 additions & 3 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def define_boundary_indices(self):
walls = np.unique(np.array(walls), axis=-1).tolist()

# Load the mesh (replace with your own mesh)
stl_filename = "../stl-files/DrivAer-Notchback.stl"
stl_filename = "examples/cfd/stl-files/DrivAer-Notchback.stl"
mesh = trimesh.load_mesh(stl_filename, process=False)
mesh_vertices = mesh.vertices

Expand All @@ -98,7 +98,6 @@ def setup_boundary_conditions(self):

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="KBC",
Expand All @@ -111,7 +110,7 @@ def run(self, num_steps, print_interval, post_process_interval=100):

start_time = time.time()
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if (i + 1) % print_interval == 0:
Expand Down
5 changes: 3 additions & 2 deletions examples/performance/mlups_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run(backend, precision_policy, grid_shape, num_steps):
boundary_conditions = [EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=lid), FullwayBounceBackBC(indices=walls)]

# Create stepper
stepper = IncompressibleNavierStokesStepper(omega=1.0, grid=grid, boundary_conditions=boundary_conditions, collision_type="BGK")
stepper = IncompressibleNavierStokesStepper(grid=grid, boundary_conditions=boundary_conditions, collision_type="BGK")

# Distribute if using JAX backend
if backend == ComputeBackend.JAX:
Expand All @@ -64,11 +64,12 @@ def run(backend, precision_policy, grid_shape, num_steps):
)

# Initialize fields and run simulation
omega = 1.0
f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields()
start_time = time.time()

for i in range(num_steps):
f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, i)
f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, i)
f_0, f_1 = f_1, f_0
wp.synchronize()

Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/collision/test_bgk_collision_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def test_bgk_ollision(dim, velocity_set, grid_shape, omega):

# Compute collision

compute_collision = BGK(omega=omega)
compute_collision = BGK()

f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)

f_out = compute_collision(f_orig, f_eq, rho, u)
f_out = compute_collision(f_orig, f_eq, rho, u, omega)

assert jnp.allclose(f_out, f_orig - omega * (f_orig - f_eq))

Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/collision/test_bgk_collision_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega):
f_eq = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)
f_eq = compute_macro(rho, u, f_eq)

compute_collision = BGK(omega=omega)
compute_collision = BGK()
f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)

f_out = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)
f_out = compute_collision(f_orig, f_eq, f_out, rho, u)
f_out = compute_collision(f_orig, f_eq, f_out, rho, u, omega)

f_eq = f_eq.numpy()
f_out = f_out.numpy()
Expand Down
15 changes: 8 additions & 7 deletions xlb/operator/collision/bgk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,21 @@ class BGK(Collision):

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0,))
def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u):
def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u, omega):
fneq = f - feq
fout = f - self.compute_dtype(self.omega) * fneq
fout = f - self.compute_dtype(omega) * fneq
return fout

def _construct_warp(self):
# Set local constants TODO: This is a hack and should be fixed with warp update
_w = self.velocity_set.w
_omega = wp.constant(self.compute_dtype(self.omega))
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)

# Construct the functional
@wp.func
def functional(f: Any, feq: Any, rho: Any, u: Any):
def functional(f: Any, feq: Any, rho: Any, u: Any, omega: Any):
fneq = f - feq
fout = f - _omega * fneq
fout = f - omega * fneq
return fout

# Construct the warp kernel
Expand All @@ -42,6 +41,7 @@ def kernel(
fout: wp.array4d(dtype=Any),
rho: wp.array4d(dtype=Any),
u: wp.array4d(dtype=Any),
omega: Any,
):
# Get the global index
i, j, k = wp.tid()
Expand All @@ -55,7 +55,7 @@ def kernel(
_feq[l] = feq[l, index[0], index[1], index[2]]

# Compute the collision
_fout = functional(_f, _feq, rho, u)
_fout = functional(_f, _feq, rho, u, omega)

# Write the result
for l in range(self.velocity_set.q):
Expand All @@ -64,7 +64,7 @@ def kernel(
return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, feq, fout, rho, u):
def warp_implementation(self, f, feq, fout, rho, u, omega):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
Expand All @@ -74,6 +74,7 @@ def warp_implementation(self, f, feq, fout, rho, u):
fout,
rho,
u,
omega,
],
dim=f.shape[1:],
)
Expand Down
8 changes: 0 additions & 8 deletions xlb/operator/collision/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,12 @@ class Collision(Operator):
This class defines the collision step for the Lattice Boltzmann Method.
Parameters
----------
omega : float
Relaxation parameter for collision step. Default value is 0.6.
shear : bool
Flag to indicate whether the collision step requires the shear stress.
"""

def __init__(
self,
omega: float,
velocity_set: VelocitySet = None,
precision_policy=None,
compute_backend=None,
):
self.omega = omega
super().__init__(velocity_set, precision_policy, compute_backend)
2 changes: 1 addition & 1 deletion xlb/operator/collision/forced_collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
):
assert collision_operator is not None
self.collision_operator = collision_operator
super().__init__(self.collision_operator.omega)
super().__init__()

assert forcing_scheme == "exact_difference", NotImplementedError(f"Force model {forcing_scheme} not implemented!")
assert force_vector.shape[0] == self.velocity_set.d, "Check the dimensions of the input force!"
Expand Down
26 changes: 16 additions & 10 deletions xlb/operator/collision/kbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,14 @@ class KBC(Collision):

def __init__(
self,
omega: float,
velocity_set: VelocitySet = None,
precision_policy=None,
compute_backend=None,
):
self.momentum_flux = MomentumFlux()
self.epsilon = 1e-32
self.beta = omega * 0.5
self.inv_beta = 1.0 / self.beta

super().__init__(
omega=omega,
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
Expand All @@ -49,6 +45,7 @@ def jax_implementation(
feq: jnp.ndarray,
rho: jnp.ndarray,
u: jnp.ndarray,
omega,
):
"""
KBC collision step for lattice.
Expand All @@ -74,13 +71,17 @@ def jax_implementation(
else:
raise NotImplementedError("Velocity set not supported: {}".format(type(self.velocity_set)))

# Compute required constants based on the input omega (omega is the inverse relaxation time)
beta = omega * 0.5
inv_beta = 1.0 / beta

# Perform collision
delta_h = fneq - delta_s
gamma = self.inv_beta - (2.0 - self.inv_beta) * self.entropic_scalar_product(delta_s, delta_h, feq) / (
gamma = inv_beta - (2.0 - inv_beta) * self.entropic_scalar_product(delta_s, delta_h, feq) / (
self.epsilon + self.entropic_scalar_product(delta_h, delta_h, feq)
)

fout = f - self.beta * (2.0 * delta_s + gamma[None, ...] * delta_h)
fout = f - beta * (2.0 * delta_s + gamma[None, ...] * delta_h)

return fout

Expand Down Expand Up @@ -185,8 +186,6 @@ def _construct_warp(self):
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_epsilon = wp.constant(self.compute_dtype(self.epsilon))
_beta = wp.constant(self.compute_dtype(self.beta))
_inv_beta = wp.constant(self.compute_dtype(1.0 / self.beta))

@wp.func
def decompose_shear_d2q9(fneq: Any):
Expand Down Expand Up @@ -268,6 +267,7 @@ def functional(
feq: Any,
rho: Any,
u: Any,
omega: Any,
):
# Compute shear and delta_s
fneq = f - feq
Expand All @@ -278,6 +278,10 @@ def functional(
shear = decompose_shear_d2q9(fneq)
delta_s = shear * rho / self.compute_dtype(4.0)

# Compute required constants based on the input omega (omega is the inverse relaxation time)
_beta = omega * 0.5
_inv_beta = 1.0 / _beta

# Perform collision
delta_h = fneq - delta_s
two = self.compute_dtype(2.0)
Expand All @@ -296,6 +300,7 @@ def kernel(
fout: wp.array4d(dtype=Any),
rho: wp.array4d(dtype=Any),
u: wp.array4d(dtype=Any),
omega: Any,
):
# Get the global index
i, j, k = wp.tid()
Expand All @@ -314,7 +319,7 @@ def kernel(
_rho = rho[0, index[0], index[1], index[2]]

# Compute the collision
_fout = functional(_f, _feq, _rho, _u)
_fout = functional(_f, _feq, _rho, _u, omega)

# Write the result
for l in range(self.velocity_set.q):
Expand All @@ -323,7 +328,7 @@ def kernel(
return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, feq, fout, rho, u):
def warp_implementation(self, f, feq, fout, rho, u, omega):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
Expand All @@ -333,6 +338,7 @@ def warp_implementation(self, f, feq, fout, rho, u):
fout,
rho,
u,
omega,
],
dim=f.shape[1:],
)
Expand Down
2 changes: 1 addition & 1 deletion xlb/operator/equilibrium/quadratic_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class QuadraticEquilibrium(Equilibrium):
def jax_implementation(self, rho, u):
cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0))
usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True)
w = self.velocity_set.w.reshape((-1,) + (1,) * self.velocity_set.d)
w = self.velocity_set.w.reshape((-1,) + (1,) * (len(rho.shape) - 1))
feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr)
return feq

Expand Down
Loading

0 comments on commit ec707cb

Please sign in to comment.