Skip to content

Commit

Permalink
[Mosaic GPU] Call mbarrier.try_wait only once mbarrier.test_wait fails
Browse files Browse the repository at this point in the history
The llvm.expect intrinsic puts the loop at the end of the program, allowing
the whole barrier to be compiled to a test_wait + predicated branch that is
immediately followed by the continuation. This seems to make the happy path
a little faster which can help reduce the barrier overhead for compute-bound
kernels.

PiperOrigin-RevId: 645007019
  • Loading branch information
apaszke authored and jax authors committed Jun 20, 2024
1 parent f2956a4 commit 97ce128
Showing 1 changed file with 43 additions and 5 deletions.
48 changes: 43 additions & 5 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,13 +537,32 @@ class Barrier:
barrier_array: BarrierArray
offset: ir.Value

def wait_parity(self, parity):
def wait_parity(self, parity, expect_wait=False):
i1 = ir.IntegerType.get_signless(1)
index = ir.IndexType.get()
nvgpu.mbarrier_try_wait_parity(
self.barrier_array.value, parity, c(10000000, index), self.offset,
if expect_wait:
nvgpu.mbarrier_try_wait_parity(
self.barrier_array.value, parity, c(10000000, index), self.offset,
)
return
barrier_ptr = self.get_ptr()
barrier_ready = llvm.inline_asm(
i1,
[barrier_ptr, parity],
"mbarrier.test_wait.parity.shared.b64 $0, [$1], $2;",
"=b,l,r",
asm_dialect=0,
has_side_effects=True,
)
should_wait = arith.xori(barrier_ready, c(1, i1))
should_wait = llvm.intr_expect(should_wait, c(0, i1))
with ir.InsertionPoint(scf.IfOp(should_wait).then_block):
nvgpu.mbarrier_try_wait_parity(
self.barrier_array.value, parity, c(10000000, index), self.offset,
)
scf.yield_([])

def wait(self):
def wait(self, expect_wait=False):
i32 = ir.IntegerType.get_signless(32)
parities = memref.load(self.barrier_array.phases, [])
offset_i32 = arith.index_castui(i32, self.offset)
Expand All @@ -553,12 +572,31 @@ def wait(self):
)
new_parities = arith.xori(parities, bitmask)
memref.store(new_parities, self.barrier_array.phases, [])
self.wait_parity(parity)
self.wait_parity(parity, expect_wait=expect_wait)

def arrive(self):
token_ty = ir.Type.parse("!nvgpu.mbarrier.token")
nvgpu.mbarrier_arrive(token_ty, self.barrier_array.value, self.offset)

def get_ptr(self):
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
ptr_ty = ir.Type.parse("!llvm.ptr<3>")
smem = ir.IntegerAttr.get(i64, 3)
num_barriers = self.barrier_array.num_barriers
mbarrier_ref_ty = ir.MemRefType.get((num_barriers,), i64, memory_space=smem)
mbarrier_ref = builtin.unrealized_conversion_cast(
[mbarrier_ref_ty], [self.barrier_array.value],
)
mbarrier_ref_ptr = memref.extract_aligned_pointer_as_index(mbarrier_ref)
barrier_arr_ptr = llvm.inttoptr(
ptr_ty, arith.index_cast(i64, mbarrier_ref_ptr),
)
offset_i32 = arith.index_cast(i32, self.offset)
return llvm.getelementptr(
ptr_ty, barrier_arr_ptr, [offset_i32], [-2147483648], i64,
)


class Partition:
source_bounds: tuple[int, ...]
Expand Down

0 comments on commit 97ce128

Please sign in to comment.