diff --git a/examples/distributed/example_allgather_gemm_overlapped.py b/examples/distributed/example_allgather_gemm_overlapped.py index cebf58ed1..61aff0dd9 100644 --- a/examples/distributed/example_allgather_gemm_overlapped.py +++ b/examples/distributed/example_allgather_gemm_overlapped.py @@ -93,7 +93,7 @@ def main( tid = T.get_thread_binding(0) T.clear(C_local) if tid == 0: - T.wait_eq(signal_buffer[pid_m * block_M // M_per_rank], 1) + T.wait_eq(signal_buffer[pid_m * block_M // M_per_rank], 1, dtype="uint32") for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(A[pid_m * block_M, k * block_K], A_shared) T.copy(B[k * block_K, pid_n * block_N], B_shared) diff --git a/examples/distributed/intranode/example_alltoall.py b/examples/distributed/intranode/example_alltoall.py new file mode 100644 index 000000000..b2bba1493 --- /dev/null +++ b/examples/distributed/intranode/example_alltoall.py @@ -0,0 +1,114 @@ +import tilelang +import tilelang.language as T +from tilelang.distributed import init_dist +import torch +import torch.distributed as dist +import argparse + + +def alltoall(PE_num, M, N, block_M, block_N, threads): + assert block_N == N + + @T.prim_func + def main( + src: T.Tensor((PE_num * M, N), "float16"), + dst: T.Tensor((PE_num * M, N), "float16"), + barrier: T.Tensor((PE_num), "int32"), + ): + # Currently not support tiled copy + with T.Kernel( + PE_num, T.ceildiv(M, block_M), T.ceildiv(N, block_N), + threads=threads) as (bx, by, bz): + rank = T.alloc_local([1], "int32") + num_ranks = T.alloc_local([1], "int32") + + dst_rank = bx + rank[0] = T.get_rank() + num_ranks[0] = T.get_num_ranks() + + T.put_block( + src=T.address_of(src[dst_rank * M + by * block_M, 0]), + dst=T.address_of(dst[rank[0] * M + by * block_M, 0]), + size=block_M * block_N, + dst_pe=dst_rank, + ) + T.fence_sys(sem=T.MemorySemantic.RELEASE) + + return main + + +def run_alltoall(local_rank, num_ranks, args): + PE_num = args.PE_num + M = args.M + N = args.N + block_M = 32 + block_N = N + threads = 256 + + local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) + allocator = tilelang.get_allocator( + size=2**34, + device="cuda", + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_ranks, + group=group_size, + ) + kernel = tilelang.compile(alltoall(PE_num, M, N, block_M, block_N, threads)) + kernel.initialize(allocator=allocator) + src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_() + dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() + barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() + + torch.cuda.synchronize() + dist.barrier(group_size) + + # Warmup + for _ in range(args.warmup): + kernel(src, dst, barrier) + dst.zero_() + torch.cuda.synchronize() + dist.barrier(group_size) + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(args.iter): + kernel(src, dst, barrier) + torch.cuda.synchronize() + dist.barrier(group_size) + end.record() + torch.cuda.synchronize() + dist.barrier(group_size) + elapsed_time = start.elapsed_time(end) / args.iter + print( + f"Rank {local_rank} Average Kernel execution time: {elapsed_time:.3f} ms, Bandwidth: {2 * PE_num * M * N / (elapsed_time * 1e6):.3f} GB/s" + ) + + # Torch Reference + torch.cuda.synchronize() + dst_ref = torch.zeros((PE_num * M, N), dtype=torch.float16, device="cuda") + dist.all_to_all_single(dst_ref, src, group=group_size) + torch.cuda.synchronize() + + if torch.allclose(dst, dst_ref, atol=1e-2, rtol=1e-2): + print(f"Rank {local_rank} Verification Passed! ✅") + else: + max_diff = (dst - dst_ref).abs().max() + print(f"Rank {local_rank} Verification Failed! ❌ Max diff: {max_diff}") + print(f"dst: {dst}") + print(f"dst_ref: {dst_ref}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--PE_num", type=int, default=8) + parser.add_argument("--M", type=int, default=8192) + parser.add_argument("--N", type=int, default=7168) + parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") + parser.add_argument("--iter", type=int, default=10, help="Number of benchmark iterations") + + args = parser.parse_args() + torch.multiprocessing.spawn(run_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py new file mode 100644 index 000000000..0c1d73aa2 --- /dev/null +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -0,0 +1,361 @@ +import tilelang +import tilelang.language as T +from tilelang.distributed import init_dist +import torch +import torch.distributed as dist +import argparse +from enum import IntEnum + +tilelang.disable_cache() + + +class Direction(IntEnum): + NORTH = 0 + SOUTH = 1 + WEST = 2 + EAST = 3 + SELF = 4 + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" +) +def torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, threads): + num_warps = threads // 32 + block_M = M // num_blocks + tile_M = block_M // num_warps + + @T.prim_func + def main_route( + # For each (src, dst) pair, the real transfer size is M * N + src: T.Tensor((PE_num * M, N), "float16"), + dst: T.Tensor((PE_num * M, N), "float16"), + # buffer[src_rank, dst_rank, *, *]: This PE save a slot for transferring data chunks for the real destination rank + buffer_transfer: T.Tensor((PE_num, PE_num, M, N), "float16"), + # Signal for each buffer + signal_transfer: T.Tensor((PE_num, PE_num, num_blocks, num_warps), "uint32"), + # Signal for finish + local_finish: T.Tensor((1), "uint32"), + global_finish: T.Tensor((1), "uint32"), + # Barrier for all blocks + barrier: T.Tensor((PE_num), "int32"), + ): + with T.Kernel(PE_num, PE_num, num_blocks, threads=threads) as (bx, by, bz): + tx = T.get_thread_binding() + warp_idx = tx // 32 + + rank = T.alloc_local([1], "uint32") + rank_x = T.alloc_local([1], "uint32") + rank_y = T.alloc_local([1], "uint32") + next_rank = T.alloc_local([1], "uint32") + to_dir = T.alloc_local([1], "uint32") + diff = T.alloc_local([1], "int32") + src_rank = T.alloc_local([1], "uint32") + dst_rank = T.alloc_local([1], "uint32") + old_local = T.alloc_local([1], "uint32") + old_global = T.alloc_local([1], "uint32") + num_tiles = T.alloc_local([1], "uint32") + flag = T.alloc_local([1], "bool") + + rank[0] = T.get_rank() + rank_x[0] = T.floordiv(rank[0], Y) + rank_y[0] = T.floormod(rank[0], Y) + next_rank[0] = rank[0] + + num_tiles[0] = M // tile_M + src_rank[0] = bx + dst_rank[0] = by + flag[0] = False + + # Prepare for routing + dst_rank_x = T.floordiv(dst_rank[0], Y) + dst_rank_y = T.floormod(dst_rank[0], Y) + to_dir[0] = Direction.SELF + # XY-routing: first route along X-axis, then Y-axis + if dst_rank_x != rank_x[0]: + # Calculate shortest path in Torus X-dimension + diff[0] = dst_rank_x - rank_x[0] + if diff[0] > T.floordiv(X, 2): + diff[0] -= X + elif diff[0] <= -T.floordiv(X, 2): + diff[0] += X + + if diff[0] < 0: + # Send North (up): neighbor receives in its north buffer + to_dir[0] = Direction.NORTH + next_rank[0] = T.floormod(rank_x[0] + X - 1, X) * Y + rank_y[0] + else: + # Send South (down): neighbor receives in its south buffer + to_dir[0] = Direction.SOUTH + next_rank[0] = T.floormod(rank_x[0] + 1, X) * Y + rank_y[0] + elif dst_rank_y != rank_y[0]: + # Calculate shortest path in Torus Y-dimension + diff[0] = dst_rank_y - rank_y[0] + if diff[0] > T.floordiv(Y, 2): + diff[0] -= Y + elif diff[0] <= -T.floordiv(Y, 2): + diff[0] += Y + + if diff[0] < 0: + # Send West (left): neighbor receives in its west buffer + to_dir[0] = Direction.WEST + next_rank[0] = rank_x[0] * Y + T.floormod(rank_y[0] + Y - 1, Y) + else: + # Send East (right): neighbor receives in its east buffer + to_dir[0] = Direction.EAST + next_rank[0] = rank_x[0] * Y + T.floormod(rank_y[0] + 1, Y) + + T.fence_gpu(sem=T.MemorySemantic.RELEASE) + + # Phase 1: Fully use all blocks to initially send from src to the target neighbor + # Split the tile_M to each block + + chunk_M = T.ceildiv(tile_M, PE_num) + chunk_start = bx * chunk_M + chunk_size = T.min(chunk_M, tile_M - chunk_start) + + # if src_rank[0] == rank[0]: + if dst_rank[0] != rank[0]: + T.put_warp( + T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), + T.address_of(buffer_transfer[rank[0], dst_rank[0], bz * block_M + warp_idx * tile_M + chunk_start, 0]), + chunk_size * N, + next_rank[0], + ) + if tx % 32 == 0: + T.atom_add_remote( + signal_transfer[rank[0], dst_rank[0], bz, warp_idx], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + T.sync_warp() + T.fence_sys(sem=T.MemorySemantic.RELEASE) + else: + T.put_warp( + T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), + T.address_of(dst[rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), + chunk_size * N, + -1, + ) + if tx % 32 == 0: + old_local[0] = T.atom_add( + local_finish[0], + 1, + scope=T.MemoryScope.GPU, + sem=T.MemorySemantic.RELEASE, + ) + T.sync_warp() + T.fence_cta(sem=T.MemorySemantic.RELEASE) + + T.fence_sys(sem=T.MemorySemantic.RELEASE) + T.sync_threads() + + # Phase 2: Each block handles one final dst data in one direction buffer of current rank and check whether to transfer + if tx % 32 == 0: + T.wait_ge( + signal_transfer[bx, dst_rank[0], bz, warp_idx], PE_num, scope=T.MemoryScope.SYSTEM) + T.sync_warp() + + if signal_transfer[bx, dst_rank[0], bz, warp_idx] == PE_num and flag[0] == False: + flag[0] = True + # Handle the transfer signal + if dst_rank[0] != rank[0]: + T.put_warp( + T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), + T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), + tile_M * N, + dst_pe=next_rank[0], + ) + if tx % 32 == 0: + T.st( + signal_transfer[bx, dst_rank[0], bz, warp_idx], + PE_num, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + # if bz == 0 and tx == 0: + # T.print(rank[0], "transfer rank") + T.sync_warp() + else: + # Current rank is the real destination of this chunk of data, the real source rank is the buffer index + T.put_warp( + T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), + T.address_of(dst[bx * M + bz * block_M + warp_idx * tile_M, 0]), + tile_M * N, + -1, + ) + T.sync_warp() + if tx % 32 == 0: + old_local[0] = T.atom_add( + local_finish[0], + 1, + scope=T.MemoryScope.GPU, + sem=T.MemorySemantic.RELEASE, + ) + # if bz == 0 and tx == 0: + # T.print(rank[0], "dst rank") + if old_local[0] + 1 == (PE_num - 1) * num_tiles[0] + PE_num * num_tiles[0]: + for i in T.serial(PE_num): + old_global[0] = T.atom_add_remote( + global_finish[0], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=i, + ) + if old_global[0] + 1 == PE_num: + # Send termination signals to wake up all waiting blocks on all PEs + for remote_pe in T.serial(PE_num): + for src_rank_idx in T.serial(PE_num): + for dst_rank_idx in T.serial(PE_num): + for bz_idx in T.serial(num_blocks): + for dst_tile in T.serial(num_warps): + T.st( + signal_transfer[src_rank_idx, dst_rank_idx, + bz_idx, dst_tile], + PE_num + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=remote_pe, + ) + T.sync_warp() + + T.sync_threads() + + T.fence_sys(sem=T.MemorySemantic.RELEASE) + + return main_route + + +def run_torus_alltoall(local_rank, num_ranks, args): + NUM_SM = 148 + PE_num = args.PE_num + X, Y = args.X, args.Y + M, N = args.M, args.N + block_M, block_N = M // 4, N + threads = 128 + + num_blocks = M // block_M + num_blocks = min(num_blocks, NUM_SM // (PE_num * PE_num)) + num_warps = threads // 32 + + local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) + allocator = tilelang.get_allocator( + size=2**35, + device="cuda", + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_ranks, + group=group_size, + ) + + kernel = torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, threads) + kernel.initialize(allocator=allocator) + + src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_() + dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() + + dst_ref = torch.zeros((PE_num * M, N), dtype=torch.float16, device="cuda") + dist.all_to_all_single(dst_ref, src, group=group_size) + torch.cuda.synchronize() + + buffer_transfer = tilelang.tensor((PE_num, PE_num, M, N), torch.float16, + allocator=allocator).fill_(0) + + # Signals for each buffer slot in each direction + signal_transfer = tilelang.tensor((PE_num, PE_num, num_blocks, num_warps), + torch.uint32, + allocator=allocator).fill_(0) + local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) + global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) + barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() + + torch.cuda.synchronize() + dist.barrier(group_size) + + kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + + torch.cuda.synchronize() + dist.barrier(group_size) + + print(f"Rank {local_rank} TileLang AllToAll XY Routing Finished.") + + if torch.allclose(dst, dst_ref, atol=1e-2, rtol=1e-2): + print(f"Rank {local_rank} Verification Passed! ✅") + else: + max_diff = (dst - dst_ref).abs().max() + print(f"Rank {local_rank} Verification Failed! ❌ Max diff: {max_diff}") + # Find differences + diff_mask = (dst != dst_ref) + diff_count = diff_mask.sum().item() + + if diff_count > 0: + # Check each rank's symmetric buffer at every M boundary + print(f"Rank {local_rank} found {diff_count} differences") + print(f"Rank {local_rank} checking symmetric buffer positions at M={M} boundaries:") + for rank_idx in range(PE_num): + start_idx = rank_idx * M + # Check first element of each rank's buffer + print( + f"Rank {local_rank} Buffer[{rank_idx}][0,0]: dst={buffer_transfer[rank_idx, local_rank, 0, 0].item():.6f}, dst_ref={dst_ref[start_idx, 0].item():.6f}" + ) + + if args.benchmark: + # Warmup + for _ in range(args.warmup): + kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + torch.cuda.synchronize() + dist.barrier(group_size) + + # Reinitialize + buffer_transfer.zero_() + signal_transfer.zero_() + local_finish.zero_() + global_finish.zero_() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + num_iters = args.iter + start_event.record() + for _ in range(num_iters): + # torch.cuda.profiler.start() + # with torch.cuda.nvtx.range("alltoall_xy_routing_benchmark"): + kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + # torch.cuda.profiler.stop() + torch.cuda.synchronize() + dist.barrier(group_size) + end_event.record() + torch.cuda.synchronize() + + elapsed_time_ms = start_event.elapsed_time(end_event) / num_iters + # All-to-all total data moved: each rank sends (PE_num - 1) * M * N elements + # and receives (PE_num - 1) * M * N elements. + # For bandwidth calculation, we usually use the amount of data sent per rank. + total_data_bytes = PE_num * M * N * 2 # float16 = 2 bytes + bandwidth_gbps = (total_data_bytes / 1e9) / (elapsed_time_ms / 1e3) + print(f"Rank {local_rank} Average Latency: {elapsed_time_ms:.4f} ms, Effective Bandwidth: {bandwidth_gbps:.4f} GB/s") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=8192) + parser.add_argument("--N", type=int, default=7168) + parser.add_argument("--PE_num", type=int, default=8) + parser.add_argument("--X", type=int, default=2) + parser.add_argument("--Y", type=int, default=4) + parser.add_argument("--benchmark", action="store_true", help="Run benchmark") + parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") + parser.add_argument("--iter", type=int, default=10, help="Number of benchmark iterations") + args = parser.parse_args() + + torch.multiprocessing.spawn(run_torus_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) diff --git a/examples/distributed/intranode/example_alltoall_route2x4_opt.py b/examples/distributed/intranode/example_alltoall_route2x4_opt.py new file mode 100644 index 000000000..e3ce0f471 --- /dev/null +++ b/examples/distributed/intranode/example_alltoall_route2x4_opt.py @@ -0,0 +1,486 @@ +import tilelang +import tilelang.language as T +from tilelang.distributed import init_dist +import torch +import torch.distributed as dist +import argparse +from enum import IntEnum + +# tilelang.disable_cache() + + +class Direction(IntEnum): + NORTH = 0 + SOUTH = 1 + WEST = 2 + EAST = 3 + SELF = 4 + + +def compute_next_hop(src_rank, dst_rank, X, Y): + """Compute the next hop from src_rank towards dst_rank using XY-routing on a 2D torus.""" + if src_rank == dst_rank: + return src_rank # self, no hop needed + + src_x, src_y = src_rank // Y, src_rank % Y + dst_x, dst_y = dst_rank // Y, dst_rank % Y + + if dst_x != src_x: + # Route along X-axis first + diff = dst_x - src_x + if diff > X // 2: + diff -= X + elif diff <= -X // 2: + diff += X + + if diff < 0: + # North + next_x = (src_x + X - 1) % X + return next_x * Y + src_y + else: + # South + next_x = (src_x + 1) % X + return next_x * Y + src_y + else: + # Route along Y-axis + diff = dst_y - src_y + if diff > Y // 2: + diff -= Y + elif diff <= -Y // 2: + diff += Y + + if diff < 0: + # West + next_y = (src_y + Y - 1) % Y + return src_x * Y + next_y + else: + # East + next_y = (src_y + 1) % Y + return src_x * Y + next_y + + +def compute_expected_slots(PE_num, X, Y): + """ + Pre-compute the number of incoming slots each PE will receive for each dst_rank, + tracing the full XY-routing path through the torus. + + Returns a list of lists of shape (PE_num, PE_num) where expected_slots[me][dst_rank] is + the number of data chunks that PE 'me' will receive in its buffer for destination 'dst_rank'. + + This includes both: + - Final reception: dst_rank == me, each other PE sends one chunk → PE_num - 1 slots + - Forwarding: dst_rank != me, data passes through me on the way to dst_rank + + For each (src, final_dst) pair where src != final_dst, we trace the full multi-hop + XY-routing path. Every intermediate PE and the final destination PE each receive one slot. + """ + # expected_slots[receiver_pe][dst_rank] = count of incoming slots + expected = [[0] * PE_num for _ in range(PE_num)] + + for src in range(PE_num): + for final_dst in range(PE_num): + if src == final_dst: + continue + # Trace the full path from src to final_dst + current = src + while current != final_dst: + next_hop = compute_next_hop(current, final_dst, X, Y) + # next_hop receives one slot for dst_rank=final_dst + expected[next_hop][final_dst] += 1 + current = next_hop + + return expected + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + # debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" +) +def torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, num_warps, threads): + block_M = M // num_blocks + tile_M = block_M // num_warps + # Number of slots + num_slots = PE_num + + @T.prim_func + def main_route_opt( + # For each (src, dst) pair, the real transfer size is M * N + src: T.Tensor((PE_num * M, N), "float16"), + dst: T.Tensor((PE_num * M, N), "float16"), + # buffer[dst_rank, slots, *, *]: This PE save slots for transferring data chunks for the real destination rank + buffer_transfer: T.Tensor((PE_num, num_slots, M, N), "float16"), + # slot_counter[dst_rank, num_blocks, num_warps]: Counter for allocating slot indices + slot_counter: T.Tensor((PE_num, num_blocks, num_warps), "uint32"), + # Per-slot ready flag: 0=not ready, 1=data ready + # signal_transfer[dst_rank, slot, num_blocks, num_warps] + signal_transfer: T.Tensor((PE_num, num_slots, num_blocks, num_warps), "int32"), + # Src idx during transfer + src_transfer: T.Tensor((PE_num, num_slots, num_blocks, num_warps), "uint32"), + # Pre-computed expected incoming slot counts per (PE, dst_rank) + # expected_slots[pe, dst_rank]: how many slots PE 'pe' will receive for dst_rank + expected_slots: T.Tensor((PE_num, PE_num), "int32"), + ): + with T.Kernel(PE_num, num_blocks, threads=threads) as (bx, bz): + tx = T.get_thread_binding() + warp_id = tx // 32 + + rank = T.alloc_local([1], "uint32") + rank_x = T.alloc_local([1], "uint32") + rank_y = T.alloc_local([1], "uint32") + next_rank = T.alloc_local([1], "uint32") + to_dir = T.alloc_local([1], "uint32") + diff = T.alloc_local([1], "int32") + dst_rank = T.alloc_local([1], "uint32") + old_counter = T.alloc_local([1], "uint32") + old_counter_shared = T.alloc_shared([PE_num, num_blocks, num_warps], "uint32") + cur_counter = T.alloc_local([1], "uint32") + cur_counter_shared = T.alloc_shared([PE_num, num_blocks, num_warps], "uint32") + slot_flag = T.alloc_local([1], "int32") + slot_flag_shared = T.alloc_shared([num_warps], "int32") + num_expected = T.alloc_local([1], "int32") + + rank[0] = T.get_rank() + rank_x[0] = T.floordiv(rank[0], Y) + rank_y[0] = T.floormod(rank[0], Y) + next_rank[0] = rank[0] + + dst_rank[0] = bx + + # Read pre-computed expected slot count for this PE and dst_rank + num_expected[0] = expected_slots[rank[0], dst_rank[0]] + + # Prepare for routing + dst_rank_x = T.floordiv(dst_rank[0], Y) + dst_rank_y = T.floormod(dst_rank[0], Y) + to_dir[0] = Direction.SELF + # XY-routing: first route along X-axis, then Y-axis + if dst_rank_x != rank_x[0]: + # Calculate shortest path in Torus X-dimension + diff[0] = dst_rank_x - rank_x[0] + if diff[0] > T.floordiv(X, 2): + diff[0] -= X + elif diff[0] <= -T.ceildiv(X, 2): + diff[0] += X + + if diff[0] < 0: + # Send North (up): neighbor receives in its north buffer + to_dir[0] = Direction.NORTH + next_rank[0] = T.floormod(rank_x[0] + X - 1, X) * Y + rank_y[0] + else: + # Send South (down): neighbor receives in its south buffer + to_dir[0] = Direction.SOUTH + next_rank[0] = T.floormod(rank_x[0] + 1, X) * Y + rank_y[0] + elif dst_rank_y != rank_y[0]: + # Calculate shortest path in Torus Y-dimension + diff[0] = dst_rank_y - rank_y[0] + if diff[0] > T.floordiv(Y, 2): + diff[0] -= Y + elif diff[0] <= -T.ceildiv(Y, 2): + diff[0] += Y + + if diff[0] < 0: + # Send West (left): neighbor receives in its west buffer + to_dir[0] = Direction.WEST + next_rank[0] = rank_x[0] * Y + T.floormod(rank_y[0] + Y - 1, Y) + else: + # Send East (right): neighbor receives in its east buffer + to_dir[0] = Direction.EAST + next_rank[0] = rank_x[0] * Y + T.floormod(rank_y[0] + 1, Y) + + # Phase 1: Fully use all blocks to initially send from src to the target neighbor + # Split the tile_M to each block + + chunk_M = T.min(block_M - warp_id * tile_M, tile_M) + if dst_rank[0] != rank[0]: + if tx % 32 == 0: + old_counter[0] = T.atom_add_remote( + slot_counter[dst_rank[0], bz, warp_id], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + # Write src idx to src_transfer + T.st( + src_transfer[dst_rank[0], old_counter[0], bz, warp_id], + rank[0], + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + old_counter_shared[dst_rank[0], bz, warp_id] = old_counter[0] + T.sync_warp() + T.put_warp( + T.address_of(src[dst_rank[0] * M + bz * block_M + warp_id * tile_M, 0]), + T.address_of(buffer_transfer[dst_rank[0], old_counter_shared[dst_rank[0], bz, warp_id], bz * block_M + warp_id * tile_M, 0]), + chunk_M * N, + next_rank[0], + ) + T.fence_sys(sem=T.MemorySemantic.RELEASE) + if tx % 32 == 0: + # Set per-slot ready flag: this specific slot is now ready + T.st( + signal_transfer[dst_rank[0], old_counter_shared[dst_rank[0], bz, warp_id], bz, warp_id], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + T.sync_warp() + T.fence_sys(sem=T.MemorySemantic.RELEASE) + else: + T.put_warp( + T.address_of(src[dst_rank[0] * M + bz * block_M + warp_id * tile_M, 0]), + T.address_of(dst[rank[0] * M + bz * block_M + warp_id * tile_M, 0]), + chunk_M * N, + -1, + ) + T.sync_warp() + T.fence_cta(sem=T.MemorySemantic.RELEASE) + + # Phase 2: Poll per-slot ready flags sequentially and process each slot. + # Use pre-computed expected_slots count to know exactly how many slots to process, + # eliminating the need for termination signals. + for slot_idx in T.serial(num_slots): + # Skip if no more expected slots or this block doesn't receive for this dst_rank + if slot_idx >= num_expected[0]: + T.loop_break() + + # Wait for this slot's data to become ready (flag != 0) + if tx % 32 == 0: + slot_flag[0] = T.wait_ne( + signal_transfer[dst_rank[0], slot_idx, bz, warp_id], + 0, + scope=T.MemoryScope.SYSTEM, + ) + slot_flag_shared[warp_id] = slot_flag[0] + T.sync_warp() + + + src_idx = src_transfer[dst_rank[0], slot_idx, bz, warp_id] + # Handle the transfer + if dst_rank[0] != rank[0]: + # Forward to next hop + if tx % 32 == 0: + cur_counter[0] = T.atom_add_remote( + slot_counter[dst_rank[0], bz, warp_id], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + # Write src idx to src_transfer on next hop + T.st( + src_transfer[dst_rank[0], cur_counter[0], bz, warp_id], + src_idx, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + cur_counter_shared[dst_rank[0], bz, warp_id] = cur_counter[0] + T.sync_warp() + T.put_warp( + T.address_of(buffer_transfer[dst_rank[0], slot_idx, bz * block_M + warp_id * tile_M, 0]), + T.address_of(buffer_transfer[dst_rank[0], cur_counter_shared[dst_rank[0], bz, warp_id], bz * block_M + warp_id * tile_M, 0]), + chunk_M * N, + dst_pe=next_rank[0], + ) + T.fence_sys(sem=T.MemorySemantic.RELEASE) + T.sync_warp() + if tx % 32 == 0: + # Set per-slot ready flag on next hop + T.st( + signal_transfer[dst_rank[0], cur_counter_shared[dst_rank[0], bz, warp_id], bz, warp_id], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + T.sync_warp() + T.fence_sys(sem=T.MemorySemantic.RELEASE) + else: + # Final destination: copy from buffer to dst + T.put_warp( + T.address_of(buffer_transfer[dst_rank[0], slot_idx, bz * block_M + warp_id * tile_M, 0]), + T.address_of(dst[src_transfer[dst_rank[0], slot_idx, bz, warp_id] * M + bz * block_M + warp_id * tile_M, 0]), + chunk_M * N, + -1, + ) + T.fence_cta(sem=T.MemorySemantic.RELEASE) + T.sync_warp() + + T.sync_warp() + + return main_route_opt + + +def run_torus_alltoall(local_rank, num_ranks, args): + NUM_SM = 148 + PE_num = args.PE_num + X, Y = args.X, args.Y + M, N = args.M, args.N + blocks = args.blocks + threads = 256 + assert threads % 32 == 0, "threads must be divisible by 32" + num_warps = threads // 32 + + num_blocks = blocks + num_blocks = min(num_blocks, NUM_SM // PE_num) + + local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) + allocator = tilelang.get_allocator( + size=2**34, + device="cuda", + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_ranks, + group=group_size, + ) + + # Pre-compute expected slot counts based on XY-routing topology + expected_slots_host = compute_expected_slots(PE_num, X, Y) + if local_rank == 0: + print(f"Expected slots per PE (rows=receiver, cols=dst_rank):") + for pe in range(PE_num): + print(f" PE {pe}: {expected_slots_host[pe]}") + + kernel = torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, num_warps, threads) + kernel.initialize(allocator=allocator) + + src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_() + dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() + + dst_ref = torch.zeros((PE_num * M, N), dtype=torch.float16, device="cuda") + dist.all_to_all_single(dst_ref, src, group=group_size) + torch.cuda.synchronize() + + buffer_transfer = tilelang.tensor((PE_num, PE_num, M, N), torch.float16, + allocator=allocator).fill_(-1) + + # Signals for each buffer slot in each direction + slot_counter = tilelang.tensor((PE_num, num_blocks, num_warps), torch.uint32, allocator=allocator).fill_(0) + # Per-slot ready flags: 0=not ready, 1=ready + signal_transfer = tilelang.tensor((PE_num, PE_num, num_blocks, num_warps), + torch.int32, + allocator=allocator).fill_(0) + + src_transfer = tilelang.tensor((PE_num, PE_num, num_blocks, num_warps), + torch.uint32, + allocator=allocator).fill_(PE_num) + + # Pre-computed expected slots tensor (same for all PEs) + expected_slots_tensor = tilelang.tensor((PE_num, PE_num), torch.int32, allocator=allocator) + expected_slots_tensor.copy_(torch.tensor(expected_slots_host, dtype=torch.int32, device="cuda")) + + torch.cuda.synchronize() + dist.barrier(group_size) + + kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, src_transfer, expected_slots_tensor) + + torch.cuda.synchronize() + dist.barrier(group_size) + + print(f"Rank {local_rank} TileLang AllToAll XY Routing Finished.") + + if torch.allclose(dst, dst_ref, atol=1e-2, rtol=1e-2): + print(f"Rank {local_rank} Verification Passed! ✅") + else: + max_diff = (dst - dst_ref).abs().max() + print(f"Rank {local_rank} Verification Failed! ❌ Max diff: {max_diff}") + # Find differences + diff_mask = (dst != dst_ref) + diff_count = diff_mask.sum().item() + + if diff_count > 0: + # Check each rank's symmetric buffer at every M boundary + print(f"Rank {local_rank} found {diff_count} differences") + print(f"Rank {local_rank} checking symmetric buffer positions at M={M} boundaries:") + for rank_idx in range(PE_num): + start_idx = rank_idx * M + # Check first element of each rank's buffer + print( + f"Rank {local_rank} Buffer[{rank_idx}][0,0]: dst={dst[start_idx, 0].item():.6f}, dst_ref={dst_ref[start_idx, 0].item():.6f}" + ) + for slot in range(PE_num): + print( + f"Rank {local_rank} slot={slot}: buffer={buffer_transfer[local_rank, slot, 0, 0].item():.6f}, src_transfer={src_transfer[local_rank, slot, 0, 0].item():.6f}" + ) + # Print first 10 differences with their coordinates + print(f"Rank {local_rank} first 10 differences:") + diff_indices = torch.nonzero(diff_mask, as_tuple=False) + for i in range(min(10, diff_indices.shape[0])): + idx = diff_indices[i] + row, col = idx[0].item(), idx[1].item() + print( + f"Rank {local_rank} Diff[{i}] at ({row}, {col}): dst={dst[row, col].item():.6f}, dst_ref={dst_ref[row, col].item():.6f}" + ) + + if args.benchmark: + # Warmup + for _ in range(args.warmup): + # Reinitialize + torch.cuda.synchronize() + dist.barrier(group_size) + buffer_transfer.fill_(-1) + src_transfer.fill_(PE_num) + slot_counter.zero_() + signal_transfer.zero_() + dst.zero_() + expected_slots_tensor.copy_(torch.tensor(expected_slots_host, dtype=torch.int32, device="cuda")) + kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, src_transfer, expected_slots_tensor) + torch.cuda.synchronize() + dist.barrier(group_size) + + # Reinitialize + buffer_transfer.fill_(-1) + src_transfer.fill_(PE_num) + slot_counter.zero_() + signal_transfer.zero_() + dst.zero_() + expected_slots_tensor.copy_(torch.tensor(expected_slots_host, dtype=torch.int32, device="cuda")) + torch.cuda.synchronize() + dist.barrier(group_size) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + num_iters = args.iter + start_event.record() + for _ in range(num_iters): + # torch.cuda.profiler.start() + # with torch.cuda.nvtx.range("alltoall_xy_routing_benchmark"): + kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, src_transfer, expected_slots_tensor) + # torch.cuda.profiler.stop() + torch.cuda.synchronize() + dist.barrier(group_size) + end_event.record() + torch.cuda.synchronize() + + elapsed_time_ms = start_event.elapsed_time(end_event) / num_iters + # All-to-all total data moved: each rank sends (PE_num - 1) * M * N elements + # and receives (PE_num - 1) * M * N elements. + # For bandwidth calculation, we usually use the amount of data sent per rank. + total_data_bytes = PE_num * M * N * 2 # float16 = 2 bytes + bandwidth_gbps = (total_data_bytes / 1e9) / (elapsed_time_ms / 1e3) + print(f"Rank {local_rank} Average Latency: {elapsed_time_ms:.4f} ms, Effective Bandwidth: {bandwidth_gbps:.4f} GB/s") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=8192) + parser.add_argument("--N", type=int, default=7168) + parser.add_argument("--PE_num", type=int, default=8) + parser.add_argument("--X", type=int, default=2) + parser.add_argument("--Y", type=int, default=4) + parser.add_argument("--benchmark", action="store_true", help="Run benchmark") + parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") + parser.add_argument("--iter", type=int, default=10, help="Number of benchmark iterations") + parser.add_argument("--blocks", type=int, default=1, help="Number of blocks") + args = parser.parse_args() + + torch.multiprocessing.spawn(run_torus_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) diff --git a/src/op/remote_copy.cc b/src/op/remote_copy.cc index fba501e48..7f0291f0f 100644 --- a/src/op/remote_copy.cc +++ b/src/op/remote_copy.cc @@ -273,10 +273,15 @@ Stmt StOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { std::stringstream ss; // Map integers to enum literal strings - const char *sem_str[] = {"Semantic::WEAK", "Semantic::VOLATILE", - "Semantic::ACQUIRE", "Semantic::RELEASE", - "Semantic::RELAXED"}; - const char *scope_str[] = {"Scope::CTA", "Scope::GPU", "Scope::SYS"}; + // 0: WEAK, 1: VOLATILE, 2: RELAXED, 3: ACQUIRE, 4: RELEASE, 5: ACQ_REL + const char *sem_str[] = {"Semantic::WEAK", "Semantic::VOLATILE", + "Semantic::RELAXED", "Semantic::ACQUIRE", + "Semantic::RELEASE", "Semantic::ACQ_REL"}; + const char *scope_str[] = {"Scope::CTA", "Scope::CLUSTER", "Scope::GPU", + "Scope::SYS"}; + + ICHECK_LT(sem, 6); + ICHECK_LT(scope, 4); // Build function name: tl::st ss << "tl::st<" << sem_str[sem] << ", " << scope_str[scope] << ", " @@ -342,10 +347,15 @@ Stmt LdOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { std::stringstream ss; // Map integers to enum literal strings - const char *sem_str[] = {"Semantic::WEAK", "Semantic::VOLATILE", - "Semantic::ACQUIRE", "Semantic::RELEASE", - "Semantic::RELAXED"}; - const char *scope_str[] = {"Scope::CTA", "Scope::GPU", "Scope::SYS"}; + // 0: WEAK, 1: VOLATILE, 2: RELAXED, 3: ACQUIRE, 4: RELEASE, 5: ACQ_REL + const char *sem_str[] = {"Semantic::WEAK", "Semantic::VOLATILE", + "Semantic::RELAXED", "Semantic::ACQUIRE", + "Semantic::RELEASE", "Semantic::ACQ_REL"}; + const char *scope_str[] = {"Scope::CTA", "Scope::CLUSTER", "Scope::GPU", + "Scope::SYS"}; + + ICHECK_LT(sem, 6); + ICHECK_LT(scope, 4); // Build function name: tl::ld ss << "tl::ld<" << sem_str[sem] << ", " << scope_str[scope] << ", " @@ -382,6 +392,75 @@ TileOperator LdOpNode::Clone() const { return LdOp(node); } +AtomAddRemoteOp::AtomAddRemoteOp(Array args, BufferMap vmap) { + ObjectPtr node = make_object(); + node->dst = args[0]; + ICHECK(node->dst.as()) << "dst must be a call node"; + ICHECK(node->dst.as()->op.same_as(builtin::address_of())) + << "dst must be address_of op"; + + node->value = args[1]; + node->sem = args[2].as().value()->value; + node->scope = args[3].as().value()->value; + node->dst_pe = args[4]; + data_ = std::move(node); + (void)vmap; +} + +bool AtomAddRemoteOpNode::is_distributed() const { + return !(dst_pe->IsInstance() && + dst_pe.as()->value == -1); +} + +Stmt AtomAddRemoteOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { + (void)analyzer; + (void)T; + Array new_args; + std::stringstream ss; + + // Map integers to semantic literal strings for PTX atom instruction + // Unified Mapping: 2: relaxed, 3: acquire, 4: release, 5: acq_rel + const char *sem_str[] = {"weak", "volatile", "relaxed", + "acquire", "release", "acq_rel"}; + const char *scope_str[] = { + "cta", "cluster", "gpu", + "sys"}; // Unified: 2: gpu, 3: system (mapped below) + + // Build function name: tl::ptx_atom_add__ + ss << "tl::ptx_atom_add_" << sem_str[sem] << "_" << scope_str[scope]; + + new_args.push_back(StringImm(ss.str())); + if (is_distributed()) { + PrimExpr local_rank = Call(DataType::Int(64), tl::get_rank(), {}); + PrimExpr local_base_ptr = + Call(DataType::Handle(), tl::get_remote_base_ptr(), {local_rank}); + PrimExpr offset_to_base = Sub( + Call(DataType::Handle(), tl::get_uintptr_t(), {dst}), local_base_ptr); + new_args.push_back( + Call(DataType::Handle(), tl::get_remote_base_ptr(), {dst_pe}) + + offset_to_base); + } else { + new_args.push_back(dst); + } + new_args.push_back(value); + + auto atom_add = Call(DataType::UInt(32), builtin::call_extern(), new_args); + return Evaluate(atom_add); +} + +LayoutMap AtomAddRemoteOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + (void)T; + (void)level; + return {}; +} + +TileOperator AtomAddRemoteOpNode::Clone() const { + auto node = make_object(*this); + return AtomAddRemoteOp(node); +} + TIR_REGISTER_TL_OP(PutOp, put) .set_num_inputs(7) .set_attr("TCallEffectKind", @@ -398,10 +477,16 @@ TIR_REGISTER_TL_OP(StOp, st).set_num_inputs(6).set_attr( TIR_REGISTER_TL_OP(LdOp, ld).set_num_inputs(7).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_REGISTER_TL_OP(AtomAddRemoteOp, atom_add_remote) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TVM_FFI_STATIC_INIT_BLOCK({ PutOpNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK({ GetOpNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK({ StOpNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK({ LdOpNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK({ AtomAddRemoteOpNode::RegisterReflection(); }); } // namespace tl } // namespace tvm diff --git a/src/op/remote_copy.h b/src/op/remote_copy.h index 3c118f33a..d93bb45e2 100644 --- a/src/op/remote_copy.h +++ b/src/op/remote_copy.h @@ -319,6 +319,62 @@ class LdOp : public TileOperator { static const Op &Get(); }; +class AtomAddRemoteOpNode : public TileOperatorNode { +public: + PrimExpr dst; ///< Destination address + PrimExpr value; ///< Value to atomically add + PrimExpr dst_pe; ///< Destination processing element (optional) + int scope; ///< Memory scope (0: GPU, 1: SYS) + int sem; ///< Memory semantic (0: relaxed, 1: acquire, 2: release, 3: acq_rel) + + bool is_distributed() const; + + static constexpr const char *_type_key = "tl.AtomAddRemoteOp"; + TVM_DECLARE_FINAL_OBJECT_INFO(AtomAddRemoteOpNode, TileOperatorNode); + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + static const Op &Get(); + TileOperator Clone() const override; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("dst", &AtomAddRemoteOpNode::dst) + .def_ro("value", &AtomAddRemoteOpNode::value) + .def_ro("dst_pe", &AtomAddRemoteOpNode::dst_pe) + .def_ro("scope", &AtomAddRemoteOpNode::scope) + .def_ro("sem", &AtomAddRemoteOpNode::sem); + } + + bool SEqualReduce(const AtomAddRemoteOpNode *other, + SEqualReducer equal) const { + return equal(dst, other->dst) && equal(value, other->value) && + equal(dst_pe, other->dst_pe) && scope == other->scope && + sem == other->sem; + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(dst); + hash_reduce(value); + hash_reduce(dst_pe); + hash_reduce(scope); + hash_reduce(sem); + } + + static constexpr bool _type_has_method_sequal_reduce = true; + static constexpr bool _type_has_method_shash_reduce = true; +}; + +class AtomAddRemoteOp : public TileOperator { +public: + TVM_DEFINE_OBJECT_REF_METHODS(AtomAddRemoteOp, TileOperator, + AtomAddRemoteOpNode); + TVM_DLL AtomAddRemoteOp(Array args, BufferMap vmap); + static const Op &Get(); +}; + } // namespace tl } // namespace tvm diff --git a/src/op/sync.cc b/src/op/sync.cc index 892fc2220..f8d585476 100644 --- a/src/op/sync.cc +++ b/src/op/sync.cc @@ -51,9 +51,6 @@ TIR_DEFINE_TL_BUILTIN(wait_barrier_gpu) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(wait_eq).set_num_inputs(2).set_attr( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); - TIR_DEFINE_TL_BUILTIN(sync_barrier_gpu) .set_num_inputs(1) .set_attr("TCallEffectKind", @@ -137,10 +134,18 @@ PrimExpr BarrierBlocksOpNode::MakeLocalBarAddr(const LowerArgs &T) const { WaitOp::WaitOp(Array args, BufferMap vmap) { ObjectPtr node = make_object(); - node->relation = args[0].as()->value; + ICHECK_GE(args.size(), 4); + const auto *relation_node = args[0].as(); + ICHECK(relation_node) << "Wait relation must be an integer"; + node->relation = relation_node->value; node->addr = args[1]; node->expected = args[2]; node->peer = args[3]; + // scope parameter is optional, default to SYSTEM (3) for safety + node->scope = (args.size() > 4) ? args[4].as()->value : 3; + // semantic parameter is optional, default to ACQUIRE (3) for safety + node->semantic = (args.size() > 5) ? args[5].as()->value : 3; + node->dtype = (args.size() > 6) ? args[6].as()->value : "int32"; data_ = std::move(node); (void)vmap; } @@ -174,8 +179,13 @@ Stmt WaitOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { new_args.push_back(addr); } new_args.push_back(expected); - - auto wait = Call(DataType::Handle(), builtin::call_extern(), new_args); + // Pass scope: 0=CTA, 1=CLUSTER, 2=GPU, 3=SYSTEM + new_args.push_back(IntImm(DataType::Int(32), scope)); + // Pass semantic: 0=WEAK, 1=VOLATILE, 2=RELAXED, 3=ACQUIRE, 4=RELEASE, + // 5=ACQ_REL + new_args.push_back(IntImm(DataType::Int(32), semantic)); + auto datatype = dtype == "int32" ? DataType::Int(32) : DataType::UInt(32); + auto wait = Call(datatype, builtin::call_extern(), new_args); return Evaluate(wait); } @@ -192,22 +202,22 @@ TileOperator WaitOpNode::Clone() const { } TIR_REGISTER_TL_OP(BarrierBlocksOp, barrier_blocks) - .set_num_inputs(1) + .set_num_inputs(2) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_REGISTER_TL_OP(WaitOp, wait) - .set_num_inputs(4) + .set_num_inputs(6) // relation, addr, expected, peer, scope, semantic .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(fence_cta).set_num_inputs(0).set_attr( +TIR_DEFINE_TL_BUILTIN(fence_cta).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(fence_gpu).set_num_inputs(0).set_attr( +TIR_DEFINE_TL_BUILTIN(fence_gpu).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(fence_sys).set_num_inputs(0).set_attr( +TIR_DEFINE_TL_BUILTIN(fence_sys).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_FFI_STATIC_INIT_BLOCK({ BarrierBlocksOpNode::RegisterReflection(); }); diff --git a/src/op/sync.h b/src/op/sync.h index 16487877e..a4c78b0ef 100644 --- a/src/op/sync.h +++ b/src/op/sync.h @@ -38,13 +38,6 @@ TVM_DLL const Op &arrive_barrier_gpu(); */ TVM_DLL const Op &wait_barrier_gpu(); -/*! - * \brief Wait until *addr == expected* for GPU-level synchronization - * void wait_eq(addr, expected) - */ - -TVM_DLL const Op &wait_eq(); - /*! * \brief TileOperatorNode for wait operation. * @@ -57,6 +50,10 @@ class WaitOpNode : public TileOperatorNode { PrimExpr expected; ///< The expected value to compare against. PrimExpr peer; ///< The peer to compare against. int relation; ///< The relation to compare against. + int scope; ///< Memory scope: 0=CTA, 1=CLUSTER, 2=GPU, 3=SYSTEM + int semantic; ///< Memory semantic: 0=WEAK, 1=VOLATILE, 2=RELAXED, 3=ACQUIRE, + ///< 4=RELEASE, 5=ACQ_REL + std::string dtype; ///< The data type of the memory address, must be int32 or uint32 bool is_distributed() const; @@ -75,12 +72,17 @@ class WaitOpNode : public TileOperatorNode { .def_ro("addr", &WaitOpNode::addr) .def_ro("expected", &WaitOpNode::expected) .def_ro("peer", &WaitOpNode::peer) - .def_ro("relation", &WaitOpNode::relation); + .def_ro("relation", &WaitOpNode::relation) + .def_ro("scope", &WaitOpNode::scope) + .def_ro("semantic", &WaitOpNode::semantic) + .def_ro("dtype", &WaitOpNode::dtype); } bool SEqualReduce(const WaitOpNode *other, SEqualReducer equal) const { return equal(addr, other->addr) && equal(expected, other->expected) && - equal(peer, other->peer) && equal(relation, other->relation); + equal(peer, other->peer) && relation == other->relation && + scope == other->scope && semantic == other->semantic && + dtype == other->dtype; } void SHashReduce(SHashReducer hash_reduce) const { @@ -88,6 +90,9 @@ class WaitOpNode : public TileOperatorNode { hash_reduce(expected); hash_reduce(peer); hash_reduce(relation); + hash_reduce(scope); + hash_reduce(semantic); + hash_reduce(dtype); } static constexpr bool _type_has_method_sequal_reduce = true; diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index b39e9b042..3a500f813 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1477,6 +1477,20 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->PrintIndent(); int num_mma = Downcast(op->args[0])->value; this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n"; + } else if (op->op.same_as(Op::Get("tl.wait"))) { + // relation, addr, expected, peer, scope, semantic + ICHECK_GE(op->args.size(), 4); + const char *relation_str[] = {"eq", "ne", "ge", "le", "gt", "lt"}; + int relation = Downcast(op->args[0])->value; + std::string addr = this->PrintExpr(op->args[1]); + std::string expected = this->PrintExpr(op->args[2]); + // peer is args[3], but wait_ne in sync.h doesn't take peer if it's already a pointer + // The Lower() in sync.cc handles distributed by converting to remote pointer + int scope = (op->args.size() > 4) ? Downcast(op->args[4])->value : 3; + int semantic = (op->args.size() > 5) ? Downcast(op->args[5])->value : 3; + + os << "tl::wait_" << relation_str[relation] << "(" << addr << ", " << expected + << ", " << scope << ", " << semantic << ");\n"; } else if (op->op.same_as(tl::pack_b16())) { os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1]) << ")"; @@ -1507,10 +1521,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::sync_grid())) { this->PrintIndent(); this->stream << "tl::sync_grid(" << this->PrintExpr(op->args[0]) << ");\n"; - } else if (op->op.same_as(tl::wait_eq())) { - this->PrintIndent(); - this->stream << "tl::wait_eq(" << this->PrintExpr(op->args[0]) << ", " - << this->PrintExpr(op->args[1]) << ");\n"; } else if (op->op.same_as(tl::atom_add())) { std::string func_name = "tl::ptx_atom_add_" + op->args[2].as()->value + "_" + @@ -2107,13 +2117,13 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { os << "nvshmem_barrier_all()"; } else if (op->op.same_as(tl::fence_cta())) { this->use_distributed_ = true; - os << "tl::memory_fence_cta()"; + os << "tl::memory_fence_cta(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::fence_gpu())) { this->use_distributed_ = true; - os << "tl::memory_fence_gpu()"; + os << "tl::memory_fence_gpu(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::fence_sys())) { this->use_distributed_ = true; - os << "tl::memory_fence_sys()"; + os << "tl::memory_fence_sys(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(builtin::reinterpret())) { DataType tgt_dtype = op->dtype; DataType src_dtype = op->args[0]->dtype; diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index fe4607020..7696c4163 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -241,36 +241,40 @@ TL_DEVICE uint32_t ptx_atom_add_acq_rel_gpu(const uint32_t *ptr, return ret; } -TL_DEVICE uint32_t ptx_atom_add_relaxed_sys(const uint32_t *ptr, +TL_DEVICE uint32_t ptx_atom_add_relaxed_sys(unsigned long addr, uint32_t value) { uint32_t ret; + const uint32_t *ptr = reinterpret_cast(addr); asm volatile("atom.add.relaxed.sys.global.u32 %0, [%1], %2;\n" : "=r"(ret) : "l"(ptr), "r"(value)); return ret; } -TL_DEVICE uint32_t ptx_atom_add_acquire_sys(const uint32_t *ptr, +TL_DEVICE uint32_t ptx_atom_add_acquire_sys(unsigned long addr, uint32_t value) { uint32_t ret; + const uint32_t *ptr = reinterpret_cast(addr); asm volatile("atom.add.acquire.sys.global.u32 %0, [%1], %2;\n" : "=r"(ret) : "l"(ptr), "r"(value)); return ret; } -TL_DEVICE uint32_t ptx_atom_add_release_sys(const uint32_t *ptr, +TL_DEVICE uint32_t ptx_atom_add_release_sys(unsigned long addr, uint32_t value) { uint32_t ret; + const uint32_t *ptr = reinterpret_cast(addr); asm volatile("atom.add.release.sys.global.u32 %0, [%1], %2;\n" : "=r"(ret) : "l"(ptr), "r"(value)); return ret; } -TL_DEVICE uint32_t ptx_atom_add_acq_rel_sys(const uint32_t *ptr, +TL_DEVICE uint32_t ptx_atom_add_acq_rel_sys(unsigned long addr, uint32_t value) { uint32_t ret; + const uint32_t *ptr = reinterpret_cast(addr); asm volatile("atom.add.acq_rel.sys.global.u32 %0, [%1], %2;\n" : "=r"(ret) : "l"(ptr), "r"(value)); diff --git a/src/tl_templates/cuda/ldst.h b/src/tl_templates/cuda/ldst.h index c875832eb..247027187 100644 --- a/src/tl_templates/cuda/ldst.h +++ b/src/tl_templates/cuda/ldst.h @@ -3,8 +3,8 @@ #include "common.h" // Memory semantic and scope enums -enum class Semantic { WEAK, VOLATILE, ACQUIRE, RELEASE, RELAXED }; -enum class Scope { CTA, GPU, SYS }; +enum class Semantic { WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL }; +enum class Scope { CTA, CLUSTER, GPU, SYS }; #ifndef TL_ALWAYS_FALSE_V_DEFINED #define TL_ALWAYS_FALSE_V_DEFINED diff --git a/src/tl_templates/cuda/sync.h b/src/tl_templates/cuda/sync.h index cad94ee7e..26e5ecc7f 100644 --- a/src/tl_templates/cuda/sync.h +++ b/src/tl_templates/cuda/sync.h @@ -12,22 +12,73 @@ namespace tl { +enum class SyncScope { CTA = 0, CLUSTER = 1, GPU = 2, SYSTEM = 3 }; + +enum class SyncSemantic { + WEAK = 0, + VOLATILE = 1, + RELAXED = 2, + ACQUIRE = 3, + RELEASE = 4, + ACQ_REL = 5, + SC = 6 +}; + // Triggers a GPU trap for debugging TL_DEVICE void trap() { asm("trap;\n"); } // CTA-level memory fence -TL_DEVICE void memory_fence_cta() { - asm volatile("fence.acq_rel.cta;\n" ::: "memory"); +TL_DEVICE void memory_fence_cta(int sem) { + switch (sem) { + case static_cast(SyncSemantic::ACQUIRE): + asm volatile("fence.acquire.cta;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::RELEASE): + asm volatile("fence.release.cta;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::SC): + asm volatile("fence.sc.cta;\n" ::: "memory"); + break; + default: + asm volatile("fence.acq_rel.cta;\n" ::: "memory"); + break; + } } // GPU-level memory fence -TL_DEVICE void memory_fence_gpu() { - asm volatile("fence.acq_rel.gpu;\n" ::: "memory"); +TL_DEVICE void memory_fence_gpu(int sem) { + switch (sem) { + case static_cast(SyncSemantic::ACQUIRE): + asm volatile("fence.acquire.gpu;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::RELEASE): + asm volatile("fence.release.gpu;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::SC): + asm volatile("fence.sc.gpu;\n" ::: "memory"); + break; + default: + asm volatile("fence.acq_rel.gpu;\n" ::: "memory"); + break; + } } // System-level memory fence -TL_DEVICE void memory_fence_sys() { - asm volatile("fence.acq_rel.sys;\n" ::: "memory"); +TL_DEVICE void memory_fence_sys(int sem) { + switch (sem) { + case static_cast(SyncSemantic::ACQUIRE): + asm volatile("fence.acquire.sys;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::RELEASE): + asm volatile("fence.release.sys;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::SC): + asm volatile("fence.sc.sys;\n" ::: "memory"); + break; + default: + asm volatile("fence.acq_rel.sys;\n" ::: "memory"); + break; + } } // GPU-level load with acquire semantics @@ -74,12 +125,13 @@ TL_DEVICE void init_barrier_gpu(uint32_t *barrier) { if (IS_MASTER_BLOCK() && IS_MASTER_THREAD()) { *barrier = BARRIER_MAGIC - kExpected; } - memory_fence_gpu(); // TODO: Is fence or sync needed here? + memory_fence_gpu(static_cast( + SyncSemantic::ACQ_REL)); // TODO: Is fence or sync needed here? } // Arrive at a GPU barrier (atomic increment) TL_DEVICE void arrive_barrier_gpu(uint32_t *barrier) { - memory_fence_gpu(); + memory_fence_gpu(static_cast(SyncSemantic::ACQ_REL)); if (IS_MASTER_THREAD()) { atomic_add_release_gpu_u32(barrier, 1); } @@ -98,7 +150,7 @@ TL_DEVICE void wait_barrier_gpu(uint32_t *barrier) { // Synchronize at a GPU barrier (arrive + wait) TL_DEVICE void sync_barrier_gpu(uint32_t *barrier) { - // memory_fence_gpu(); + memory_fence_gpu(static_cast(SyncSemantic::ACQ_REL)); __syncthreads(); if (IS_MASTER_THREAD()) { atomic_add_release_gpu_u32(barrier, 1); @@ -120,7 +172,7 @@ TL_DEVICE unsigned int sync_grids_arrive(uint32_t *barrier) { unsigned int expected = gridDim.x * gridDim.y * gridDim.z; unsigned int nb = 1; if (IS_MASTER_BLOCK()) { - nb = 0x80000000 - (expected - 1); + nb = BARRIER_MAGIC - (expected - 1); } asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" : "=r"(oldArrive) @@ -140,7 +192,7 @@ TL_DEVICE void sync_grids_wait(unsigned int oldArrive, uint32_t *barrier) { : "=r"(current_arrive) : "l"((unsigned int *)barrier) : "memory"); - } while (!(((oldArrive ^ current_arrive) & 0x80000000) != 0)); + } while (!(((oldArrive ^ current_arrive) & BARRIER_MAGIC) != 0)); } __syncthreads(); } @@ -150,7 +202,7 @@ TL_DEVICE void sync_grid(uint32_t *barrier) { sync_grids_wait(token, barrier); } -// Sync blocks at a system-level barrier with an optinal fence +// Sync blocks at a system-level barrier with an optional fence // TODO(wt): Add timeout handling template @@ -161,7 +213,7 @@ TL_DEVICE void barrier_blocks(int offset, int rank, int num_ranks) { #define FINISHED_SUM_TAG (1024) if constexpr (need_fence) { - memory_fence_sys(); + memory_fence_sys(static_cast(SyncSemantic::ACQ_REL)); __syncthreads(); } @@ -184,62 +236,180 @@ TL_DEVICE void barrier_blocks(int offset, int rank, int num_ranks) { #undef FINISHED_SUM_TAG } -template TL_DEVICE void wait_eq(void *ptr, T val) { +using WaitScope = SyncScope; +using WaitSemantic = SyncSemantic; + +// Load with volatile semantics (GPU scope, faster but no cross-PE guarantees) +template +TL_DEVICE T ld_wait_gpu(const T *ptr, WaitSemantic semantic) { + int ret = 0; + if constexpr (std::is_same_v) { + if (semantic == WaitSemantic::RELAXED) { + asm volatile("ld.global.relaxed.gpu.s32 %0, [%1];\n" + : "=r"(ret) + : "l"(ptr)); + } else if (semantic == WaitSemantic::VOLATILE) { + asm volatile("ld.global.volatile.gpu.s32 %0, [%1];\n" + : "=r"(ret) + : "l"(ptr)); + } else { + // Default to acquire + asm volatile("ld.global.acquire.gpu.s32 %0, [%1];\n" + : "=r"(ret) + : "l"(ptr)); + } + return ret; + } else if constexpr (std::is_same_v || + std::is_same_v) { + // Cast to int* for ld_volatile_global, then cast back + const int *int_ptr = reinterpret_cast(ptr); + if (semantic == WaitSemantic::RELAXED) { + asm volatile("ld.global.relaxed.gpu.u32 %0, [%1];\n" + : "=r"(ret) + : "l"(int_ptr)); + } else if (semantic == WaitSemantic::VOLATILE) { + asm volatile("ld.global.volatile.gpu.u32 %0, [%1];\n" + : "=r"(ret) + : "l"(int_ptr)); + } else { + // Default to acquire + asm volatile("ld.global.acquire.gpu.u32 %0, [%1];\n" + : "=r"(ret) + : "l"(int_ptr)); + } + return static_cast(ret); + } else { + return *reinterpret_cast(ptr); + } +} + +// Load with acquire.sys semantics (SYSTEM scope, required for proper cross-PE +// sync) +template +TL_DEVICE T ld_wait_sys(const T *ptr, WaitSemantic semantic) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + unsigned int ret = 0; + if (semantic == WaitSemantic::RELAXED) { + asm volatile("ld.global.relaxed.sys.s32 %0, [%1];\n" + : "=r"(ret) + : "l"(ptr)); + } else if (semantic == WaitSemantic::VOLATILE) { + asm volatile("ld.global.volatile.sys.s32 %0, [%1];\n" + : "=r"(ret) + : "l"(ptr)); + } else { + // Default to acquire + asm volatile("ld.global.acquire.sys.s32 %0, [%1];\n" + : "=r"(ret) + : "l"(ptr)); + } + return static_cast(ret); + } else { + // Fallback to volatile for other types + return *reinterpret_cast(ptr); + } +} + +// Generic load dispatcher based on scope and semantic +template +TL_DEVICE T ld_wait_generic(const T *ptr, WaitScope scope, + WaitSemantic semantic = WaitSemantic::ACQUIRE) { + if (scope == WaitScope::SYSTEM) { + return ld_wait_sys(ptr, semantic); + } else { + return ld_wait_gpu(ptr, semantic); + } +} + +template +TL_DEVICE T wait_eq(P ptr, T val, int scope = (int)WaitScope::SYSTEM, + int semantic = (int)WaitSemantic::ACQUIRE) { + static_assert(std::is_same_v || std::is_pointer_v

, + "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); + T ret; // Spin-loop #pragma unroll 1 - while (ld_acquire(flag_ptr) != val) + while ((ret = ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic)) != + val) ; + return ret; } -template TL_DEVICE void wait_ne(P ptr, T val) { +template +TL_DEVICE T wait_ne(P ptr, T val, int scope = (int)WaitScope::SYSTEM, + int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); + T ret; // Spin-loop #pragma unroll 1 - while (ld_volatile_global(flag_ptr) == val) + while ((ret = ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic)) == + val) ; + return ret; } -template TL_DEVICE void wait_ge(P ptr, T val) { +template +TL_DEVICE T wait_ge(P ptr, T val, int scope = (int)WaitScope::SYSTEM, + int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); + T ret; // Spin-loop #pragma unroll 1 - while (ld_volatile_global(flag_ptr) < val) + while ((ret = ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic)) < + val) ; + return ret; } -template TL_DEVICE void wait_le(P ptr, T val) { +template +TL_DEVICE T wait_le(P ptr, T val, int scope = (int)WaitScope::SYSTEM, + int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); + T ret; // Spin-loop #pragma unroll 1 - while (ld_volatile_global(flag_ptr) > val) + while ((ret = ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic)) > + val) ; + return ret; } -template TL_DEVICE void wait_gt(P ptr, T val) { +template +TL_DEVICE T wait_gt(P ptr, T val, int scope = (int)WaitScope::SYSTEM, + int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); + T ret; // Spin-loop #pragma unroll 1 - while (ld_volatile_global(flag_ptr) <= val) + while ((ret = ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic)) <= + val) ; + return ret; } -template TL_DEVICE void wait_lt(P ptr, T val) { +template +TL_DEVICE T wait_lt(P ptr, T val, int scope = (int)WaitScope::SYSTEM, + int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); + T ret; // Spin-loop #pragma unroll 1 - while (ld_volatile_global(flag_ptr) >= val) + while ((ret = ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic)) >= + val) ; + return ret; } } // namespace tl diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 104b46c79..8ff89f773 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -138,7 +138,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor { role = Role::kProducer; has_bulk_copy_ = true; } - if (call->op.same_as(loop_break()) || call->op.same_as(wait_eq())) + if (call->op.same_as(loop_break()) || call->op.same_as(WaitOp::Get())) role = Role::kBoth; if (call->op.same_as(get_clock())) role = Role::kBoth; diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 8db638308..e285bc655 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -3,10 +3,11 @@ from tilelang import tvm as tvm from tilelang.language import ptx_arrive_barrier, evaluate, address_of +from tilelang.language.utils import MemoryScope, MemorySemantic from tilelang.language.kernel import get_thread_bindings, get_block_extents from tilelang.utils.target import check_hip_availability from tvm import tir -from typing import Any, Literal +from typing import Any import tilelang.language as T from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad @@ -612,15 +613,16 @@ def sync_grid(barrier: PrimExpr): return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid"), address_of(barrier)) -def barrier_blocks(barrier: PrimExpr): +def barrier_blocks(barrier: PrimExpr, need_fence: bool = True): """Barrier all blocks at a system-level barrier. Compare to sync_blocks, barrier_blocks have an extra system-level fence effect Args: barrier: The barrier to synchronize at, should be [num_ranks] of int32 + need_fence: Whether need fence. Default to True """ return tir.call_intrin("handle", tir.op.Op.get("tl.barrier_blocks"), address_of(barrier), - 1) # whether need fence + need_fence) # whether need fence def sync_blocks(barrier: PrimExpr): @@ -633,19 +635,19 @@ def sync_blocks(barrier: PrimExpr): 0) # whether need fence -def fence_cta(): +def fence_cta(sem: MemorySemantic = MemorySemantic.ACQ_REL): """Create a memory fence at the block level (visible to all threads in the current block).""" - return tir.call_intrin("handle", tir.op.Op.get("tl.fence_cta")) + return tir.call_intrin("handle", tir.op.Op.get("tl.fence_cta"), sem.value) -def fence_gpu(): +def fence_gpu(sem: MemorySemantic = MemorySemantic.ACQ_REL): """Synchronize all threads at the GPU level (visible to all blocks on the current device).""" - return tir.call_intrin("handle", tir.op.Op.get("tl.fence_gpu")) + return tir.call_intrin("handle", tir.op.Op.get("tl.fence_gpu"), sem.value) -def fence_sys(): +def fence_sys(sem: MemorySemantic = MemorySemantic.ACQ_REL): """Synchronize all threads at the system level (visible in a node).""" - return tir.call_intrin("handle", tir.op.Op.get("tl.fence_sys")) + return tir.call_intrin("handle", tir.op.Op.get("tl.fence_sys"), sem.value) def get_clock(): @@ -728,21 +730,28 @@ def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call): return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) -def atom_add(barrier: PrimExpr, value: PrimExpr, scope: str = "gpu", sem: str = "relaxed"): +def atom_add(barrier: PrimExpr, + value: PrimExpr, + scope: MemoryScope = MemoryScope.GPU, + sem: MemorySemantic = MemorySemantic.RELAXED): """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. """ - assert scope in ["gpu", "sys"], "Scope must be one of 'gpu', or 'sys'." - assert sem in ["relaxed", "acquire", "release", "acq_rel" - ], "Semantic must be one of 'relaxed', 'acquire', 'release', or 'acq_rel'." - return tir.call_intrin("uint32", tir.op.Op.get("tl.atom_add"), address_of(barrier), value, sem, - scope) + scope_str = {MemoryScope.GPU: "gpu", MemoryScope.SYSTEM: "sys"}[scope] + sem_str = { + MemorySemantic.RELAXED: "relaxed", + MemorySemantic.ACQUIRE: "acquire", + MemorySemantic.RELEASE: "release", + MemorySemantic.ACQ_REL: "acq_rel" + }[sem] + return tir.call_intrin("uint32", tir.op.Op.get("tl.atom_add"), address_of(barrier), value, + sem_str, scope_str) def ld( src: PrimExpr, value: PrimExpr, - scope: Literal["cta", "gpu", "sys"] = "gpu", - sem: Literal["weak", "volatile", "acquire", "release", "relaxed"] = "weak", + scope: MemoryScope = MemoryScope.GPU, + sem: MemorySemantic = MemorySemantic.WEAK, na: bool = False, nc: bool = False, src_pe: tir.PrimExpr | tir.IntImm | None = -1, @@ -762,23 +771,17 @@ def ld( Returns: tir.Call: A handle to the load operation. """ - assert scope in ["cta", "gpu", "sys"], "Scope must be one of 'cta', 'gpu', or 'sys'." - assert sem in [ - "weak", "volatile", "acquire", "relaxed" - ], "Semantic must be one of 'weak', 'volatile', 'acquire', 'release', or 'relaxed'." - scope = {"cta": 0, "gpu": 1, "sys": 2}[scope] - sem = {"weak": 0, "volatile": 1, "acquire": 2, "release": 3, "relaxed": 4}[sem] na = 1 if na else 0 nc = 1 if nc else 0 - return tir.call_intrin("handle", tir.op.Op.get("tl.ld"), address_of(src), value, sem, scope, na, - nc, src_pe) + return tir.call_intrin("handle", tir.op.Op.get("tl.ld"), address_of(src), value, sem.value, + scope.value, na, nc, src_pe) def st( dst: PrimExpr, value: PrimExpr, - scope: Literal["cta", "gpu", "sys"] = "gpu", - sem: Literal["weak", "volatile", "release", "relaxed"] = "weak", + scope: MemoryScope = MemoryScope.GPU, + sem: MemorySemantic = MemorySemantic.WEAK, na: bool = False, dst_pe: tir.PrimExpr | tir.IntImm | None = -1, ): @@ -796,16 +799,59 @@ def st( Returns: tir.Call: A handle to the store operation. """ - assert scope in ["cta", "gpu", "sys"], "Scope must be one of 'cta', 'gpu', or 'sys'." - assert sem in ["weak", "volatile", "release", "relaxed" - ], "Semantic must be one of 'weak', 'volatile', 'release', or 'relaxed'." - - # convert to int - scope = {"cta": 0, "gpu": 1, "sys": 2}[scope] - sem = {"weak": 0, "volatile": 1, "acquire": 2, "release": 3, "relaxed": 4}[sem] na = 1 if na else 0 - return tir.call_intrin("handle", tir.op.Op.get("tl.st"), address_of(dst), value, sem, scope, na, - dst_pe) + return tir.call_intrin("handle", tir.op.Op.get("tl.st"), address_of(dst), value, sem.value, + scope.value, na, dst_pe) + + +def atom_add_remote( + dst: PrimExpr, + value: PrimExpr, + scope: MemoryScope = MemoryScope.SYSTEM, + sem: MemorySemantic = MemorySemantic.RELAXED, + dst_pe: tir.PrimExpr | tir.IntImm | None = -1, +): + """Perform a remote atomic add operation with return value support + + Args: + dst: The destination address to store the value to. + value: The value to store. + scope: The memory scope. + sem: The memory semantic. + dst_pe: The destination processing element (PE) identifier. + Use -1 (default) for local PE, or a non-negative integer to target a remote PE. + + Returns: + tir.Call: Returns the old value before the atomic add (uint32). + """ + scope_str = {MemoryScope.GPU: "gpu", MemoryScope.SYSTEM: "sys"}[scope] + sem_str = { + MemorySemantic.RELAXED: "relaxed", + MemorySemantic.ACQUIRE: "acquire", + MemorySemantic.RELEASE: "release", + MemorySemantic.ACQ_REL: "acq_rel" + }[sem] + + # Build the intrinsic function name + func_name = f"tl::ptx_atom_add_{sem_str}_{scope_str}" + + # If dst_pe is specified and not -1, compute remote address + is_remote = not (isinstance(dst_pe, tir.IntImm) and dst_pe.value == -1) + + if is_remote: + # Compute remote address: remote_base_ptr(dst_pe) + (address_of(dst) - remote_base_ptr(get_rank())) + local_rank = tir.Call("int64", tir.op.Op.get("tl.get_rank"), []) + local_base_ptr = tir.Call("handle", tir.op.Op.get("tl.get_remote_base_ptr"), [local_rank]) + offset_to_base = tir.Sub( + tir.Call("handle", tir.op.Op.get("tl.get_uintptr_t"), [address_of(dst)]), + local_base_ptr) + remote_ptr = tir.Add( + tir.Call("handle", tir.op.Op.get("tl.get_remote_base_ptr"), [dst_pe]), offset_to_base) + # Call the PTX intrinsic directly with remote address + return tir.call_extern("uint32", func_name, remote_ptr, value) + else: + # Local atomic add + return tir.call_extern("uint32", func_name, address_of(dst), value) def elect_one_sync(): diff --git a/tilelang/language/distributed/common.py b/tilelang/language/distributed/common.py index adb559e92..b311bfcde 100644 --- a/tilelang/language/distributed/common.py +++ b/tilelang/language/distributed/common.py @@ -5,6 +5,7 @@ from tvm.tir import address_of from tvm.tir import PrimExpr, IntImm from enum import Enum +from tilelang.language.utils import MemoryScope, MemorySemantic def get_rank(): @@ -122,41 +123,122 @@ class BinaryRelation(Enum): LT = 5 -def wait_eq(barrier: PrimExpr, expected: PrimExpr): +def wait_eq(barrier: PrimExpr, + expected: PrimExpr, + peer: PrimExpr | None = -1, + scope: MemoryScope = MemoryScope.SYSTEM, + semantic: MemorySemantic = MemorySemantic.ACQUIRE, + dtype = "int32"): """Wait until *barrier == expected* for GPU-level synchronization. # todo: have different semantic compared to 3 fns below currently Args: barrier: The barrier to wait at expected: The expected value to wait for + peer: The PE to wait on (-1 for local) + scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) + dtype: The data type of the memory address, must be int32 or uint32 """ - return tir.call_intrin("handle", tir.op.Op.get("tl.wait_eq"), address_of(barrier), expected) + return tir.call_intrin(dtype, tir.op.Op.get("tl.wait"), BinaryRelation.EQ.value, + address_of(barrier), expected, peer, scope.value, semantic.value) -def wait_ne(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1): - """Wait until *ptr != expected""" - return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.NE.value, - address_of(ptr), expected, peer) +def wait_ne(ptr: PrimExpr, + expected: PrimExpr, + peer: PrimExpr | None = -1, + scope: MemoryScope = MemoryScope.SYSTEM, + semantic: MemorySemantic = MemorySemantic.ACQUIRE, + dtype = "int32"): + """Wait until *ptr != expected + Args: + ptr: The memory address to wait on + expected: The value to compare against + peer: The PE to wait on (-1 for local) + scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) + dtype: The data type of the memory address, must be int32 or uint32 + """ + return tir.call_intrin(dtype, tir.op.Op.get("tl.wait"), BinaryRelation.NE.value, + address_of(ptr), expected, peer, scope.value, semantic.value) + + +def wait_ge(ptr: PrimExpr, + expected: PrimExpr, + peer: PrimExpr | None = -1, + scope: MemoryScope = MemoryScope.SYSTEM, + semantic: MemorySemantic = MemorySemantic.ACQUIRE, + dtype = "int32"): + """Wait until *ptr >= expected -def wait_ge(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1): - """Wait until *ptr >= expected""" - return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.GE.value, - address_of(ptr), expected, peer) + Args: + ptr: The memory address to wait on + expected: The value to compare against + peer: The PE to wait on (-1 for local) + scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) + dtype: The data type of the memory address + dtype: The data type of the memory address, must be int32 or uint32 + """ + return tir.call_intrin(dtype, tir.op.Op.get("tl.wait"), BinaryRelation.GE.value, + address_of(ptr), expected, peer, scope.value, semantic.value) -def wait_le(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1): - """Wait until *ptr <= expected""" - return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.LE.value, - address_of(ptr), expected, peer) +def wait_le(ptr: PrimExpr, + expected: PrimExpr, + peer: PrimExpr | None = -1, + scope: MemoryScope = MemoryScope.SYSTEM, + semantic: MemorySemantic = MemorySemantic.ACQUIRE, + dtype = "int32"): + """Wait until *ptr <= expected + Args: + ptr: The memory address to wait on + expected: The value to compare against + peer: The PE to wait on (-1 for local) + scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) + dtype: The data type of the memory address, must be int32 or uint32 + """ + return tir.call_intrin(dtype, tir.op.Op.get("tl.wait"), BinaryRelation.LE.value, + address_of(ptr), expected, peer, scope.value, semantic.value) -def wait_gt(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1): - """Wait until *ptr > expected""" - return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.GT.value, - address_of(ptr), expected, peer) +def wait_gt(ptr: PrimExpr, + expected: PrimExpr, + peer: PrimExpr | None = -1, + scope: MemoryScope = MemoryScope.SYSTEM, + semantic: MemorySemantic = MemorySemantic.ACQUIRE, + dtype = "int32"): + """Wait until *ptr > expected -def wait_lt(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1): - """Wait until *ptr < expected""" - return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.LT.value, - address_of(ptr), expected, peer) + Args: + ptr: The memory address to wait on + expected: The value to compare against + peer: The PE to wait on (-1 for local) + scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) + dtype: The data type of the memory address, must be int32 or uint32 + """ + return tir.call_intrin(dtype, tir.op.Op.get("tl.wait"), BinaryRelation.GT.value, + address_of(ptr), expected, peer, scope.value, semantic.value) + + +def wait_lt(ptr: PrimExpr, + expected: PrimExpr, + peer: PrimExpr | None = -1, + scope: MemoryScope = MemoryScope.SYSTEM, + semantic: MemorySemantic = MemorySemantic.ACQUIRE, + dtype = "int32"): + """Wait until *ptr < expected + + Args: + ptr: The memory address to wait on + expected: The value to compare against + peer: The PE to wait on (-1 for local) + scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) + dtype: The data type of the memory address, must be int32 or uint32 + """ + return tir.call_intrin(dtype, tir.op.Op.get("tl.wait"), BinaryRelation.LT.value, + address_of(ptr), expected, peer, scope.value, semantic.value) diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 161a09c45..e0d90b354 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -4,6 +4,7 @@ from tvm import tir from tvm.tir import PrimExpr, Buffer, BufferLoad, op from tilelang import language as T +from enum import Enum def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): @@ -158,3 +159,36 @@ def linear_index(*args: PrimExpr) -> PrimExpr: for idx, stride in zip(coords[1:], strides): linear = linear * stride + idx return linear + + +class MemoryScope(Enum): + """Memory scope for wait operations. + + - CTA: Uses ld.volatile.cta (faster, suitable for same-CTA synchronization) + - CLUSTER: Uses ld.volatile.cluster (faster, suitable for same-cluster synchronization) + - GPU: Uses ld.volatile.gpu (faster, suitable for same-GPU synchronization) + - SYSTEM: Uses ld.acquire.sys (required for cross-PE/NUMA synchronization with st.release.sys) + """ + CTA = 0 + CLUSTER = 1 + GPU = 2 + SYSTEM = 3 + + +class MemorySemantic(Enum): + """Memory semantic for memory operations. + + - WEAK: Uses ld.weak (no synchronization) + - VOLATILE: Uses ld.volatile (faster, suitable for same-GPU synchronization) + - RELAXED: Uses ld.relaxed (faster, suitable for same-GPU synchronization) + - ACQUIRE: Uses ld.acquire (slower, suitable for cross-PE/NUMA synchronization with st.release) + - RELEASE: Uses ld.release (slower, suitable for cross-PE/NUMA synchronization with st.acquire) + - ACQ_REL: Uses ld.acq_rel (both acquire and release semantics) + """ + WEAK = 0 + VOLATILE = 1 + RELAXED = 2 + ACQUIRE = 3 + RELEASE = 4 + ACQ_REL = 5 + SC = 6