diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 0cd17e02dce3..a620eb1f1302 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -537,13 +537,32 @@ class Barrier: barrier_array: BarrierArray offset: ir.Value - def wait_parity(self, parity): + def wait_parity(self, parity, expect_wait=False): + i1 = ir.IntegerType.get_signless(1) index = ir.IndexType.get() - nvgpu.mbarrier_try_wait_parity( - self.barrier_array.value, parity, c(10000000, index), self.offset, + if expect_wait: + nvgpu.mbarrier_try_wait_parity( + self.barrier_array.value, parity, c(10000000, index), self.offset, + ) + return + barrier_ptr = self.get_ptr() + barrier_ready = llvm.inline_asm( + i1, + [barrier_ptr, parity], + "mbarrier.test_wait.parity.shared.b64 $0, [$1], $2;", + "=b,l,r", + asm_dialect=0, + has_side_effects=True, ) + should_wait = arith.xori(barrier_ready, c(1, i1)) + should_wait = llvm.intr_expect(should_wait, c(0, i1)) + with ir.InsertionPoint(scf.IfOp(should_wait).then_block): + nvgpu.mbarrier_try_wait_parity( + self.barrier_array.value, parity, c(10000000, index), self.offset, + ) + scf.yield_([]) - def wait(self): + def wait(self, expect_wait=False): i32 = ir.IntegerType.get_signless(32) parities = memref.load(self.barrier_array.phases, []) offset_i32 = arith.index_castui(i32, self.offset) @@ -553,12 +572,31 @@ def wait(self): ) new_parities = arith.xori(parities, bitmask) memref.store(new_parities, self.barrier_array.phases, []) - self.wait_parity(parity) + self.wait_parity(parity, expect_wait=expect_wait) def arrive(self): token_ty = ir.Type.parse("!nvgpu.mbarrier.token") nvgpu.mbarrier_arrive(token_ty, self.barrier_array.value, self.offset) + def get_ptr(self): + i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr<3>") + smem = ir.IntegerAttr.get(i64, 3) + num_barriers = self.barrier_array.num_barriers + mbarrier_ref_ty = ir.MemRefType.get((num_barriers,), i64, memory_space=smem) + mbarrier_ref = builtin.unrealized_conversion_cast( + [mbarrier_ref_ty], [self.barrier_array.value], + ) + mbarrier_ref_ptr = memref.extract_aligned_pointer_as_index(mbarrier_ref) + barrier_arr_ptr = llvm.inttoptr( + ptr_ty, arith.index_cast(i64, mbarrier_ref_ptr), + ) + offset_i32 = arith.index_cast(i32, self.offset) + return llvm.getelementptr( + ptr_ty, barrier_arr_ptr, [offset_i32], [-2147483648], i64, + ) + class Partition: source_bounds: tuple[int, ...]