Skip to content

Commit

Permalink
Merge pull request #100 from hsalehipour/omega_stepper_input
Browse files Browse the repository at this point in the history
Moved omega from an attribute of the collision to the input of its callable
  • Loading branch information
hsalehipour authored Jan 10, 2025
2 parents b54013b + c87b3a5 commit d4a92bc
Show file tree
Hide file tree
Showing 15 changed files with 56 additions and 59 deletions.
4 changes: 2 additions & 2 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
self.backend = backend
self.precision_policy = precision_policy
self.omega = omega

self.boundary_conditions = []
self.u_max = 0.04

Expand Down Expand Up @@ -75,7 +76,6 @@ def setup_boundary_conditions(self):

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
Expand Down Expand Up @@ -127,7 +127,7 @@ def bc_profile_jax():
def run(self, num_steps, post_process_interval=100):
start_time = time.time()
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if i % post_process_interval == 0 or i == num_steps - 1:
Expand Down
3 changes: 1 addition & 2 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,14 @@ def setup_boundary_conditions(self):

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
)

def run(self, num_steps, post_process_interval=100):
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if i % post_process_interval == 0 or i == num_steps - 1:
Expand Down
1 change: 0 additions & 1 deletion examples/cfd/lid_driven_cavity_2d_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, pre
def setup_stepper(self):
# Create the base stepper
stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
Expand Down
3 changes: 1 addition & 2 deletions examples/cfd/turbulent_channel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def initialize_fields(self):

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="KBC",
Expand All @@ -108,7 +107,7 @@ def setup_stepper(self):
def run(self, num_steps, print_interval, post_process_interval=100):
start_time = time.time()
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if (i + 1) % print_interval == 0:
Expand Down
3 changes: 1 addition & 2 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def setup_boundary_conditions(self):

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
omega=self.omega,
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="KBC",
Expand All @@ -111,7 +110,7 @@ def run(self, num_steps, print_interval, post_process_interval=100):

start_time = time.time()
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if (i + 1) % print_interval == 0:
Expand Down
5 changes: 3 additions & 2 deletions examples/performance/mlups_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run(backend, precision_policy, grid_shape, num_steps):
boundary_conditions = [EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=lid), FullwayBounceBackBC(indices=walls)]

# Create stepper
stepper = IncompressibleNavierStokesStepper(omega=1.0, grid=grid, boundary_conditions=boundary_conditions, collision_type="BGK")
stepper = IncompressibleNavierStokesStepper(grid=grid, boundary_conditions=boundary_conditions, collision_type="BGK")

# Distribute if using JAX backend
if backend == ComputeBackend.JAX:
Expand All @@ -64,11 +64,12 @@ def run(backend, precision_policy, grid_shape, num_steps):
)

# Initialize fields and run simulation
omega = 1.0
f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields()
start_time = time.time()

for i in range(num_steps):
f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, i)
f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, i)
f_0, f_1 = f_1, f_0
wp.synchronize()

Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/collision/test_bgk_collision_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def test_bgk_ollision(dim, velocity_set, grid_shape, omega):

# Compute collision

compute_collision = BGK(omega=omega)
compute_collision = BGK()

f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)

f_out = compute_collision(f_orig, f_eq, rho, u)
f_out = compute_collision(f_orig, f_eq, rho, u, omega)

assert jnp.allclose(f_out, f_orig - omega * (f_orig - f_eq))

Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/collision/test_bgk_collision_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega):
f_eq = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)
f_eq = compute_macro(rho, u, f_eq)

compute_collision = BGK(omega=omega)
compute_collision = BGK()
f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)

f_out = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)
f_out = compute_collision(f_orig, f_eq, f_out, rho, u)
f_out = compute_collision(f_orig, f_eq, f_out, rho, u, omega)

f_eq = f_eq.numpy()
f_out = f_out.numpy()
Expand Down
2 changes: 1 addition & 1 deletion xlb/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def check_backend_support():
elif len(gpus) == 1:
print("Single-GPU support is available: 1 GPU detected.")

if jax.devices()[0].platform == "tpu":
elif jax.devices()[0].platform == "tpu":
tpus = jax.devices("tpu")
if len(tpus) > 1:
print("Multi-TPU support is available: {} TPUs detected.".format(len(tpus)))
Expand Down
15 changes: 8 additions & 7 deletions xlb/operator/collision/bgk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,21 @@ class BGK(Collision):

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0,))
def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u):
def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u, omega):
fneq = f - feq
fout = f - self.compute_dtype(self.omega) * fneq
fout = f - self.compute_dtype(omega) * fneq
return fout

def _construct_warp(self):
# Set local constants TODO: This is a hack and should be fixed with warp update
_w = self.velocity_set.w
_omega = wp.constant(self.compute_dtype(self.omega))
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)

# Construct the functional
@wp.func
def functional(f: Any, feq: Any, rho: Any, u: Any):
def functional(f: Any, feq: Any, rho: Any, u: Any, omega: Any):
fneq = f - feq
fout = f - _omega * fneq
fout = f - self.compute_dtype(omega) * fneq
return fout

# Construct the warp kernel
Expand All @@ -42,6 +41,7 @@ def kernel(
fout: wp.array4d(dtype=Any),
rho: wp.array4d(dtype=Any),
u: wp.array4d(dtype=Any),
omega: Any,
):
# Get the global index
i, j, k = wp.tid()
Expand All @@ -55,7 +55,7 @@ def kernel(
_feq[l] = feq[l, index[0], index[1], index[2]]

# Compute the collision
_fout = functional(_f, _feq, rho, u)
_fout = functional(_f, _feq, rho, u, omega)

# Write the result
for l in range(self.velocity_set.q):
Expand All @@ -64,7 +64,7 @@ def kernel(
return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, feq, fout, rho, u):
def warp_implementation(self, f, feq, fout, rho, u, omega):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
Expand All @@ -74,6 +74,7 @@ def warp_implementation(self, f, feq, fout, rho, u):
fout,
rho,
u,
omega,
],
dim=f.shape[1:],
)
Expand Down
9 changes: 0 additions & 9 deletions xlb/operator/collision/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,12 @@ class Collision(Operator):
Base class for collision operators.
This class defines the collision step for the Lattice Boltzmann Method.
Parameters
----------
omega : float
Relaxation parameter for collision step. Default value is 0.6.
shear : bool
Flag to indicate whether the collision step requires the shear stress.
"""

def __init__(
self,
omega: float,
velocity_set: VelocitySet = None,
precision_policy=None,
compute_backend=None,
):
self.omega = omega
super().__init__(velocity_set, precision_policy, compute_backend)
16 changes: 9 additions & 7 deletions xlb/operator/collision/forced_collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
):
assert collision_operator is not None
self.collision_operator = collision_operator
super().__init__(self.collision_operator.omega)
super().__init__()

assert forcing_scheme == "exact_difference", NotImplementedError(f"Force model {forcing_scheme} not implemented!")
assert force_vector.shape[0] == self.velocity_set.d, "Check the dimensions of the input force!"
Expand All @@ -33,8 +33,8 @@ def __init__(

@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0,))
def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u):
fout = self.collision_operator(f, feq, rho, u)
def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u, omega):
fout = self.collision_operator(f, feq, rho, u, omega)
fout = self.forcing_operator(fout, feq, rho, u)
return fout

Expand All @@ -45,8 +45,8 @@ def _construct_warp(self):

# Construct the functional
@wp.func
def functional(f: Any, feq: Any, rho: Any, u: Any):
fout = self.collision_operator.warp_functional(f, feq, rho, u)
def functional(f: Any, feq: Any, rho: Any, u: Any, omega: Any):
fout = self.collision_operator.warp_functional(f, feq, rho, u, omega)
fout = self.forcing_operator.warp_functional(fout, feq, rho, u)
return fout

Expand All @@ -58,6 +58,7 @@ def kernel(
fout: wp.array4d(dtype=Any),
rho: wp.array4d(dtype=Any),
u: wp.array4d(dtype=Any),
omega: Any,
):
# Get the global index
i, j, k = wp.tid()
Expand All @@ -76,7 +77,7 @@ def kernel(
_rho = rho[0, index[0], index[1], index[2]]

# Compute the collision
_fout = functional(_f, _feq, _rho, _u)
_fout = functional(_f, _feq, _rho, _u, omega)

# Write the result
for l in range(self.velocity_set.q):
Expand All @@ -85,7 +86,7 @@ def kernel(
return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f, feq, fout, rho, u):
def warp_implementation(self, f, feq, fout, rho, u, omega):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
Expand All @@ -95,6 +96,7 @@ def warp_implementation(self, f, feq, fout, rho, u):
fout,
rho,
u,
omega,
],
dim=f.shape[1:],
)
Expand Down
Loading

0 comments on commit d4a92bc

Please sign in to comment.