From c859a165ae47fd1e132693d2cb19b4d951b41c0e Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Fri, 30 Jan 2026 06:14:16 +0000 Subject: [PATCH 01/28] [Fence] Add fence options for barrier_blocks --- src/op/distributed.cc | 6 ++++-- src/op/sync.cc | 2 +- tilelang/language/builtin.py | 5 +++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/op/distributed.cc b/src/op/distributed.cc index 84a23afa7..481c4b729 100644 --- a/src/op/distributed.cc +++ b/src/op/distributed.cc @@ -194,8 +194,10 @@ TIR_DEFINE_TL_BUILTIN(CpengineCpAsync) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(get_rank).set_num_inputs(0).set_attr( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(get_rank) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); TIR_DEFINE_TL_BUILTIN(get_num_ranks) .set_num_inputs(0) diff --git a/src/op/sync.cc b/src/op/sync.cc index 892fc2220..487f381c6 100644 --- a/src/op/sync.cc +++ b/src/op/sync.cc @@ -192,7 +192,7 @@ TileOperator WaitOpNode::Clone() const { } TIR_REGISTER_TL_OP(BarrierBlocksOp, barrier_blocks) - .set_num_inputs(1) + .set_num_inputs(2) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 8db638308..03efd4359 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -612,15 +612,16 @@ def sync_grid(barrier: PrimExpr): return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid"), address_of(barrier)) -def barrier_blocks(barrier: PrimExpr): +def barrier_blocks(barrier: PrimExpr, need_fence: bool = True): """Barrier all blocks at a system-level barrier. Compare to sync_blocks, barrier_blocks have an extra system-level fence effect Args: barrier: The barrier to synchronize at, should be [num_ranks] of int32 + need_fence: Whether need fence. Default to True """ return tir.call_intrin("handle", tir.op.Op.get("tl.barrier_blocks"), address_of(barrier), - 1) # whether need fence + need_fence) # whether need fence def sync_blocks(barrier: PrimExpr): From cd4e50911f471e6616d3edf645123df814c9aa7c Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 4 Feb 2026 05:30:41 +0000 Subject: [PATCH 02/28] [Feature] Add remote atomic-add and more scopes/semantics for wait op --- .../intranode/example_alltoall_route2x4.py | 297 ++++++++++++++++++ src/op/remote_copy.cc | 69 ++++ src/op/remote_copy.h | 54 ++++ src/op/sync.cc | 6 +- src/op/sync.h | 8 +- src/tl_templates/cuda/atomic.h | 12 +- src/tl_templates/cuda/sync.h | 102 ++++-- tilelang/language/builtin.py | 49 +++ tilelang/language/distributed/common.py | 75 ++++- 9 files changed, 632 insertions(+), 40 deletions(-) create mode 100644 examples/distributed/intranode/example_alltoall_route2x4.py diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py new file mode 100644 index 000000000..31ba2b462 --- /dev/null +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -0,0 +1,297 @@ +import tilelang +import tilelang.language as T +from tilelang.distributed import init_dist +import torch +import torch.distributed as dist +import argparse +from pynvshmem import Amo +from enum import IntEnum + +tilelang.disable_cache() + +class Direction(IntEnum): + NORTH = 0 + SOUTH = 1 + WEST = 2 + EAST = 3 + SELF = 4 + + +@tilelang.jit(pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +},debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/") +def torus_alltoall_xy(PE_num, X, Y, M, N, block_M, block_N, threads): + buffer_size = PE_num + + @T.prim_func + def main( + # For each (src, dst) pair, the real transfer size is M * N + src: T.Tensor((PE_num * M, N), "float16"), + dst: T.Tensor((PE_num * M, N), "float16"), + # buffer[dir, slot, rank, *, *]: This PE save some slots for transferring data chunks for the real destination rank + buffer_direction: T.Tensor((4, buffer_size, PE_num, M, N), "float16"), + # Signal for each buffer slot in each direction + signal_direction: T.Tensor((4, buffer_size, PE_num), "uint32"), + # Signal for finish + local_finish: T.Tensor((1), "uint32"), + global_finish: T.Tensor((1), "uint32"), + # Barrier for all blocks + barrier: T.Tensor((PE_num), "int32"), + ): + with T.Kernel(PE_num, 4, threads=threads) as (bx, by): + tx = T.get_thread_binding() + + rank = T.alloc_local([1], "uint32") + rank_x = T.alloc_local([1], "uint32") + rank_y = T.alloc_local([1], "uint32") + next_rank = T.alloc_local([1], "uint32") + from_dir = T.alloc_local([1], "uint32") + to_dir = T.alloc_local([1], "uint32") + diff = T.alloc_local([1], "int32") + dst_rank = T.alloc_local([1], "uint32") + old_local = T.alloc_local([1], "uint32") + old_global = T.alloc_local([1], "uint32") + + # TODO maybe we do not need send/recv count, for each data to dst rank will pass current rank **at most** once + # Counter for sending data in each direction (cumulative across hops) + send_count = T.alloc_local([4], "uint32") + + # Counter for receiving/waiting data from each direction + recv_count = T.alloc_local([4], "uint32") + + rank[0] = T.get_rank() + rank_x[0] = T.floordiv(rank[0], Y) + rank_y[0] = T.floormod(rank[0], Y) + next_rank[0] = 0 + + for i in T.serial(4): + send_count[i] = 0 + recv_count[i] = 0 + + dst_rank[0] = bx + from_dir[0] = by + + # Prepare for routing + dst_rank_x = T.floordiv(dst_rank[0], Y) + dst_rank_y = T.floormod(dst_rank[0], Y) + to_dir[0] = Direction.SELF + # XY-routing: first route along X-axis, then Y-axis + if dst_rank_x != rank_x[0]: + # Calculate shortest path in Torus X-dimension + diff[0] = dst_rank_x - rank_x[0] + if diff[0] > T.floordiv(X, 2): + diff[0] -= X + elif diff[0] < -T.floordiv(X, 2): + diff[0] += X + + if diff[0] < 0: + # Send North (up): neighbor receives in its north buffer + to_dir[0] = Direction.NORTH + next_rank[0] = T.floormod(rank_x[0] + X - 1, X) * Y + rank_y[0] + else: + # Send South (down): neighbor receives in its south buffer + to_dir[0] = Direction.SOUTH + next_rank[0] = T.floormod(rank_x[0] + 1, X) * Y + rank_y[0] + elif dst_rank_y != rank_y[0]: + # Calculate shortest path in Torus Y-dimension + diff[0] = dst_rank_y - rank_y[0] + if diff[0] > T.floordiv(Y, 2): + diff[0] -= Y + elif diff[0] < -T.floordiv(Y, 2): + diff[0] += Y + + if diff[0] < 0: + # Send West (left): neighbor receives in its west buffer + to_dir[0] = Direction.WEST + next_rank[0] = rank_x[0] * Y + T.floormod(rank_y[0] + Y - 1, Y) + else: + # Send East (right): neighbor receives in its east buffer + to_dir[0] = Direction.EAST + next_rank[0] = rank_x[0] * Y + T.floormod(rank_y[0] + 1, Y) + + # Phase 1: Initial send from src to the target neighbor (choose one random 'by' is OK) + if dst_rank[0] != rank[0] and from_dir[0] == 0: + T.put_block( + T.address_of(src[dst_rank[0] * M, 0]), + # T.address_of(buffer_direction[to_dir[0], send_count[to_dir[0]], dst_rank[0], 0, 0]), + T.address_of(buffer_direction[to_dir[0], 0, dst_rank[0], 0, 0]), + M * N, + next_rank[0], + ) + if tx == 0: + T.st( + signal_direction[to_dir[0], send_count[to_dir[0]], dst_rank[0]], + 1, + scope='sys', + sem="release", + dst_pe=next_rank[0] + ) + T.sync_threads() + send_count[to_dir[0]] += 1 + # if tx == 0 and rank[0] == 1: + # T.print(to_dir[0], msg="to_dir") + # T.print(send_count[to_dir[0]], msg="send_count") + + T.barrier_blocks(barrier) + + # Phase 2: Each block handles one final dst data in one direction buffer of current rank and check whether to transfer + # Signal values: 0 = no signal, 1 = data ready, 2 = termination signal + with T.While(global_finish[0] < PE_num): + # Check if already finished before waiting to avoid deadlock + # if global_finish[0] >= PE_num: + # T.loop_break() + + # if tx == 0 and rank[0] == 0: + # T.print(from_dir[0], msg="from_dir") + # T.print(recv_count[from_dir[0]], msg="recv_count") + # T.print(dst_rank[0], msg="wait signal for rank:") + # T.print(signal_direction[from_dir[0], 0, dst_rank[0]], msg="wait signal for rank:") + if tx == 0: + # T.wait_gt(signal_direction[from_dir[0], recv_count[from_dir[0]], dst_rank[0]], 0) + T.wait_gt(signal_direction[from_dir[0], 0, dst_rank[0]], 0, scope=T.WaitScope.SYSTEM) + T.sync_threads() + + # if tx == 0 and rank[0] == 0: + # T.print(rank[0], "rank") + # T.print(dst_rank[0], "dst_rank") + # T.print(from_dir[0], "from_dir") + # T.print(to_dir[0], "to_dir") + + # if signal_direction[from_dir[0], recv_count[from_dir[0]], dst_rank[0]] == 1: + if signal_direction[from_dir[0], 0, dst_rank[0]] == 1: + # Only handle the transfer signal + if to_dir[0] != Direction.SELF: + T.put_block( + # T.address_of(buffer_direction[from_dir[0], recv_count[from_dir[0]], dst_rank[0], 0, 0]), + # T.address_of(buffer_direction[to_dir[0], send_count[to_dir[0]], dst_rank[0], 0, 0]), + T.address_of(buffer_direction[from_dir[0], 0, dst_rank[0], 0, 0]), + T.address_of(buffer_direction[to_dir[0], 0, dst_rank[0], 0, 0]), + M * N, + next_rank[0], + ) + if tx == 0: + T.st( + # signal_direction[to_dir[0], send_count[to_dir[0]], dst_rank[0]], + signal_direction[to_dir[0], 0, dst_rank[0]], + 1, + scope="sys", + sem="release", + dst_pe=next_rank[0], + ) + # if tx == 0 and rank[0] == 1: + # T.print(next_rank[0], msg="send signal to next_rank") + # T.print(to_dir[0], msg="to_dir") + # T.print(send_count[to_dir[0]], msg="send_count") + # T.print(dst_rank[0], msg="dst_rank") + T.sync_threads() + send_count[to_dir[0]] += 1 + else: + # Current rank is the real destination of this chunk of data, the real source rank is the buffer index + # T.copy(buffer_direction[from_dir[0], recv_count[from_dir[0]], dst_rank[0], 0, 0], dst[dst_rank[0] * M, 0]) + T.copy(buffer_direction[from_dir[0], 0, dst_rank[0], 0, 0], dst[dst_rank[0] * M, 0]) + if tx == 0: + # Use the OLD value returned by atom_add to ensure only ONE block executes the global notification + old_local[0] = T.atom_add( + local_finish[0], + 1, + scope="gpu", + sem="release", + ) + # T.print(old_local[0], "old_local") + if old_local[0] + 2 == PE_num: + # T.print(msg="Last chunk received, notifying all PEs") + for i in T.serial(PE_num): + old_global[0] = T.atom_add_remote( + global_finish[0], + 1, + scope="sys", + sem="release", + dst_pe=i, + ) + # T.print(old_global[0], "old_global (from last PE)") + if old_global[0] + 1 == PE_num: + # T.print(msg="This is the last PE! Sending termination signals to all PEs") + # Send termination signals to wake up all waiting blocks on all PEs + for remote_pe in T.serial(PE_num): + for direction in T.serial(4): + for slot in T.serial(buffer_size): + for dst_r in T.serial(PE_num): + T.st( + signal_direction[direction, slot, dst_r], + 2, + scope="sys", + sem="release", + dst_pe=remote_pe, + ) + # Fence to ensure all remote writes are visible to other PEs + # T.fence_sys() + T.sync_threads() + recv_count[from_dir[0]] += 1 + + return main + +def run_torus_alltoall(local_rank, num_ranks, args): + PE_num = args.PE_num + X, Y = args.X, args.Y + M, N = args.M, args.N + block_M, block_N = 16, N + threads = 128 + + local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) + allocator = tilelang.get_allocator( + size=2**32, + device="cuda", + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_ranks, + group=group_size, + ) + + kernel = torus_alltoall_xy(PE_num, X, Y, M, N, block_M, block_N, threads) + kernel.initialize(allocator=allocator) + + src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_(0, 1) + dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() + + buffer_size = PE_num + buffer_direction = tilelang.tensor((4, buffer_size, PE_num, M, N), torch.float16, allocator=allocator).zero_() + + # Signals for each buffer slot in each direction + signal_direction = tilelang.tensor((4, buffer_size, PE_num), torch.uint32, allocator=allocator).fill_(0) + local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) + global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) + barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() + + torch.cuda.synchronize() + dist.barrier(group_size) + + kernel(src, dst, buffer_direction, signal_direction, local_finish, global_finish, barrier) + + torch.cuda.synchronize() + dist.barrier(group_size) + + print(f"Rank {local_rank} TileLang AllToAll XY Routing Finished.") + + dst_ref = torch.zeros((PE_num * M, N), dtype=torch.float16, device="cuda") + dist.all_to_all_single(dst_ref, src, group=group_size) + torch.cuda.synchronize() + + if torch.allclose(dst, dst_ref, atol=1e-2, rtol=1e-2): + print(f"Rank {local_rank} Verification Passed! ✅") + else: + max_diff = (dst - dst_ref).abs().max() + print(f"Rank {local_rank} Verification Failed! ❌ Max diff: {max_diff}") + + dist.destroy_process_group() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=128) + parser.add_argument("--N", type=int, default=128) + parser.add_argument("--PE_num", type=int, default=8) + parser.add_argument("--X", type=int, default=2) + parser.add_argument("--Y", type=int, default=4) + args = parser.parse_args() + + torch.multiprocessing.spawn(run_torus_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) diff --git a/src/op/remote_copy.cc b/src/op/remote_copy.cc index fba501e48..dff1c98b1 100644 --- a/src/op/remote_copy.cc +++ b/src/op/remote_copy.cc @@ -382,6 +382,70 @@ TileOperator LdOpNode::Clone() const { return LdOp(node); } +AtomAddRemoteOp::AtomAddRemoteOp(Array args, BufferMap vmap) { + ObjectPtr node = make_object(); + node->dst = args[0]; + ICHECK(node->dst.as()) << "dst must be a call node"; + ICHECK(node->dst.as()->op.same_as(builtin::address_of())) + << "dst must be address_of op"; + + node->value = args[1]; + node->sem = args[2].as().value()->value; + node->scope = args[3].as().value()->value; + node->dst_pe = args[4]; + data_ = std::move(node); + (void)vmap; +} + +bool AtomAddRemoteOpNode::is_distributed() const { + return !(dst_pe->IsInstance() && + dst_pe.as()->value == -1); +} + +Stmt AtomAddRemoteOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + (void)analyzer; + (void)T; + Array new_args; + std::stringstream ss; + + // Map integers to semantic literal strings for PTX atom instruction + const char *sem_str[] = {"relaxed", "acquire", "release", "acq_rel"}; + const char *scope_str[] = {"gpu", "sys"}; + + // Build function name: tl::ptx_atom_add__ + ss << "tl::ptx_atom_add_" << sem_str[sem] << "_" << scope_str[scope]; + + new_args.push_back(StringImm(ss.str())); + if (is_distributed()) { + PrimExpr local_rank = Call(DataType::Int(64), tl::get_rank(), {}); + PrimExpr local_base_ptr = + Call(DataType::Handle(), tl::get_remote_base_ptr(), {local_rank}); + PrimExpr offset_to_base = Sub( + Call(DataType::Handle(), tl::get_uintptr_t(), {dst}), local_base_ptr); + new_args.push_back( + Call(DataType::Handle(), tl::get_remote_base_ptr(), {dst_pe}) + + offset_to_base); + } else { + new_args.push_back(dst); + } + new_args.push_back(value); + + auto atom_add = Call(DataType::UInt(32), builtin::call_extern(), new_args); + return Evaluate(atom_add); +} + +LayoutMap AtomAddRemoteOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + (void)T; + (void)level; + return {}; +} + +TileOperator AtomAddRemoteOpNode::Clone() const { + auto node = make_object(*this); + return AtomAddRemoteOp(node); +} + TIR_REGISTER_TL_OP(PutOp, put) .set_num_inputs(7) .set_attr("TCallEffectKind", @@ -398,10 +462,15 @@ TIR_REGISTER_TL_OP(StOp, st).set_num_inputs(6).set_attr( TIR_REGISTER_TL_OP(LdOp, ld).set_num_inputs(7).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_REGISTER_TL_OP(AtomAddRemoteOp, atom_add_remote).set_num_inputs(5).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TVM_FFI_STATIC_INIT_BLOCK({ PutOpNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK({ GetOpNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK({ StOpNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK({ LdOpNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK({ AtomAddRemoteOpNode::RegisterReflection(); }); + } // namespace tl } // namespace tvm diff --git a/src/op/remote_copy.h b/src/op/remote_copy.h index 3c118f33a..94645a89a 100644 --- a/src/op/remote_copy.h +++ b/src/op/remote_copy.h @@ -319,6 +319,60 @@ class LdOp : public TileOperator { static const Op &Get(); }; +class AtomAddRemoteOpNode : public TileOperatorNode { +public: + PrimExpr dst; ///< Destination address + PrimExpr value; ///< Value to atomically add + PrimExpr dst_pe; ///< Destination processing element (optional) + int scope; ///< Memory scope (0: GPU, 1: SYS) + int sem; ///< Memory semantic (0: relaxed, 1: acquire, 2: release, 3: acq_rel) + + bool is_distributed() const; + + static constexpr const char *_type_key = "tl.AtomAddRemoteOp"; + TVM_DECLARE_FINAL_OBJECT_INFO(AtomAddRemoteOpNode, TileOperatorNode); + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + static const Op &Get(); + TileOperator Clone() const override; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("dst", &AtomAddRemoteOpNode::dst) + .def_ro("value", &AtomAddRemoteOpNode::value) + .def_ro("dst_pe", &AtomAddRemoteOpNode::dst_pe) + .def_ro("scope", &AtomAddRemoteOpNode::scope) + .def_ro("sem", &AtomAddRemoteOpNode::sem); + } + + bool SEqualReduce(const AtomAddRemoteOpNode *other, SEqualReducer equal) const { + return equal(dst, other->dst) && equal(value, other->value) && + equal(dst_pe, other->dst_pe) && scope == other->scope && + sem == other->sem; + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(dst); + hash_reduce(value); + hash_reduce(dst_pe); + hash_reduce(scope); + hash_reduce(sem); + } + + static constexpr bool _type_has_method_sequal_reduce = true; + static constexpr bool _type_has_method_shash_reduce = true; +}; + +class AtomAddRemoteOp : public TileOperator { +public: + TVM_DEFINE_OBJECT_REF_METHODS(AtomAddRemoteOp, TileOperator, AtomAddRemoteOpNode); + TVM_DLL AtomAddRemoteOp(Array args, BufferMap vmap); + static const Op &Get(); +}; + } // namespace tl } // namespace tvm diff --git a/src/op/sync.cc b/src/op/sync.cc index 487f381c6..e7e1448ba 100644 --- a/src/op/sync.cc +++ b/src/op/sync.cc @@ -141,6 +141,8 @@ WaitOp::WaitOp(Array args, BufferMap vmap) { node->addr = args[1]; node->expected = args[2]; node->peer = args[3]; + // scope parameter is optional, default to SYSTEM (1) for safety + node->scope = (args.size() > 4) ? args[4].as()->value : 1; data_ = std::move(node); (void)vmap; } @@ -174,6 +176,8 @@ Stmt WaitOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { new_args.push_back(addr); } new_args.push_back(expected); + // Pass scope: 0=GPU (volatile), 1=SYSTEM (acquire.sys) + new_args.push_back(IntImm(DataType::Int(32), scope)); auto wait = Call(DataType::Handle(), builtin::call_extern(), new_args); return Evaluate(wait); @@ -197,7 +201,7 @@ TIR_REGISTER_TL_OP(BarrierBlocksOp, barrier_blocks) Integer(CallEffectKind::kOpaque)); TIR_REGISTER_TL_OP(WaitOp, wait) - .set_num_inputs(4) + .set_num_inputs(5) // relation, addr, expected, peer, scope (scope is optional, default=1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/sync.h b/src/op/sync.h index 16487877e..427749838 100644 --- a/src/op/sync.h +++ b/src/op/sync.h @@ -57,6 +57,7 @@ class WaitOpNode : public TileOperatorNode { PrimExpr expected; ///< The expected value to compare against. PrimExpr peer; ///< The peer to compare against. int relation; ///< The relation to compare against. + int scope; ///< Memory scope: 0=GPU (volatile), 1=SYSTEM (acquire.sys) bool is_distributed() const; @@ -75,12 +76,14 @@ class WaitOpNode : public TileOperatorNode { .def_ro("addr", &WaitOpNode::addr) .def_ro("expected", &WaitOpNode::expected) .def_ro("peer", &WaitOpNode::peer) - .def_ro("relation", &WaitOpNode::relation); + .def_ro("relation", &WaitOpNode::relation) + .def_ro("scope", &WaitOpNode::scope); } bool SEqualReduce(const WaitOpNode *other, SEqualReducer equal) const { return equal(addr, other->addr) && equal(expected, other->expected) && - equal(peer, other->peer) && equal(relation, other->relation); + equal(peer, other->peer) && relation == other->relation && + scope == other->scope; } void SHashReduce(SHashReducer hash_reduce) const { @@ -88,6 +91,7 @@ class WaitOpNode : public TileOperatorNode { hash_reduce(expected); hash_reduce(peer); hash_reduce(relation); + hash_reduce(scope); } static constexpr bool _type_has_method_sequal_reduce = true; diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index fe4607020..7696c4163 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -241,36 +241,40 @@ TL_DEVICE uint32_t ptx_atom_add_acq_rel_gpu(const uint32_t *ptr, return ret; } -TL_DEVICE uint32_t ptx_atom_add_relaxed_sys(const uint32_t *ptr, +TL_DEVICE uint32_t ptx_atom_add_relaxed_sys(unsigned long addr, uint32_t value) { uint32_t ret; + const uint32_t *ptr = reinterpret_cast(addr); asm volatile("atom.add.relaxed.sys.global.u32 %0, [%1], %2;\n" : "=r"(ret) : "l"(ptr), "r"(value)); return ret; } -TL_DEVICE uint32_t ptx_atom_add_acquire_sys(const uint32_t *ptr, +TL_DEVICE uint32_t ptx_atom_add_acquire_sys(unsigned long addr, uint32_t value) { uint32_t ret; + const uint32_t *ptr = reinterpret_cast(addr); asm volatile("atom.add.acquire.sys.global.u32 %0, [%1], %2;\n" : "=r"(ret) : "l"(ptr), "r"(value)); return ret; } -TL_DEVICE uint32_t ptx_atom_add_release_sys(const uint32_t *ptr, +TL_DEVICE uint32_t ptx_atom_add_release_sys(unsigned long addr, uint32_t value) { uint32_t ret; + const uint32_t *ptr = reinterpret_cast(addr); asm volatile("atom.add.release.sys.global.u32 %0, [%1], %2;\n" : "=r"(ret) : "l"(ptr), "r"(value)); return ret; } -TL_DEVICE uint32_t ptx_atom_add_acq_rel_sys(const uint32_t *ptr, +TL_DEVICE uint32_t ptx_atom_add_acq_rel_sys(unsigned long addr, uint32_t value) { uint32_t ret; + const uint32_t *ptr = reinterpret_cast(addr); asm volatile("atom.add.acq_rel.sys.global.u32 %0, [%1], %2;\n" : "=r"(ret) : "l"(ptr), "r"(value)); diff --git a/src/tl_templates/cuda/sync.h b/src/tl_templates/cuda/sync.h index cad94ee7e..298e97ab4 100644 --- a/src/tl_templates/cuda/sync.h +++ b/src/tl_templates/cuda/sync.h @@ -68,6 +68,12 @@ TL_DEVICE int ld_acquire(const int *ptr) { return ret; } +TL_DEVICE int ld_acquire_sys(const int *ptr) { + int ret = 0; + asm volatile("ld.global.acquire.sys.b32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + return ret; +} + // Initialize a GPU barrier template TL_DEVICE void init_barrier_gpu(uint32_t *barrier) { @@ -120,7 +126,7 @@ TL_DEVICE unsigned int sync_grids_arrive(uint32_t *barrier) { unsigned int expected = gridDim.x * gridDim.y * gridDim.z; unsigned int nb = 1; if (IS_MASTER_BLOCK()) { - nb = 0x80000000 - (expected - 1); + nb = BARRIER_MAGIC - (expected - 1); } asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;" : "=r"(oldArrive) @@ -140,7 +146,7 @@ TL_DEVICE void sync_grids_wait(unsigned int oldArrive, uint32_t *barrier) { : "=r"(current_arrive) : "l"((unsigned int *)barrier) : "memory"); - } while (!(((oldArrive ^ current_arrive) & 0x80000000) != 0)); + } while (!(((oldArrive ^ current_arrive) & BARRIER_MAGIC) != 0)); } __syncthreads(); } @@ -150,7 +156,7 @@ TL_DEVICE void sync_grid(uint32_t *barrier) { sync_grids_wait(token, barrier); } -// Sync blocks at a system-level barrier with an optinal fence +// Sync blocks at a system-level barrier with an optional fence // TODO(wt): Add timeout handling template @@ -184,6 +190,61 @@ TL_DEVICE void barrier_blocks(int offset, int rank, int num_ranks) { #undef FINISHED_SUM_TAG } +// Memory scope for wait operations +enum class WaitScope { + GPU = 0, // Use ld.volatile.global - suitable for same-GPU synchronization + SYSTEM = 1 // Use ld.acquire.sys - required for cross-PE/NUMA synchronization +}; + +// Load with volatile semantics (GPU scope, faster but no cross-PE guarantees) +template +TL_DEVICE T ld_wait_gpu(const T *ptr) { + if constexpr (std::is_same_v) { + return ld_volatile_global(ptr); + } else if constexpr (std::is_same_v || std::is_same_v) { + // Cast to int* for ld_volatile_global, then cast back + const int *int_ptr = reinterpret_cast(ptr); + return static_cast(ld_volatile_global(int_ptr)); + } else { + return *reinterpret_cast(ptr); + } +} + +// Load with acquire.sys semantics (SYSTEM scope, required for proper cross-PE sync) +template +TL_DEVICE T ld_wait_sys(const T *ptr) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + unsigned int ret = 0; + asm volatile("ld.global.acquire.sys.b32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + return static_cast(ret); + } else { + // Fallback to volatile for other types + return *reinterpret_cast(ptr); + } +} + +// Generic load dispatcher based on scope (runtime) +// Note: When scope is a compile-time constant, compiler should optimize away the branch +template + TL_DEVICE T ld_wait_generic(const T *ptr, WaitScope scope) { + if (scope == WaitScope::SYSTEM) { + return ld_wait_sys(ptr); + } else { + return ld_wait_gpu(ptr); + } +} + +// Compile-time scope selection (better performance when scope is known at compile time) +template + TL_DEVICE T ld_wait_static(const T *ptr) { + if constexpr (scope == WaitScope::SYSTEM) { + return ld_wait_sys(ptr); + } else { + return ld_wait_gpu(ptr); + } +} + template TL_DEVICE void wait_eq(void *ptr, T val) { T *flag_ptr = reinterpret_cast(ptr); // Spin-loop @@ -192,53 +253,58 @@ template TL_DEVICE void wait_eq(void *ptr, T val) { ; } -template TL_DEVICE void wait_ne(P ptr, T val) { +template +TL_DEVICE void wait_ne(P ptr, T val, WaitScope scope = WaitScope::SYSTEM) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); -// Spin-loop +// Spin-loop: use SYSTEM scope (acquire.sys) for cross-PE sync, GPU scope (volatile) for same-GPU #pragma unroll 1 - while (ld_volatile_global(flag_ptr) == val) + while (ld_wait_generic(flag_ptr, scope) == val) ; } -template TL_DEVICE void wait_ge(P ptr, T val) { +template +TL_DEVICE void wait_ge(P ptr, T val, WaitScope scope = WaitScope::SYSTEM) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); -// Spin-loop +// Spin-loop: use SYSTEM scope (acquire.sys) for cross-PE sync, GPU scope (volatile) for same-GPU #pragma unroll 1 - while (ld_volatile_global(flag_ptr) < val) + while (ld_wait_generic(flag_ptr, scope) < val) ; } -template TL_DEVICE void wait_le(P ptr, T val) { +template +TL_DEVICE void wait_le(P ptr, T val, WaitScope scope = WaitScope::SYSTEM) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); -// Spin-loop +// Spin-loop: use SYSTEM scope (acquire.sys) for cross-PE sync, GPU scope (volatile) for same-GPU #pragma unroll 1 - while (ld_volatile_global(flag_ptr) > val) + while (ld_wait_generic(flag_ptr, scope) > val) ; } -template TL_DEVICE void wait_gt(P ptr, T val) { +template +TL_DEVICE void wait_gt(P ptr, T val, WaitScope scope = WaitScope::SYSTEM) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); -// Spin-loop +// Spin-loop: use SYSTEM scope (acquire.sys) for cross-PE sync, GPU scope (volatile) for same-GPU #pragma unroll 1 - while (ld_volatile_global(flag_ptr) <= val) + while (ld_wait_generic(flag_ptr, scope) <= val) ; } -template TL_DEVICE void wait_lt(P ptr, T val) { +template +TL_DEVICE void wait_lt(P ptr, T val, WaitScope scope = WaitScope::SYSTEM) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); -// Spin-loop +// Spin-loop: use SYSTEM scope (acquire.sys) for cross-PE sync, GPU scope (volatile) for same-GPU #pragma unroll 1 - while (ld_volatile_global(flag_ptr) >= val) + while (ld_wait_generic(flag_ptr, scope) >= val) ; } diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 03efd4359..ec71c5d60 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -809,6 +809,55 @@ def st( dst_pe) +def atom_add_remote( + dst: PrimExpr, + value: PrimExpr, + scope: Literal["gpu", "sys"] = "sys", + sem: Literal["relaxed", "acquire", "release", "acq_rel"] = "relaxed", + dst_pe: tir.PrimExpr | tir.IntImm | None = -1, +): + """Perform a remote atomic add operation with return value support + + Args: + dst: The destination address to store the value to. + value: The value to store. + scope: The memory scope. + sem: The memory semantic. + dst_pe: The destination processing element (PE) identifier. + Use -1 (default) for local PE, or a non-negative integer to target a remote PE. + + Returns: + tir.Call: Returns the old value before the atomic add (uint32). + """ + assert scope in ["gpu", "sys"], "Scope must be one of 'gpu', or 'sys'." + assert sem in ["relaxed", "acquire", "release", "acq_rel" + ], "Semantic must be one of 'relaxed', 'acquire', 'release', or 'acq_rel'." + + # Build the intrinsic function name + func_name = f"tl::ptx_atom_add_{sem}_{scope}" + + # If dst_pe is specified and not -1, compute remote address + is_remote = not (isinstance(dst_pe, tir.IntImm) and dst_pe.value == -1) + + if is_remote: + # Compute remote address: remote_base_ptr(dst_pe) + (address_of(dst) - remote_base_ptr(get_rank())) + local_rank = tir.Call("int64", tir.op.Op.get("tl.get_rank"), []) + local_base_ptr = tir.Call("handle", tir.op.Op.get("tl.get_remote_base_ptr"), [local_rank]) + offset_to_base = tir.Sub( + tir.Call("handle", tir.op.Op.get("tl.get_uintptr_t"), [address_of(dst)]), + local_base_ptr + ) + remote_ptr = tir.Add( + tir.Call("handle", tir.op.Op.get("tl.get_remote_base_ptr"), [dst_pe]), + offset_to_base + ) + # Call the PTX intrinsic directly with remote address + return tir.call_extern("uint32", func_name, remote_ptr, value) + else: + # Local atomic add + return tir.call_extern("uint32", func_name, address_of(dst), value) + + def elect_one_sync(): """Efficiently elect exactly one lane within a warp.""" return tir.call_intrin("bool", tir.op.Op.get("tl.elect_one_sync")) diff --git a/tilelang/language/distributed/common.py b/tilelang/language/distributed/common.py index adb559e92..2be420e9a 100644 --- a/tilelang/language/distributed/common.py +++ b/tilelang/language/distributed/common.py @@ -122,6 +122,16 @@ class BinaryRelation(Enum): LT = 5 +class WaitScope(Enum): + """Memory scope for wait operations. + + - GPU: Uses ld.volatile.global (faster, suitable for same-GPU synchronization) + - SYSTEM: Uses ld.acquire.sys (required for cross-PE/NUMA synchronization with st.release.sys) + """ + GPU = 0 + SYSTEM = 1 + + def wait_eq(barrier: PrimExpr, expected: PrimExpr): """Wait until *barrier == expected* for GPU-level synchronization. # todo: have different semantic compared to 3 fns below currently @@ -132,31 +142,66 @@ def wait_eq(barrier: PrimExpr, expected: PrimExpr): return tir.call_intrin("handle", tir.op.Op.get("tl.wait_eq"), address_of(barrier), expected) -def wait_ne(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1): - """Wait until *ptr != expected""" +def wait_ne(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: WaitScope = WaitScope.SYSTEM): + """Wait until *ptr != expected + + Args: + ptr: The memory address to wait on + expected: The value to compare against + peer: The PE to wait on (-1 for local) + scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + """ return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.NE.value, - address_of(ptr), expected, peer) + address_of(ptr), expected, peer, scope.value) -def wait_ge(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1): - """Wait until *ptr >= expected""" +def wait_ge(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: WaitScope = WaitScope.SYSTEM): + """Wait until *ptr >= expected + + Args: + ptr: The memory address to wait on + expected: The value to compare against + peer: The PE to wait on (-1 for local) + scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + """ return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.GE.value, - address_of(ptr), expected, peer) + address_of(ptr), expected, peer, scope.value) -def wait_le(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1): - """Wait until *ptr <= expected""" +def wait_le(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: WaitScope = WaitScope.SYSTEM): + """Wait until *ptr <= expected + + Args: + ptr: The memory address to wait on + expected: The value to compare against + peer: The PE to wait on (-1 for local) + scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + """ return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.LE.value, - address_of(ptr), expected, peer) + address_of(ptr), expected, peer, scope.value) -def wait_gt(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1): - """Wait until *ptr > expected""" +def wait_gt(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: WaitScope = WaitScope.SYSTEM): + """Wait until *ptr > expected + + Args: + ptr: The memory address to wait on + expected: The value to compare against + peer: The PE to wait on (-1 for local) + scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + """ return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.GT.value, - address_of(ptr), expected, peer) + address_of(ptr), expected, peer, scope.value) -def wait_lt(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1): - """Wait until *ptr < expected""" +def wait_lt(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: WaitScope = WaitScope.SYSTEM): + """Wait until *ptr < expected + + Args: + ptr: The memory address to wait on + expected: The value to compare against + peer: The PE to wait on (-1 for local) + scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + """ return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.LT.value, - address_of(ptr), expected, peer) + address_of(ptr), expected, peer, scope.value) From e37fea4ef123550d5690c90d15b07a37c6dec8b2 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 4 Feb 2026 06:59:19 +0000 Subject: [PATCH 03/28] [Misc] Remove unused code --- .../intranode/example_alltoall_route2x4.py | 53 +------------------ 1 file changed, 1 insertion(+), 52 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index 31ba2b462..0ad26fb9b 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -52,22 +52,11 @@ def main( dst_rank = T.alloc_local([1], "uint32") old_local = T.alloc_local([1], "uint32") old_global = T.alloc_local([1], "uint32") - - # TODO maybe we do not need send/recv count, for each data to dst rank will pass current rank **at most** once - # Counter for sending data in each direction (cumulative across hops) - send_count = T.alloc_local([4], "uint32") - - # Counter for receiving/waiting data from each direction - recv_count = T.alloc_local([4], "uint32") rank[0] = T.get_rank() rank_x[0] = T.floordiv(rank[0], Y) rank_y[0] = T.floormod(rank[0], Y) next_rank[0] = 0 - - for i in T.serial(4): - send_count[i] = 0 - recv_count[i] = 0 dst_rank[0] = bx from_dir[0] = by @@ -114,57 +103,33 @@ def main( if dst_rank[0] != rank[0] and from_dir[0] == 0: T.put_block( T.address_of(src[dst_rank[0] * M, 0]), - # T.address_of(buffer_direction[to_dir[0], send_count[to_dir[0]], dst_rank[0], 0, 0]), T.address_of(buffer_direction[to_dir[0], 0, dst_rank[0], 0, 0]), M * N, next_rank[0], ) if tx == 0: T.st( - signal_direction[to_dir[0], send_count[to_dir[0]], dst_rank[0]], + signal_direction[to_dir[0], 0, dst_rank[0]], 1, scope='sys', sem="release", dst_pe=next_rank[0] ) T.sync_threads() - send_count[to_dir[0]] += 1 - # if tx == 0 and rank[0] == 1: - # T.print(to_dir[0], msg="to_dir") - # T.print(send_count[to_dir[0]], msg="send_count") T.barrier_blocks(barrier) # Phase 2: Each block handles one final dst data in one direction buffer of current rank and check whether to transfer # Signal values: 0 = no signal, 1 = data ready, 2 = termination signal with T.While(global_finish[0] < PE_num): - # Check if already finished before waiting to avoid deadlock - # if global_finish[0] >= PE_num: - # T.loop_break() - - # if tx == 0 and rank[0] == 0: - # T.print(from_dir[0], msg="from_dir") - # T.print(recv_count[from_dir[0]], msg="recv_count") - # T.print(dst_rank[0], msg="wait signal for rank:") - # T.print(signal_direction[from_dir[0], 0, dst_rank[0]], msg="wait signal for rank:") if tx == 0: - # T.wait_gt(signal_direction[from_dir[0], recv_count[from_dir[0]], dst_rank[0]], 0) T.wait_gt(signal_direction[from_dir[0], 0, dst_rank[0]], 0, scope=T.WaitScope.SYSTEM) T.sync_threads() - # if tx == 0 and rank[0] == 0: - # T.print(rank[0], "rank") - # T.print(dst_rank[0], "dst_rank") - # T.print(from_dir[0], "from_dir") - # T.print(to_dir[0], "to_dir") - - # if signal_direction[from_dir[0], recv_count[from_dir[0]], dst_rank[0]] == 1: if signal_direction[from_dir[0], 0, dst_rank[0]] == 1: # Only handle the transfer signal if to_dir[0] != Direction.SELF: T.put_block( - # T.address_of(buffer_direction[from_dir[0], recv_count[from_dir[0]], dst_rank[0], 0, 0]), - # T.address_of(buffer_direction[to_dir[0], send_count[to_dir[0]], dst_rank[0], 0, 0]), T.address_of(buffer_direction[from_dir[0], 0, dst_rank[0], 0, 0]), T.address_of(buffer_direction[to_dir[0], 0, dst_rank[0], 0, 0]), M * N, @@ -172,35 +137,24 @@ def main( ) if tx == 0: T.st( - # signal_direction[to_dir[0], send_count[to_dir[0]], dst_rank[0]], signal_direction[to_dir[0], 0, dst_rank[0]], 1, scope="sys", sem="release", dst_pe=next_rank[0], ) - # if tx == 0 and rank[0] == 1: - # T.print(next_rank[0], msg="send signal to next_rank") - # T.print(to_dir[0], msg="to_dir") - # T.print(send_count[to_dir[0]], msg="send_count") - # T.print(dst_rank[0], msg="dst_rank") T.sync_threads() - send_count[to_dir[0]] += 1 else: # Current rank is the real destination of this chunk of data, the real source rank is the buffer index - # T.copy(buffer_direction[from_dir[0], recv_count[from_dir[0]], dst_rank[0], 0, 0], dst[dst_rank[0] * M, 0]) T.copy(buffer_direction[from_dir[0], 0, dst_rank[0], 0, 0], dst[dst_rank[0] * M, 0]) if tx == 0: - # Use the OLD value returned by atom_add to ensure only ONE block executes the global notification old_local[0] = T.atom_add( local_finish[0], 1, scope="gpu", sem="release", ) - # T.print(old_local[0], "old_local") if old_local[0] + 2 == PE_num: - # T.print(msg="Last chunk received, notifying all PEs") for i in T.serial(PE_num): old_global[0] = T.atom_add_remote( global_finish[0], @@ -209,9 +163,7 @@ def main( sem="release", dst_pe=i, ) - # T.print(old_global[0], "old_global (from last PE)") if old_global[0] + 1 == PE_num: - # T.print(msg="This is the last PE! Sending termination signals to all PEs") # Send termination signals to wake up all waiting blocks on all PEs for remote_pe in T.serial(PE_num): for direction in T.serial(4): @@ -224,10 +176,7 @@ def main( sem="release", dst_pe=remote_pe, ) - # Fence to ensure all remote writes are visible to other PEs - # T.fence_sys() T.sync_threads() - recv_count[from_dir[0]] += 1 return main From 12a98d0d61e180d8f75c0c3f49125e270644cf97 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 4 Feb 2026 07:07:35 +0000 Subject: [PATCH 04/28] [Example] Remove redundant buffer --- .../intranode/example_alltoall_route2x4.py | 45 +++++++++---------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index 0ad26fb9b..eb106bd96 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -20,9 +20,8 @@ class Direction(IntEnum): @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, -},debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/") +}) def torus_alltoall_xy(PE_num, X, Y, M, N, block_M, block_N, threads): - buffer_size = PE_num @T.prim_func def main( @@ -30,9 +29,9 @@ def main( src: T.Tensor((PE_num * M, N), "float16"), dst: T.Tensor((PE_num * M, N), "float16"), # buffer[dir, slot, rank, *, *]: This PE save some slots for transferring data chunks for the real destination rank - buffer_direction: T.Tensor((4, buffer_size, PE_num, M, N), "float16"), + buffer_direction: T.Tensor((4, PE_num, M, N), "float16"), # Signal for each buffer slot in each direction - signal_direction: T.Tensor((4, buffer_size, PE_num), "uint32"), + signal_direction: T.Tensor((4, PE_num), "uint32"), # Signal for finish local_finish: T.Tensor((1), "uint32"), global_finish: T.Tensor((1), "uint32"), @@ -103,13 +102,13 @@ def main( if dst_rank[0] != rank[0] and from_dir[0] == 0: T.put_block( T.address_of(src[dst_rank[0] * M, 0]), - T.address_of(buffer_direction[to_dir[0], 0, dst_rank[0], 0, 0]), + T.address_of(buffer_direction[to_dir[0], dst_rank[0], 0, 0]), M * N, next_rank[0], ) if tx == 0: T.st( - signal_direction[to_dir[0], 0, dst_rank[0]], + signal_direction[to_dir[0], dst_rank[0]], 1, scope='sys', sem="release", @@ -123,21 +122,21 @@ def main( # Signal values: 0 = no signal, 1 = data ready, 2 = termination signal with T.While(global_finish[0] < PE_num): if tx == 0: - T.wait_gt(signal_direction[from_dir[0], 0, dst_rank[0]], 0, scope=T.WaitScope.SYSTEM) + T.wait_gt(signal_direction[from_dir[0], dst_rank[0]], 0, scope=T.WaitScope.SYSTEM) T.sync_threads() - if signal_direction[from_dir[0], 0, dst_rank[0]] == 1: + if signal_direction[from_dir[0], dst_rank[0]] == 1: # Only handle the transfer signal if to_dir[0] != Direction.SELF: T.put_block( - T.address_of(buffer_direction[from_dir[0], 0, dst_rank[0], 0, 0]), - T.address_of(buffer_direction[to_dir[0], 0, dst_rank[0], 0, 0]), + T.address_of(buffer_direction[from_dir[0], dst_rank[0], 0, 0]), + T.address_of(buffer_direction[to_dir[0], dst_rank[0], 0, 0]), M * N, next_rank[0], ) if tx == 0: T.st( - signal_direction[to_dir[0], 0, dst_rank[0]], + signal_direction[to_dir[0], dst_rank[0]], 1, scope="sys", sem="release", @@ -146,7 +145,7 @@ def main( T.sync_threads() else: # Current rank is the real destination of this chunk of data, the real source rank is the buffer index - T.copy(buffer_direction[from_dir[0], 0, dst_rank[0], 0, 0], dst[dst_rank[0] * M, 0]) + T.copy(buffer_direction[from_dir[0], dst_rank[0], 0, 0], dst[dst_rank[0] * M, 0]) if tx == 0: old_local[0] = T.atom_add( local_finish[0], @@ -167,15 +166,14 @@ def main( # Send termination signals to wake up all waiting blocks on all PEs for remote_pe in T.serial(PE_num): for direction in T.serial(4): - for slot in T.serial(buffer_size): - for dst_r in T.serial(PE_num): - T.st( - signal_direction[direction, slot, dst_r], - 2, - scope="sys", - sem="release", - dst_pe=remote_pe, - ) + for dst_r in T.serial(PE_num): + T.st( + signal_direction[direction, dst_r], + 2, + scope="sys", + sem="release", + dst_pe=remote_pe, + ) T.sync_threads() return main @@ -203,11 +201,10 @@ def run_torus_alltoall(local_rank, num_ranks, args): src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_(0, 1) dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() - buffer_size = PE_num - buffer_direction = tilelang.tensor((4, buffer_size, PE_num, M, N), torch.float16, allocator=allocator).zero_() + buffer_direction = tilelang.tensor((4, PE_num, M, N), torch.float16, allocator=allocator).zero_() # Signals for each buffer slot in each direction - signal_direction = tilelang.tensor((4, buffer_size, PE_num), torch.uint32, allocator=allocator).fill_(0) + signal_direction = tilelang.tensor((4, PE_num), torch.uint32, allocator=allocator).fill_(0) local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() From 231dad1495a9da8ae0c5b72c0b4d1b57540d9a47 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 4 Feb 2026 07:11:31 +0000 Subject: [PATCH 05/28] [Example] Remove direction-related buffer --- .../intranode/example_alltoall_route2x4.py | 47 +++++++++---------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index eb106bd96..89e15aec5 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -20,7 +20,7 @@ class Direction(IntEnum): @tilelang.jit(pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, -}) +},debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/") def torus_alltoall_xy(PE_num, X, Y, M, N, block_M, block_N, threads): @T.prim_func @@ -28,10 +28,10 @@ def main( # For each (src, dst) pair, the real transfer size is M * N src: T.Tensor((PE_num * M, N), "float16"), dst: T.Tensor((PE_num * M, N), "float16"), - # buffer[dir, slot, rank, *, *]: This PE save some slots for transferring data chunks for the real destination rank - buffer_direction: T.Tensor((4, PE_num, M, N), "float16"), - # Signal for each buffer slot in each direction - signal_direction: T.Tensor((4, PE_num), "uint32"), + # buffer[rank, *, *]: This PE save a slot for transferring data chunks for the real destination rank + buffer_direction: T.Tensor((PE_num, M, N), "float16"), + # Signal for each buffer + signal_direction: T.Tensor((PE_num), "uint32"), # Signal for finish local_finish: T.Tensor((1), "uint32"), global_finish: T.Tensor((1), "uint32"), @@ -102,13 +102,13 @@ def main( if dst_rank[0] != rank[0] and from_dir[0] == 0: T.put_block( T.address_of(src[dst_rank[0] * M, 0]), - T.address_of(buffer_direction[to_dir[0], dst_rank[0], 0, 0]), + T.address_of(buffer_direction[dst_rank[0], 0, 0]), M * N, next_rank[0], ) if tx == 0: T.st( - signal_direction[to_dir[0], dst_rank[0]], + signal_direction[dst_rank[0]], 1, scope='sys', sem="release", @@ -122,21 +122,21 @@ def main( # Signal values: 0 = no signal, 1 = data ready, 2 = termination signal with T.While(global_finish[0] < PE_num): if tx == 0: - T.wait_gt(signal_direction[from_dir[0], dst_rank[0]], 0, scope=T.WaitScope.SYSTEM) + T.wait_gt(signal_direction[dst_rank[0]], 0, scope=T.WaitScope.SYSTEM) T.sync_threads() - if signal_direction[from_dir[0], dst_rank[0]] == 1: + if signal_direction[dst_rank[0]] == 1: # Only handle the transfer signal if to_dir[0] != Direction.SELF: T.put_block( - T.address_of(buffer_direction[from_dir[0], dst_rank[0], 0, 0]), - T.address_of(buffer_direction[to_dir[0], dst_rank[0], 0, 0]), + T.address_of(buffer_direction[dst_rank[0], 0, 0]), + T.address_of(buffer_direction[dst_rank[0], 0, 0]), M * N, next_rank[0], ) if tx == 0: T.st( - signal_direction[to_dir[0], dst_rank[0]], + signal_direction[dst_rank[0]], 1, scope="sys", sem="release", @@ -145,7 +145,7 @@ def main( T.sync_threads() else: # Current rank is the real destination of this chunk of data, the real source rank is the buffer index - T.copy(buffer_direction[from_dir[0], dst_rank[0], 0, 0], dst[dst_rank[0] * M, 0]) + T.copy(buffer_direction[dst_rank[0], 0, 0], dst[dst_rank[0] * M, 0]) if tx == 0: old_local[0] = T.atom_add( local_finish[0], @@ -165,15 +165,14 @@ def main( if old_global[0] + 1 == PE_num: # Send termination signals to wake up all waiting blocks on all PEs for remote_pe in T.serial(PE_num): - for direction in T.serial(4): - for dst_r in T.serial(PE_num): - T.st( - signal_direction[direction, dst_r], - 2, - scope="sys", - sem="release", - dst_pe=remote_pe, - ) + for dst_r in T.serial(PE_num): + T.st( + signal_direction[dst_r], + 2, + scope="sys", + sem="release", + dst_pe=remote_pe, + ) T.sync_threads() return main @@ -201,10 +200,10 @@ def run_torus_alltoall(local_rank, num_ranks, args): src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_(0, 1) dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() - buffer_direction = tilelang.tensor((4, PE_num, M, N), torch.float16, allocator=allocator).zero_() + buffer_direction = tilelang.tensor((PE_num, M, N), torch.float16, allocator=allocator).zero_() # Signals for each buffer slot in each direction - signal_direction = tilelang.tensor((4, PE_num), torch.uint32, allocator=allocator).fill_(0) + signal_direction = tilelang.tensor((PE_num), torch.uint32, allocator=allocator).fill_(0) local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() From b496a5415d0115a340263d9ff6fa119c14b7f396 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 4 Feb 2026 08:28:16 +0000 Subject: [PATCH 06/28] [Refactor] Unified scope and semantic representation in tilescale language system --- .../intranode/example_alltoall_route2x4.py | 57 ++++----- src/op/remote_copy.cc | 25 ++-- src/op/sync.cc | 20 +-- src/op/sync.h | 16 +-- src/target/codegen_cuda.cc | 4 - src/tl_templates/cuda/ldst.h | 2 +- src/tl_templates/cuda/sync.h | 119 +++++++++++------- src/transform/warp_specialized_rewriter.cc | 2 +- tilelang/language/builtin.py | 48 +++---- tilelang/language/distributed/common.py | 68 ++++++---- tilelang/language/utils.py | 34 ++++- 11 files changed, 231 insertions(+), 164 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index 89e15aec5..a00022923 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -4,7 +4,6 @@ import torch import torch.distributed as dist import argparse -from pynvshmem import Amo from enum import IntEnum tilelang.disable_cache() @@ -29,23 +28,22 @@ def main( src: T.Tensor((PE_num * M, N), "float16"), dst: T.Tensor((PE_num * M, N), "float16"), # buffer[rank, *, *]: This PE save a slot for transferring data chunks for the real destination rank - buffer_direction: T.Tensor((PE_num, M, N), "float16"), + buffer_transfer: T.Tensor((PE_num, M, N), "float16"), # Signal for each buffer - signal_direction: T.Tensor((PE_num), "uint32"), + signal_transfer: T.Tensor((PE_num), "uint32"), # Signal for finish local_finish: T.Tensor((1), "uint32"), global_finish: T.Tensor((1), "uint32"), # Barrier for all blocks barrier: T.Tensor((PE_num), "int32"), ): - with T.Kernel(PE_num, 4, threads=threads) as (bx, by): + with T.Kernel(PE_num, threads=threads) as (bx): tx = T.get_thread_binding() rank = T.alloc_local([1], "uint32") rank_x = T.alloc_local([1], "uint32") rank_y = T.alloc_local([1], "uint32") next_rank = T.alloc_local([1], "uint32") - from_dir = T.alloc_local([1], "uint32") to_dir = T.alloc_local([1], "uint32") diff = T.alloc_local([1], "int32") dst_rank = T.alloc_local([1], "uint32") @@ -58,7 +56,6 @@ def main( next_rank[0] = 0 dst_rank[0] = bx - from_dir[0] = by # Prepare for routing dst_rank_x = T.floordiv(dst_rank[0], Y) @@ -98,20 +95,20 @@ def main( to_dir[0] = Direction.EAST next_rank[0] = rank_x[0] * Y + T.floormod(rank_y[0] + 1, Y) - # Phase 1: Initial send from src to the target neighbor (choose one random 'by' is OK) - if dst_rank[0] != rank[0] and from_dir[0] == 0: + # Phase 1: Initial send from src to the target neighbor + if dst_rank[0] != rank[0]: T.put_block( T.address_of(src[dst_rank[0] * M, 0]), - T.address_of(buffer_direction[dst_rank[0], 0, 0]), + T.address_of(buffer_transfer[dst_rank[0], 0, 0]), M * N, next_rank[0], ) if tx == 0: T.st( - signal_direction[dst_rank[0]], + signal_transfer[dst_rank[0]], 1, - scope='sys', - sem="release", + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0] ) T.sync_threads() @@ -122,44 +119,44 @@ def main( # Signal values: 0 = no signal, 1 = data ready, 2 = termination signal with T.While(global_finish[0] < PE_num): if tx == 0: - T.wait_gt(signal_direction[dst_rank[0]], 0, scope=T.WaitScope.SYSTEM) + T.wait_gt(signal_transfer[dst_rank[0]], 0, scope=T.MemoryScope.SYSTEM) T.sync_threads() - if signal_direction[dst_rank[0]] == 1: + if signal_transfer[dst_rank[0]] == 1: # Only handle the transfer signal if to_dir[0] != Direction.SELF: T.put_block( - T.address_of(buffer_direction[dst_rank[0], 0, 0]), - T.address_of(buffer_direction[dst_rank[0], 0, 0]), + T.address_of(buffer_transfer[dst_rank[0], 0, 0]), + T.address_of(buffer_transfer[dst_rank[0], 0, 0]), M * N, next_rank[0], ) if tx == 0: T.st( - signal_direction[dst_rank[0]], + signal_transfer[dst_rank[0]], 1, - scope="sys", - sem="release", + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0], ) T.sync_threads() else: # Current rank is the real destination of this chunk of data, the real source rank is the buffer index - T.copy(buffer_direction[dst_rank[0], 0, 0], dst[dst_rank[0] * M, 0]) + T.copy(buffer_transfer[dst_rank[0], 0, 0], dst[dst_rank[0] * M, 0]) if tx == 0: old_local[0] = T.atom_add( local_finish[0], 1, - scope="gpu", - sem="release", + scope=T.MemoryScope.GPU, + sem=T.MemorySemantic.RELEASE, ) if old_local[0] + 2 == PE_num: for i in T.serial(PE_num): old_global[0] = T.atom_add_remote( global_finish[0], 1, - scope="sys", - sem="release", + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, dst_pe=i, ) if old_global[0] + 1 == PE_num: @@ -167,10 +164,10 @@ def main( for remote_pe in T.serial(PE_num): for dst_r in T.serial(PE_num): T.st( - signal_direction[dst_r], + signal_transfer[dst_r], 2, - scope="sys", - sem="release", + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, dst_pe=remote_pe, ) T.sync_threads() @@ -200,10 +197,10 @@ def run_torus_alltoall(local_rank, num_ranks, args): src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_(0, 1) dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() - buffer_direction = tilelang.tensor((PE_num, M, N), torch.float16, allocator=allocator).zero_() + buffer_transfer = tilelang.tensor((PE_num, M, N), torch.float16, allocator=allocator).zero_() # Signals for each buffer slot in each direction - signal_direction = tilelang.tensor((PE_num), torch.uint32, allocator=allocator).fill_(0) + signal_transfer = tilelang.tensor((PE_num), torch.uint32, allocator=allocator).fill_(0) local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() @@ -211,7 +208,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): torch.cuda.synchronize() dist.barrier(group_size) - kernel(src, dst, buffer_direction, signal_direction, local_finish, global_finish, barrier) + kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) torch.cuda.synchronize() dist.barrier(group_size) diff --git a/src/op/remote_copy.cc b/src/op/remote_copy.cc index dff1c98b1..75063139a 100644 --- a/src/op/remote_copy.cc +++ b/src/op/remote_copy.cc @@ -273,10 +273,14 @@ Stmt StOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { std::stringstream ss; // Map integers to enum literal strings + // 0: WEAK, 1: VOLATILE, 2: RELAXED, 3: ACQUIRE, 4: RELEASE, 5: ACQ_REL const char *sem_str[] = {"Semantic::WEAK", "Semantic::VOLATILE", - "Semantic::ACQUIRE", "Semantic::RELEASE", - "Semantic::RELAXED"}; - const char *scope_str[] = {"Scope::CTA", "Scope::GPU", "Scope::SYS"}; + "Semantic::RELAXED", "Semantic::ACQUIRE", + "Semantic::RELEASE", "Semantic::ACQ_REL"}; + const char *scope_str[] = {"Scope::CTA", "Scope::CLUSTER", "Scope::GPU", "Scope::SYS"}; + + ICHECK_LT(sem, 6); + ICHECK_LT(scope, 4); // Build function name: tl::st ss << "tl::st<" << sem_str[sem] << ", " << scope_str[scope] << ", " @@ -342,10 +346,14 @@ Stmt LdOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { std::stringstream ss; // Map integers to enum literal strings + // 0: WEAK, 1: VOLATILE, 2: RELAXED, 3: ACQUIRE, 4: RELEASE, 5: ACQ_REL const char *sem_str[] = {"Semantic::WEAK", "Semantic::VOLATILE", - "Semantic::ACQUIRE", "Semantic::RELEASE", - "Semantic::RELAXED"}; - const char *scope_str[] = {"Scope::CTA", "Scope::GPU", "Scope::SYS"}; + "Semantic::RELAXED", "Semantic::ACQUIRE", + "Semantic::RELEASE", "Semantic::ACQ_REL"}; + const char *scope_str[] = {"Scope::CTA", "Scope::CLUSTER", "Scope::GPU", "Scope::SYS"}; + + ICHECK_LT(sem, 6); + ICHECK_LT(scope, 4); // Build function name: tl::ld ss << "tl::ld<" << sem_str[sem] << ", " << scope_str[scope] << ", " @@ -409,8 +417,9 @@ Stmt AtomAddRemoteOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) c std::stringstream ss; // Map integers to semantic literal strings for PTX atom instruction - const char *sem_str[] = {"relaxed", "acquire", "release", "acq_rel"}; - const char *scope_str[] = {"gpu", "sys"}; + // Unified Mapping: 2: relaxed, 3: acquire, 4: release, 5: acq_rel + const char *sem_str[] = {"weak", "volatile", "relaxed", "acquire", "release", "acq_rel"}; + const char *scope_str[] = {"cta", "cluster", "gpu", "sys"}; // Unified: 2: gpu, 3: system (mapped below) // Build function name: tl::ptx_atom_add__ ss << "tl::ptx_atom_add_" << sem_str[sem] << "_" << scope_str[scope]; diff --git a/src/op/sync.cc b/src/op/sync.cc index e7e1448ba..a40f7f6de 100644 --- a/src/op/sync.cc +++ b/src/op/sync.cc @@ -51,9 +51,6 @@ TIR_DEFINE_TL_BUILTIN(wait_barrier_gpu) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(wait_eq).set_num_inputs(2).set_attr( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); - TIR_DEFINE_TL_BUILTIN(sync_barrier_gpu) .set_num_inputs(1) .set_attr("TCallEffectKind", @@ -137,12 +134,17 @@ PrimExpr BarrierBlocksOpNode::MakeLocalBarAddr(const LowerArgs &T) const { WaitOp::WaitOp(Array args, BufferMap vmap) { ObjectPtr node = make_object(); - node->relation = args[0].as()->value; + ICHECK_GE(args.size(), 4); + const auto *relation_node = args[0].as(); + ICHECK(relation_node) << "Wait relation must be an integer"; + node->relation = relation_node->value; node->addr = args[1]; node->expected = args[2]; node->peer = args[3]; - // scope parameter is optional, default to SYSTEM (1) for safety - node->scope = (args.size() > 4) ? args[4].as()->value : 1; + // scope parameter is optional, default to SYSTEM (3) for safety + node->scope = (args.size() > 4) ? args[4].as()->value : 3; + // semantic parameter is optional, default to ACQUIRE (3) for safety + node->semantic = (args.size() > 5) ? args[5].as()->value : 3; data_ = std::move(node); (void)vmap; } @@ -176,8 +178,10 @@ Stmt WaitOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { new_args.push_back(addr); } new_args.push_back(expected); - // Pass scope: 0=GPU (volatile), 1=SYSTEM (acquire.sys) + // Pass scope: 0=CTA, 1=CLUSTER, 2=GPU, 3=SYSTEM new_args.push_back(IntImm(DataType::Int(32), scope)); + // Pass semantic: 0=WEAK, 1=VOLATILE, 2=RELAXED, 3=ACQUIRE, 4=RELEASE, 5=ACQ_REL + new_args.push_back(IntImm(DataType::Int(32), semantic)); auto wait = Call(DataType::Handle(), builtin::call_extern(), new_args); return Evaluate(wait); @@ -201,7 +205,7 @@ TIR_REGISTER_TL_OP(BarrierBlocksOp, barrier_blocks) Integer(CallEffectKind::kOpaque)); TIR_REGISTER_TL_OP(WaitOp, wait) - .set_num_inputs(5) // relation, addr, expected, peer, scope (scope is optional, default=1) + .set_num_inputs(6) // relation, addr, expected, peer, scope, semantic .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/sync.h b/src/op/sync.h index 427749838..82d15eeaa 100644 --- a/src/op/sync.h +++ b/src/op/sync.h @@ -38,13 +38,6 @@ TVM_DLL const Op &arrive_barrier_gpu(); */ TVM_DLL const Op &wait_barrier_gpu(); -/*! - * \brief Wait until *addr == expected* for GPU-level synchronization - * void wait_eq(addr, expected) - */ - -TVM_DLL const Op &wait_eq(); - /*! * \brief TileOperatorNode for wait operation. * @@ -57,7 +50,8 @@ class WaitOpNode : public TileOperatorNode { PrimExpr expected; ///< The expected value to compare against. PrimExpr peer; ///< The peer to compare against. int relation; ///< The relation to compare against. - int scope; ///< Memory scope: 0=GPU (volatile), 1=SYSTEM (acquire.sys) + int scope; ///< Memory scope: 0=CTA, 1=CLUSTER, 2=GPU, 3=SYSTEM + int semantic; ///< Memory semantic: 0=WEAK, 1=VOLATILE, 2=RELAXED, 3=ACQUIRE, 4=RELEASE, 5=ACQ_REL bool is_distributed() const; @@ -77,13 +71,14 @@ class WaitOpNode : public TileOperatorNode { .def_ro("expected", &WaitOpNode::expected) .def_ro("peer", &WaitOpNode::peer) .def_ro("relation", &WaitOpNode::relation) - .def_ro("scope", &WaitOpNode::scope); + .def_ro("scope", &WaitOpNode::scope) + .def_ro("semantic", &WaitOpNode::semantic); } bool SEqualReduce(const WaitOpNode *other, SEqualReducer equal) const { return equal(addr, other->addr) && equal(expected, other->expected) && equal(peer, other->peer) && relation == other->relation && - scope == other->scope; + scope == other->scope && semantic == other->semantic; } void SHashReduce(SHashReducer hash_reduce) const { @@ -92,6 +87,7 @@ class WaitOpNode : public TileOperatorNode { hash_reduce(peer); hash_reduce(relation); hash_reduce(scope); + hash_reduce(semantic); } static constexpr bool _type_has_method_sequal_reduce = true; diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index b39e9b042..b80fb1c09 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1507,10 +1507,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::sync_grid())) { this->PrintIndent(); this->stream << "tl::sync_grid(" << this->PrintExpr(op->args[0]) << ");\n"; - } else if (op->op.same_as(tl::wait_eq())) { - this->PrintIndent(); - this->stream << "tl::wait_eq(" << this->PrintExpr(op->args[0]) << ", " - << this->PrintExpr(op->args[1]) << ");\n"; } else if (op->op.same_as(tl::atom_add())) { std::string func_name = "tl::ptx_atom_add_" + op->args[2].as()->value + "_" + diff --git a/src/tl_templates/cuda/ldst.h b/src/tl_templates/cuda/ldst.h index c875832eb..5bccbcb90 100644 --- a/src/tl_templates/cuda/ldst.h +++ b/src/tl_templates/cuda/ldst.h @@ -4,7 +4,7 @@ // Memory semantic and scope enums enum class Semantic { WEAK, VOLATILE, ACQUIRE, RELEASE, RELAXED }; -enum class Scope { CTA, GPU, SYS }; +enum class Scope { CTA, CLUSTER, GPU, SYS }; #ifndef TL_ALWAYS_FALSE_V_DEFINED #define TL_ALWAYS_FALSE_V_DEFINED diff --git a/src/tl_templates/cuda/sync.h b/src/tl_templates/cuda/sync.h index 298e97ab4..0bbb0a729 100644 --- a/src/tl_templates/cuda/sync.h +++ b/src/tl_templates/cuda/sync.h @@ -12,6 +12,22 @@ namespace tl { +enum class SyncScope { + CTA = 0, + CLUSTER = 1, + GPU = 2, + SYSTEM = 3 +}; + +enum class SyncSemantic { + WEAK = 0, + VOLATILE = 1, + RELAXED = 2, + ACQUIRE = 3, + RELEASE = 4, + ACQ_REL = 5 +}; + // Triggers a GPU trap for debugging TL_DEVICE void trap() { asm("trap;\n"); } @@ -68,12 +84,6 @@ TL_DEVICE int ld_acquire(const int *ptr) { return ret; } -TL_DEVICE int ld_acquire_sys(const int *ptr) { - int ret = 0; - asm volatile("ld.global.acquire.sys.b32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); - return ret; -} - // Initialize a GPU barrier template TL_DEVICE void init_barrier_gpu(uint32_t *barrier) { @@ -190,21 +200,35 @@ TL_DEVICE void barrier_blocks(int offset, int rank, int num_ranks) { #undef FINISHED_SUM_TAG } -// Memory scope for wait operations -enum class WaitScope { - GPU = 0, // Use ld.volatile.global - suitable for same-GPU synchronization - SYSTEM = 1 // Use ld.acquire.sys - required for cross-PE/NUMA synchronization -}; +using WaitScope = SyncScope; +using WaitSemantic = SyncSemantic; // Load with volatile semantics (GPU scope, faster but no cross-PE guarantees) template -TL_DEVICE T ld_wait_gpu(const T *ptr) { +TL_DEVICE T ld_wait_gpu(const T *ptr, WaitSemantic semantic) { + int ret = 0; if constexpr (std::is_same_v) { - return ld_volatile_global(ptr); + if (semantic == WaitSemantic::RELAXED) { + asm volatile("ld.global.relaxed.gpu.s32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + } else if (semantic == WaitSemantic::VOLATILE) { + asm volatile("ld.global.volatile.gpu.s32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + } else { + // Default to acquire + asm volatile("ld.global.acquire.gpu.s32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + } + return ret; } else if constexpr (std::is_same_v || std::is_same_v) { // Cast to int* for ld_volatile_global, then cast back const int *int_ptr = reinterpret_cast(ptr); - return static_cast(ld_volatile_global(int_ptr)); + if (semantic == WaitSemantic::RELAXED) { + asm volatile("ld.global.relaxed.gpu.u32 %0, [%1];\n" : "=r"(ret) : "l"(int_ptr)); + } else if (semantic == WaitSemantic::VOLATILE) { + asm volatile("ld.global.volatile.gpu.u32 %0, [%1];\n" : "=r"(ret) : "l"(int_ptr)); + } else { + // Default to acquire + asm volatile("ld.global.acquire.gpu.u32 %0, [%1];\n" : "=r"(ret) : "l"(int_ptr)); + } + return static_cast(ret); } else { return *reinterpret_cast(ptr); } @@ -212,11 +236,18 @@ TL_DEVICE T ld_wait_gpu(const T *ptr) { // Load with acquire.sys semantics (SYSTEM scope, required for proper cross-PE sync) template -TL_DEVICE T ld_wait_sys(const T *ptr) { +TL_DEVICE T ld_wait_sys(const T *ptr, WaitSemantic semantic) { if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { unsigned int ret = 0; - asm volatile("ld.global.acquire.sys.b32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + if (semantic == WaitSemantic::RELAXED) { + asm volatile("ld.global.relaxed.sys.s32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + } else if (semantic == WaitSemantic::VOLATILE) { + asm volatile("ld.global.volatile.sys.s32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + } else { + // Default to acquire + asm volatile("ld.global.acquire.sys.s32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + } return static_cast(ret); } else { // Fallback to volatile for other types @@ -224,87 +255,79 @@ TL_DEVICE T ld_wait_sys(const T *ptr) { } } -// Generic load dispatcher based on scope (runtime) -// Note: When scope is a compile-time constant, compiler should optimize away the branch +// Generic load dispatcher based on scope and semantic template - TL_DEVICE T ld_wait_generic(const T *ptr, WaitScope scope) { +TL_DEVICE T ld_wait_generic(const T *ptr, WaitScope scope, WaitSemantic semantic = WaitSemantic::ACQUIRE) { if (scope == WaitScope::SYSTEM) { - return ld_wait_sys(ptr); - } else { - return ld_wait_gpu(ptr); - } -} - -// Compile-time scope selection (better performance when scope is known at compile time) -template - TL_DEVICE T ld_wait_static(const T *ptr) { - if constexpr (scope == WaitScope::SYSTEM) { - return ld_wait_sys(ptr); + return ld_wait_sys(ptr, semantic); } else { - return ld_wait_gpu(ptr); + return ld_wait_gpu(ptr, semantic); } } -template TL_DEVICE void wait_eq(void *ptr, T val) { +template +TL_DEVICE void wait_eq(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { + static_assert(std::is_same_v || std::is_pointer_v

, + "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); // Spin-loop #pragma unroll 1 - while (ld_acquire(flag_ptr) != val) + while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) != val) ; } template -TL_DEVICE void wait_ne(P ptr, T val, WaitScope scope = WaitScope::SYSTEM) { +TL_DEVICE void wait_ne(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); -// Spin-loop: use SYSTEM scope (acquire.sys) for cross-PE sync, GPU scope (volatile) for same-GPU +// Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, scope) == val) + while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) == val) ; } template -TL_DEVICE void wait_ge(P ptr, T val, WaitScope scope = WaitScope::SYSTEM) { +TL_DEVICE void wait_ge(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); -// Spin-loop: use SYSTEM scope (acquire.sys) for cross-PE sync, GPU scope (volatile) for same-GPU +// Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, scope) < val) + while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) < val) ; } template -TL_DEVICE void wait_le(P ptr, T val, WaitScope scope = WaitScope::SYSTEM) { +TL_DEVICE void wait_le(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); -// Spin-loop: use SYSTEM scope (acquire.sys) for cross-PE sync, GPU scope (volatile) for same-GPU +// Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, scope) > val) + while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) > val) ; } template -TL_DEVICE void wait_gt(P ptr, T val, WaitScope scope = WaitScope::SYSTEM) { +TL_DEVICE void wait_gt(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); -// Spin-loop: use SYSTEM scope (acquire.sys) for cross-PE sync, GPU scope (volatile) for same-GPU +// Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, scope) <= val) + while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) <= val) ; } template -TL_DEVICE void wait_lt(P ptr, T val, WaitScope scope = WaitScope::SYSTEM) { +TL_DEVICE void wait_lt(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); -// Spin-loop: use SYSTEM scope (acquire.sys) for cross-PE sync, GPU scope (volatile) for same-GPU +// Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, scope) >= val) + while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) >= val) ; } diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 104b46c79..8ff89f773 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -138,7 +138,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor { role = Role::kProducer; has_bulk_copy_ = true; } - if (call->op.same_as(loop_break()) || call->op.same_as(wait_eq())) + if (call->op.same_as(loop_break()) || call->op.same_as(WaitOp::Get())) role = Role::kBoth; if (call->op.same_as(get_clock())) role = Role::kBoth; diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index ec71c5d60..1b157c650 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -3,6 +3,7 @@ from tilelang import tvm as tvm from tilelang.language import ptx_arrive_barrier, evaluate, address_of +from tilelang.language.utils import MemoryScope, MemorySemantic from tilelang.language.kernel import get_thread_bindings, get_block_extents from tilelang.utils.target import check_hip_availability from tvm import tir @@ -729,21 +730,20 @@ def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call): return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) -def atom_add(barrier: PrimExpr, value: PrimExpr, scope: str = "gpu", sem: str = "relaxed"): +def atom_add(barrier: PrimExpr, value: PrimExpr, scope: MemoryScope = MemoryScope.GPU, sem: MemorySemantic = MemorySemantic.RELAXED): """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. """ - assert scope in ["gpu", "sys"], "Scope must be one of 'gpu', or 'sys'." - assert sem in ["relaxed", "acquire", "release", "acq_rel" - ], "Semantic must be one of 'relaxed', 'acquire', 'release', or 'acq_rel'." - return tir.call_intrin("uint32", tir.op.Op.get("tl.atom_add"), address_of(barrier), value, sem, - scope) + scope_str = {MemoryScope.GPU: "gpu", MemoryScope.SYSTEM: "sys"}[scope] + sem_str = {MemorySemantic.RELAXED: "relaxed", MemorySemantic.ACQUIRE: "acquire", MemorySemantic.RELEASE: "release", MemorySemantic.ACQ_REL: "acq_rel"}[sem] + return tir.call_intrin("uint32", tir.op.Op.get("tl.atom_add"), address_of(barrier), value, sem_str, + scope_str) def ld( src: PrimExpr, value: PrimExpr, - scope: Literal["cta", "gpu", "sys"] = "gpu", - sem: Literal["weak", "volatile", "acquire", "release", "relaxed"] = "weak", + scope: MemoryScope = MemoryScope.GPU, + sem: MemorySemantic = MemorySemantic.WEAK, na: bool = False, nc: bool = False, src_pe: tir.PrimExpr | tir.IntImm | None = -1, @@ -763,23 +763,17 @@ def ld( Returns: tir.Call: A handle to the load operation. """ - assert scope in ["cta", "gpu", "sys"], "Scope must be one of 'cta', 'gpu', or 'sys'." - assert sem in [ - "weak", "volatile", "acquire", "relaxed" - ], "Semantic must be one of 'weak', 'volatile', 'acquire', 'release', or 'relaxed'." - scope = {"cta": 0, "gpu": 1, "sys": 2}[scope] - sem = {"weak": 0, "volatile": 1, "acquire": 2, "release": 3, "relaxed": 4}[sem] na = 1 if na else 0 nc = 1 if nc else 0 - return tir.call_intrin("handle", tir.op.Op.get("tl.ld"), address_of(src), value, sem, scope, na, + return tir.call_intrin("handle", tir.op.Op.get("tl.ld"), address_of(src), value, sem.value, scope.value, na, nc, src_pe) def st( dst: PrimExpr, value: PrimExpr, - scope: Literal["cta", "gpu", "sys"] = "gpu", - sem: Literal["weak", "volatile", "release", "relaxed"] = "weak", + scope: MemoryScope = MemoryScope.GPU, + sem: MemorySemantic = MemorySemantic.WEAK, na: bool = False, dst_pe: tir.PrimExpr | tir.IntImm | None = -1, ): @@ -797,23 +791,16 @@ def st( Returns: tir.Call: A handle to the store operation. """ - assert scope in ["cta", "gpu", "sys"], "Scope must be one of 'cta', 'gpu', or 'sys'." - assert sem in ["weak", "volatile", "release", "relaxed" - ], "Semantic must be one of 'weak', 'volatile', 'release', or 'relaxed'." - - # convert to int - scope = {"cta": 0, "gpu": 1, "sys": 2}[scope] - sem = {"weak": 0, "volatile": 1, "acquire": 2, "release": 3, "relaxed": 4}[sem] na = 1 if na else 0 - return tir.call_intrin("handle", tir.op.Op.get("tl.st"), address_of(dst), value, sem, scope, na, + return tir.call_intrin("handle", tir.op.Op.get("tl.st"), address_of(dst), value, sem.value, scope.value, na, dst_pe) def atom_add_remote( dst: PrimExpr, value: PrimExpr, - scope: Literal["gpu", "sys"] = "sys", - sem: Literal["relaxed", "acquire", "release", "acq_rel"] = "relaxed", + scope: MemoryScope = MemoryScope.SYSTEM, + sem: MemorySemantic = MemorySemantic.RELAXED, dst_pe: tir.PrimExpr | tir.IntImm | None = -1, ): """Perform a remote atomic add operation with return value support @@ -829,12 +816,11 @@ def atom_add_remote( Returns: tir.Call: Returns the old value before the atomic add (uint32). """ - assert scope in ["gpu", "sys"], "Scope must be one of 'gpu', or 'sys'." - assert sem in ["relaxed", "acquire", "release", "acq_rel" - ], "Semantic must be one of 'relaxed', 'acquire', 'release', or 'acq_rel'." + scope_str = {MemoryScope.GPU: "gpu", MemoryScope.SYSTEM: "sys"}[scope] + sem_str = {MemorySemantic.RELAXED: "relaxed", MemorySemantic.ACQUIRE: "acquire", MemorySemantic.RELEASE: "release", MemorySemantic.ACQ_REL: "acq_rel"}[sem] # Build the intrinsic function name - func_name = f"tl::ptx_atom_add_{sem}_{scope}" + func_name = f"tl::ptx_atom_add_{sem_str}_{scope_str}" # If dst_pe is specified and not -1, compute remote address is_remote = not (isinstance(dst_pe, tir.IntImm) and dst_pe.value == -1) diff --git a/tilelang/language/distributed/common.py b/tilelang/language/distributed/common.py index 2be420e9a..543bc4cb7 100644 --- a/tilelang/language/distributed/common.py +++ b/tilelang/language/distributed/common.py @@ -5,6 +5,7 @@ from tvm.tir import address_of from tvm.tir import PrimExpr, IntImm from enum import Enum +from tilelang.language.utils import MemoryScope, MemorySemantic def get_rank(): @@ -122,27 +123,29 @@ class BinaryRelation(Enum): LT = 5 -class WaitScope(Enum): - """Memory scope for wait operations. - - - GPU: Uses ld.volatile.global (faster, suitable for same-GPU synchronization) - - SYSTEM: Uses ld.acquire.sys (required for cross-PE/NUMA synchronization with st.release.sys) - """ - GPU = 0 - SYSTEM = 1 - - -def wait_eq(barrier: PrimExpr, expected: PrimExpr): +def wait_eq(barrier: PrimExpr, + expected: PrimExpr, + peer: PrimExpr | None = -1, + scope: MemoryScope = MemoryScope.SYSTEM, + semantic: MemorySemantic = MemorySemantic.ACQUIRE): """Wait until *barrier == expected* for GPU-level synchronization. # todo: have different semantic compared to 3 fns below currently Args: barrier: The barrier to wait at expected: The expected value to wait for + peer: The PE to wait on (-1 for local) + scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) """ - return tir.call_intrin("handle", tir.op.Op.get("tl.wait_eq"), address_of(barrier), expected) + return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.EQ.value, + address_of(barrier), expected, peer, scope.value, semantic.value) -def wait_ne(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: WaitScope = WaitScope.SYSTEM): +def wait_ne(ptr: PrimExpr, + expected: PrimExpr, + peer: PrimExpr | None = -1, + scope: MemoryScope = MemoryScope.SYSTEM, + semantic: MemorySemantic = MemorySemantic.ACQUIRE): """Wait until *ptr != expected Args: @@ -150,12 +153,17 @@ def wait_ne(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope expected: The value to compare against peer: The PE to wait on (-1 for local) scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) """ return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.NE.value, - address_of(ptr), expected, peer, scope.value) + address_of(ptr), expected, peer, scope.value, semantic.value) -def wait_ge(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: WaitScope = WaitScope.SYSTEM): +def wait_ge(ptr: PrimExpr, + expected: PrimExpr, + peer: PrimExpr | None = -1, + scope: MemoryScope = MemoryScope.SYSTEM, + semantic: MemorySemantic = MemorySemantic.ACQUIRE): """Wait until *ptr >= expected Args: @@ -163,12 +171,17 @@ def wait_ge(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope expected: The value to compare against peer: The PE to wait on (-1 for local) scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) """ return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.GE.value, - address_of(ptr), expected, peer, scope.value) + address_of(ptr), expected, peer, scope.value, semantic.value) -def wait_le(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: WaitScope = WaitScope.SYSTEM): +def wait_le(ptr: PrimExpr, + expected: PrimExpr, + peer: PrimExpr | None = -1, + scope: MemoryScope = MemoryScope.SYSTEM, + semantic: MemorySemantic = MemorySemantic.ACQUIRE): """Wait until *ptr <= expected Args: @@ -176,12 +189,17 @@ def wait_le(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope expected: The value to compare against peer: The PE to wait on (-1 for local) scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) """ return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.LE.value, - address_of(ptr), expected, peer, scope.value) + address_of(ptr), expected, peer, scope.value, semantic.value) -def wait_gt(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: WaitScope = WaitScope.SYSTEM): +def wait_gt(ptr: PrimExpr, + expected: PrimExpr, + peer: PrimExpr | None = -1, + scope: MemoryScope = MemoryScope.SYSTEM, + semantic: MemorySemantic = MemorySemantic.ACQUIRE): """Wait until *ptr > expected Args: @@ -189,12 +207,17 @@ def wait_gt(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope expected: The value to compare against peer: The PE to wait on (-1 for local) scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) """ return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.GT.value, - address_of(ptr), expected, peer, scope.value) + address_of(ptr), expected, peer, scope.value, semantic.value) -def wait_lt(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: WaitScope = WaitScope.SYSTEM): +def wait_lt(ptr: PrimExpr, + expected: PrimExpr, + peer: PrimExpr | None = -1, + scope: MemoryScope = MemoryScope.SYSTEM, + semantic: MemorySemantic = MemorySemantic.ACQUIRE): """Wait until *ptr < expected Args: @@ -202,6 +225,7 @@ def wait_lt(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope expected: The value to compare against peer: The PE to wait on (-1 for local) scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) + semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) """ return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.LT.value, - address_of(ptr), expected, peer, scope.value) + address_of(ptr), expected, peer, scope.value, semantic.value) diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 161a09c45..979226d08 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -4,7 +4,7 @@ from tvm import tir from tvm.tir import PrimExpr, Buffer, BufferLoad, op from tilelang import language as T - +from enum import Enum def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): """ @@ -158,3 +158,35 @@ def linear_index(*args: PrimExpr) -> PrimExpr: for idx, stride in zip(coords[1:], strides): linear = linear * stride + idx return linear + + +class MemoryScope(Enum): + """Memory scope for wait operations. + + - CTA: Uses ld.volatile.cta (faster, suitable for same-CTA synchronization) + - CLUSTER: Uses ld.volatile.cluster (faster, suitable for same-cluster synchronization) + - GPU: Uses ld.volatile.gpu (faster, suitable for same-GPU synchronization) + - SYSTEM: Uses ld.acquire.sys (required for cross-PE/NUMA synchronization with st.release.sys) + """ + CTA = 0 + CLUSTER = 1 + GPU = 2 + SYSTEM = 3 + + +class MemorySemantic(Enum): + """Memory semantic for memory operations. + + - WEAK: Uses ld.weak (no synchronization) + - VOLATILE: Uses ld.volatile (faster, suitable for same-GPU synchronization) + - RELAXED: Uses ld.relaxed (faster, suitable for same-GPU synchronization) + - ACQUIRE: Uses ld.acquire (slower, suitable for cross-PE/NUMA synchronization with st.release) + - RELEASE: Uses ld.release (slower, suitable for cross-PE/NUMA synchronization with st.acquire) + - ACQ_REL: Uses ld.acq_rel (both acquire and release semantics) + """ + WEAK = 0 + VOLATILE = 1 + RELAXED = 2 + ACQUIRE = 3 + RELEASE = 4 + ACQ_REL = 5 From 6f072ab9f9ae5e404f08c03e326f5c00da9162f9 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 4 Feb 2026 09:23:02 +0000 Subject: [PATCH 07/28] [Misc] Add fence options --- src/op/sync.cc | 6 ++-- src/target/codegen_cuda.cc | 6 ++-- src/tl_templates/cuda/sync.h | 62 +++++++++++++++++++++++++++++------- tilelang/language/builtin.py | 12 +++---- tilelang/language/utils.py | 1 + 5 files changed, 64 insertions(+), 23 deletions(-) diff --git a/src/op/sync.cc b/src/op/sync.cc index a40f7f6de..7fd5c725d 100644 --- a/src/op/sync.cc +++ b/src/op/sync.cc @@ -209,13 +209,13 @@ TIR_REGISTER_TL_OP(WaitOp, wait) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(fence_cta).set_num_inputs(0).set_attr( +TIR_DEFINE_TL_BUILTIN(fence_cta).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(fence_gpu).set_num_inputs(0).set_attr( +TIR_DEFINE_TL_BUILTIN(fence_gpu).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(fence_sys).set_num_inputs(0).set_attr( +TIR_DEFINE_TL_BUILTIN(fence_sys).set_num_inputs(1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); TVM_FFI_STATIC_INIT_BLOCK({ BarrierBlocksOpNode::RegisterReflection(); }); diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index b80fb1c09..333048b4b 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2103,13 +2103,13 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { os << "nvshmem_barrier_all()"; } else if (op->op.same_as(tl::fence_cta())) { this->use_distributed_ = true; - os << "tl::memory_fence_cta()"; + os << "tl::memory_fence_cta(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::fence_gpu())) { this->use_distributed_ = true; - os << "tl::memory_fence_gpu()"; + os << "tl::memory_fence_gpu(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::fence_sys())) { this->use_distributed_ = true; - os << "tl::memory_fence_sys()"; + os << "tl::memory_fence_sys(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(builtin::reinterpret())) { DataType tgt_dtype = op->dtype; DataType src_dtype = op->args[0]->dtype; diff --git a/src/tl_templates/cuda/sync.h b/src/tl_templates/cuda/sync.h index 0bbb0a729..8f485488a 100644 --- a/src/tl_templates/cuda/sync.h +++ b/src/tl_templates/cuda/sync.h @@ -25,25 +25,65 @@ enum class SyncSemantic { RELAXED = 2, ACQUIRE = 3, RELEASE = 4, - ACQ_REL = 5 + ACQ_REL = 5, + SC = 6 }; // Triggers a GPU trap for debugging TL_DEVICE void trap() { asm("trap;\n"); } // CTA-level memory fence -TL_DEVICE void memory_fence_cta() { - asm volatile("fence.acq_rel.cta;\n" ::: "memory"); +TL_DEVICE void memory_fence_cta(int sem) { + switch (sem) { + case static_cast(SyncSemantic::ACQUIRE): + asm volatile("fence.acquire.cta;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::RELEASE): + asm volatile("fence.release.cta;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::SC): + asm volatile("fence.sc.cta;\n" ::: "memory"); + break; + default: + asm volatile("fence.acq_rel.cta;\n" ::: "memory"); + break; + } } // GPU-level memory fence -TL_DEVICE void memory_fence_gpu() { - asm volatile("fence.acq_rel.gpu;\n" ::: "memory"); +TL_DEVICE void memory_fence_gpu(int sem) { + switch (sem) { + case static_cast(SyncSemantic::ACQUIRE): + asm volatile("fence.acquire.gpu;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::RELEASE): + asm volatile("fence.release.gpu;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::SC): + asm volatile("fence.sc.gpu;\n" ::: "memory"); + break; + default: + asm volatile("fence.acq_rel.gpu;\n" ::: "memory"); + break; + } } // System-level memory fence -TL_DEVICE void memory_fence_sys() { - asm volatile("fence.acq_rel.sys;\n" ::: "memory"); +TL_DEVICE void memory_fence_sys(int sem) { + switch (sem) { + case static_cast(SyncSemantic::ACQUIRE): + asm volatile("fence.acquire.sys;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::RELEASE): + asm volatile("fence.release.sys;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::SC): + asm volatile("fence.sc.sys;\n" ::: "memory"); + break; + default: + asm volatile("fence.acq_rel.sys;\n" ::: "memory"); + break; + } } // GPU-level load with acquire semantics @@ -90,12 +130,12 @@ TL_DEVICE void init_barrier_gpu(uint32_t *barrier) { if (IS_MASTER_BLOCK() && IS_MASTER_THREAD()) { *barrier = BARRIER_MAGIC - kExpected; } - memory_fence_gpu(); // TODO: Is fence or sync needed here? + memory_fence_gpu(static_cast(SyncSemantic::ACQ_REL)); // TODO: Is fence or sync needed here? } // Arrive at a GPU barrier (atomic increment) TL_DEVICE void arrive_barrier_gpu(uint32_t *barrier) { - memory_fence_gpu(); + memory_fence_gpu(static_cast(SyncSemantic::ACQ_REL)); if (IS_MASTER_THREAD()) { atomic_add_release_gpu_u32(barrier, 1); } @@ -114,7 +154,7 @@ TL_DEVICE void wait_barrier_gpu(uint32_t *barrier) { // Synchronize at a GPU barrier (arrive + wait) TL_DEVICE void sync_barrier_gpu(uint32_t *barrier) { - // memory_fence_gpu(); + memory_fence_gpu(static_cast(SyncSemantic::ACQ_REL)); __syncthreads(); if (IS_MASTER_THREAD()) { atomic_add_release_gpu_u32(barrier, 1); @@ -177,7 +217,7 @@ TL_DEVICE void barrier_blocks(int offset, int rank, int num_ranks) { #define FINISHED_SUM_TAG (1024) if constexpr (need_fence) { - memory_fence_sys(); + memory_fence_sys(static_cast(SyncSemantic::ACQ_REL)); __syncthreads(); } diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 1b157c650..922526e48 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -635,19 +635,19 @@ def sync_blocks(barrier: PrimExpr): 0) # whether need fence -def fence_cta(): +def fence_cta(sem: MemorySemantic = MemorySemantic.ACQ_REL): """Create a memory fence at the block level (visible to all threads in the current block).""" - return tir.call_intrin("handle", tir.op.Op.get("tl.fence_cta")) + return tir.call_intrin("handle", tir.op.Op.get("tl.fence_cta"), sem.value) -def fence_gpu(): +def fence_gpu(sem: MemorySemantic = MemorySemantic.ACQ_REL): """Synchronize all threads at the GPU level (visible to all blocks on the current device).""" - return tir.call_intrin("handle", tir.op.Op.get("tl.fence_gpu")) + return tir.call_intrin("handle", tir.op.Op.get("tl.fence_gpu"), sem.value) -def fence_sys(): +def fence_sys(sem: MemorySemantic = MemorySemantic.ACQ_REL): """Synchronize all threads at the system level (visible in a node).""" - return tir.call_intrin("handle", tir.op.Op.get("tl.fence_sys")) + return tir.call_intrin("handle", tir.op.Op.get("tl.fence_sys"), sem.value) def get_clock(): diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 979226d08..8e0170d3c 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -190,3 +190,4 @@ class MemorySemantic(Enum): ACQUIRE = 3 RELEASE = 4 ACQ_REL = 5 + SC = 6 From 268f54a2bcd45490edb0b855f7b7e56c34c10557 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 4 Feb 2026 11:44:39 +0000 Subject: [PATCH 08/28] [BugFix] Intermediate buffer for each path --- .../intranode/example_alltoall_route2x4.py | 108 +++++++++++------- 1 file changed, 64 insertions(+), 44 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index a00022923..a87649b44 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -27,17 +27,17 @@ def main( # For each (src, dst) pair, the real transfer size is M * N src: T.Tensor((PE_num * M, N), "float16"), dst: T.Tensor((PE_num * M, N), "float16"), - # buffer[rank, *, *]: This PE save a slot for transferring data chunks for the real destination rank - buffer_transfer: T.Tensor((PE_num, M, N), "float16"), + # buffer[src_rank, dst_rank, *, *]: This PE save a slot for transferring data chunks for the real destination rank + buffer_transfer: T.Tensor((PE_num, PE_num, M, N), "float16"), # Signal for each buffer - signal_transfer: T.Tensor((PE_num), "uint32"), + signal_transfer: T.Tensor((PE_num, PE_num), "uint32"), # Signal for finish local_finish: T.Tensor((1), "uint32"), global_finish: T.Tensor((1), "uint32"), # Barrier for all blocks barrier: T.Tensor((PE_num), "int32"), ): - with T.Kernel(PE_num, threads=threads) as (bx): + with T.Kernel(PE_num, PE_num, T.ceildiv(M, block_M), threads=threads) as (bx, by, bz): tx = T.get_thread_binding() rank = T.alloc_local([1], "uint32") @@ -46,17 +46,20 @@ def main( next_rank = T.alloc_local([1], "uint32") to_dir = T.alloc_local([1], "uint32") diff = T.alloc_local([1], "int32") + src_rank = T.alloc_local([1], "uint32") dst_rank = T.alloc_local([1], "uint32") old_local = T.alloc_local([1], "uint32") old_global = T.alloc_local([1], "uint32") + num_block_M = T.alloc_local([1], "uint32") rank[0] = T.get_rank() rank_x[0] = T.floordiv(rank[0], Y) rank_y[0] = T.floormod(rank[0], Y) - next_rank[0] = 0 + next_rank[0] = rank[0] - dst_rank[0] = bx - + num_block_M[0] = T.ceildiv(M, block_M) + src_rank[0] = bx + dst_rank[0] = by # Prepare for routing dst_rank_x = T.floordiv(dst_rank[0], Y) dst_rank_y = T.floormod(dst_rank[0], Y) @@ -96,45 +99,54 @@ def main( next_rank[0] = rank_x[0] * Y + T.floormod(rank_y[0] + 1, Y) # Phase 1: Initial send from src to the target neighbor - if dst_rank[0] != rank[0]: - T.put_block( - T.address_of(src[dst_rank[0] * M, 0]), - T.address_of(buffer_transfer[dst_rank[0], 0, 0]), - M * N, - next_rank[0], - ) - if tx == 0: - T.st( - signal_transfer[dst_rank[0]], - 1, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=next_rank[0] + if src_rank[0] == rank[0]: + if dst_rank[0] != rank[0]: + T.put_block( + T.address_of(src[dst_rank[0] * M + bz * block_M, 0]), + T.address_of(buffer_transfer[rank[0], dst_rank[0], bz * block_M, 0]), + # T.address_of(dst[rank[0] * M + by * block_M, 0]), + block_M * N, + next_rank[0], + ) + if tx == 0: + T.st( + signal_transfer[rank[0], dst_rank[0]], + rank[0], + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0] + ) + T.sync_threads() + else: + T.put_block( + T.address_of(src[dst_rank[0] * M + bz * block_M, 0]), + T.address_of(dst[rank[0] * M + bz * block_M, 0]), + block_M * N, + -1, ) - T.sync_threads() T.barrier_blocks(barrier) # Phase 2: Each block handles one final dst data in one direction buffer of current rank and check whether to transfer - # Signal values: 0 = no signal, 1 = data ready, 2 = termination signal + # Signal values: represent the src_rank with T.While(global_finish[0] < PE_num): if tx == 0: - T.wait_gt(signal_transfer[dst_rank[0]], 0, scope=T.MemoryScope.SYSTEM) + T.wait_le(signal_transfer[bx, dst_rank[0]], PE_num, scope=T.MemoryScope.SYSTEM) T.sync_threads() - if signal_transfer[dst_rank[0]] == 1: + if signal_transfer[bx, dst_rank[0]] < PE_num: # Only handle the transfer signal if to_dir[0] != Direction.SELF: T.put_block( - T.address_of(buffer_transfer[dst_rank[0], 0, 0]), - T.address_of(buffer_transfer[dst_rank[0], 0, 0]), - M * N, + T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M, 0]), + T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M, 0]), + block_M * N, next_rank[0], ) if tx == 0: T.st( - signal_transfer[dst_rank[0]], - 1, + signal_transfer[bx, dst_rank[0]], + bx, scope=T.MemoryScope.SYSTEM, sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0], @@ -142,7 +154,12 @@ def main( T.sync_threads() else: # Current rank is the real destination of this chunk of data, the real source rank is the buffer index - T.copy(buffer_transfer[dst_rank[0], 0, 0], dst[dst_rank[0] * M, 0]) + T.put_block( + T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M, 0]), + T.address_of(dst[bx * M + bz * block_M, 0]), + block_M * N, + -1, + ) if tx == 0: old_local[0] = T.atom_add( local_finish[0], @@ -150,7 +167,7 @@ def main( scope=T.MemoryScope.GPU, sem=T.MemorySemantic.RELEASE, ) - if old_local[0] + 2 == PE_num: + if old_local[0] + 2 == PE_num * num_block_M[0]: for i in T.serial(PE_num): old_global[0] = T.atom_add_remote( global_finish[0], @@ -162,23 +179,26 @@ def main( if old_global[0] + 1 == PE_num: # Send termination signals to wake up all waiting blocks on all PEs for remote_pe in T.serial(PE_num): - for dst_r in T.serial(PE_num): - T.st( - signal_transfer[dst_r], - 2, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=remote_pe, - ) + for src_rank in T.serial(PE_num): + for dst_rank in T.serial(PE_num): + T.st( + signal_transfer[src_rank, dst_rank], + PE_num, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=remote_pe, + ) T.sync_threads() + T.barrier_blocks(barrier) + return main def run_torus_alltoall(local_rank, num_ranks, args): PE_num = args.PE_num X, Y = args.X, args.Y M, N = args.M, args.N - block_M, block_N = 16, N + block_M, block_N = M, N threads = 128 local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) @@ -194,13 +214,13 @@ def run_torus_alltoall(local_rank, num_ranks, args): kernel = torus_alltoall_xy(PE_num, X, Y, M, N, block_M, block_N, threads) kernel.initialize(allocator=allocator) - src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_(0, 1) + src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_() dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() - buffer_transfer = tilelang.tensor((PE_num, M, N), torch.float16, allocator=allocator).zero_() + buffer_transfer = tilelang.tensor((PE_num, PE_num, M, N), torch.float16, allocator=allocator).zero_() # Signals for each buffer slot in each direction - signal_transfer = tilelang.tensor((PE_num), torch.uint32, allocator=allocator).fill_(0) + signal_transfer = tilelang.tensor((PE_num, PE_num), torch.uint32, allocator=allocator).fill_(PE_num + 1) local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() From e8d036ac48681d5d35de625f6ad68e3bed0da090 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 4 Feb 2026 11:52:25 +0000 Subject: [PATCH 09/28] [Lint] Block_M for alltoall --- .../intranode/example_alltoall_route2x4.py | 75 ++++---- src/op/distributed.cc | 6 +- src/op/remote_copy.cc | 27 ++- src/op/remote_copy.h | 8 +- src/op/sync.cc | 5 +- src/op/sync.h | 3 +- src/tl_templates/cuda/sync.h | 179 ++++++++++-------- tilelang/language/builtin.py | 46 +++-- tilelang/language/utils.py | 1 + 9 files changed, 203 insertions(+), 147 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index a87649b44..c0e813690 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -8,6 +8,7 @@ tilelang.disable_cache() + class Direction(IntEnum): NORTH = 0 SOUTH = 1 @@ -16,30 +17,33 @@ class Direction(IntEnum): SELF = 4 -@tilelang.jit(pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, -},debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/") +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + # debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/") +) def torus_alltoall_xy(PE_num, X, Y, M, N, block_M, block_N, threads): @T.prim_func def main( - # For each (src, dst) pair, the real transfer size is M * N - src: T.Tensor((PE_num * M, N), "float16"), - dst: T.Tensor((PE_num * M, N), "float16"), - # buffer[src_rank, dst_rank, *, *]: This PE save a slot for transferring data chunks for the real destination rank - buffer_transfer: T.Tensor((PE_num, PE_num, M, N), "float16"), - # Signal for each buffer - signal_transfer: T.Tensor((PE_num, PE_num), "uint32"), - # Signal for finish - local_finish: T.Tensor((1), "uint32"), - global_finish: T.Tensor((1), "uint32"), - # Barrier for all blocks - barrier: T.Tensor((PE_num), "int32"), + # For each (src, dst) pair, the real transfer size is M * N + src: T.Tensor((PE_num * M, N), "float16"), + dst: T.Tensor((PE_num * M, N), "float16"), + # buffer[src_rank, dst_rank, *, *]: This PE save a slot for transferring data chunks for the real destination rank + buffer_transfer: T.Tensor((PE_num, PE_num, M, N), "float16"), + # Signal for each buffer + signal_transfer: T.Tensor((PE_num, PE_num), "uint32"), + # Signal for finish + local_finish: T.Tensor((1), "uint32"), + global_finish: T.Tensor((1), "uint32"), + # Barrier for all blocks + barrier: T.Tensor((PE_num), "int32"), ): with T.Kernel(PE_num, PE_num, T.ceildiv(M, block_M), threads=threads) as (bx, by, bz): tx = T.get_thread_binding() - + rank = T.alloc_local([1], "uint32") rank_x = T.alloc_local([1], "uint32") rank_y = T.alloc_local([1], "uint32") @@ -72,7 +76,7 @@ def main( diff[0] -= X elif diff[0] < -T.floordiv(X, 2): diff[0] += X - + if diff[0] < 0: # Send North (up): neighbor receives in its north buffer to_dir[0] = Direction.NORTH @@ -88,7 +92,7 @@ def main( diff[0] -= Y elif diff[0] < -T.floordiv(Y, 2): diff[0] += Y - + if diff[0] < 0: # Send West (left): neighbor receives in its west buffer to_dir[0] = Direction.WEST @@ -114,8 +118,7 @@ def main( rank[0], scope=T.MemoryScope.SYSTEM, sem=T.MemorySemantic.RELEASE, - dst_pe=next_rank[0] - ) + dst_pe=next_rank[0]) T.sync_threads() else: T.put_block( @@ -126,7 +129,7 @@ def main( ) T.barrier_blocks(barrier) - + # Phase 2: Each block handles one final dst data in one direction buffer of current rank and check whether to transfer # Signal values: represent the src_rank with T.While(global_finish[0] < PE_num): @@ -168,7 +171,7 @@ def main( sem=T.MemorySemantic.RELEASE, ) if old_local[0] + 2 == PE_num * num_block_M[0]: - for i in T.serial(PE_num): + for i in T.serial(PE_num): old_global[0] = T.atom_add_remote( global_finish[0], 1, @@ -194,6 +197,7 @@ def main( return main + def run_torus_alltoall(local_rank, num_ranks, args): PE_num = args.PE_num X, Y = args.X, args.Y @@ -210,31 +214,33 @@ def run_torus_alltoall(local_rank, num_ranks, args): num_local_ranks=num_ranks, group=group_size, ) - + kernel = torus_alltoall_xy(PE_num, X, Y, M, N, block_M, block_N, threads) kernel.initialize(allocator=allocator) - + src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_() dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() - - buffer_transfer = tilelang.tensor((PE_num, PE_num, M, N), torch.float16, allocator=allocator).zero_() - + + buffer_transfer = tilelang.tensor((PE_num, PE_num, M, N), torch.float16, + allocator=allocator).zero_() + # Signals for each buffer slot in each direction - signal_transfer = tilelang.tensor((PE_num, PE_num), torch.uint32, allocator=allocator).fill_(PE_num + 1) + signal_transfer = tilelang.tensor((PE_num, PE_num), torch.uint32, + allocator=allocator).fill_(PE_num + 1) local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() - + torch.cuda.synchronize() dist.barrier(group_size) - + kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) torch.cuda.synchronize() dist.barrier(group_size) - + print(f"Rank {local_rank} TileLang AllToAll XY Routing Finished.") - + dst_ref = torch.zeros((PE_num * M, N), dtype=torch.float16, device="cuda") dist.all_to_all_single(dst_ref, src, group=group_size) torch.cuda.synchronize() @@ -247,6 +253,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): dist.destroy_process_group() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--M", type=int, default=128) @@ -255,5 +262,5 @@ def run_torus_alltoall(local_rank, num_ranks, args): parser.add_argument("--X", type=int, default=2) parser.add_argument("--Y", type=int, default=4) args = parser.parse_args() - + torch.multiprocessing.spawn(run_torus_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) diff --git a/src/op/distributed.cc b/src/op/distributed.cc index 481c4b729..84a23afa7 100644 --- a/src/op/distributed.cc +++ b/src/op/distributed.cc @@ -194,10 +194,8 @@ TIR_DEFINE_TL_BUILTIN(CpengineCpAsync) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(get_rank) - .set_num_inputs(0) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(get_rank).set_num_inputs(0).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_TL_BUILTIN(get_num_ranks) .set_num_inputs(0) diff --git a/src/op/remote_copy.cc b/src/op/remote_copy.cc index 75063139a..7f0291f0f 100644 --- a/src/op/remote_copy.cc +++ b/src/op/remote_copy.cc @@ -274,10 +274,11 @@ Stmt StOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { // Map integers to enum literal strings // 0: WEAK, 1: VOLATILE, 2: RELAXED, 3: ACQUIRE, 4: RELEASE, 5: ACQ_REL - const char *sem_str[] = {"Semantic::WEAK", "Semantic::VOLATILE", + const char *sem_str[] = {"Semantic::WEAK", "Semantic::VOLATILE", "Semantic::RELAXED", "Semantic::ACQUIRE", "Semantic::RELEASE", "Semantic::ACQ_REL"}; - const char *scope_str[] = {"Scope::CTA", "Scope::CLUSTER", "Scope::GPU", "Scope::SYS"}; + const char *scope_str[] = {"Scope::CTA", "Scope::CLUSTER", "Scope::GPU", + "Scope::SYS"}; ICHECK_LT(sem, 6); ICHECK_LT(scope, 4); @@ -347,10 +348,11 @@ Stmt LdOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { // Map integers to enum literal strings // 0: WEAK, 1: VOLATILE, 2: RELAXED, 3: ACQUIRE, 4: RELEASE, 5: ACQ_REL - const char *sem_str[] = {"Semantic::WEAK", "Semantic::VOLATILE", + const char *sem_str[] = {"Semantic::WEAK", "Semantic::VOLATILE", "Semantic::RELAXED", "Semantic::ACQUIRE", "Semantic::RELEASE", "Semantic::ACQ_REL"}; - const char *scope_str[] = {"Scope::CTA", "Scope::CLUSTER", "Scope::GPU", "Scope::SYS"}; + const char *scope_str[] = {"Scope::CTA", "Scope::CLUSTER", "Scope::GPU", + "Scope::SYS"}; ICHECK_LT(sem, 6); ICHECK_LT(scope, 4); @@ -410,7 +412,8 @@ bool AtomAddRemoteOpNode::is_distributed() const { dst_pe.as()->value == -1); } -Stmt AtomAddRemoteOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt AtomAddRemoteOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { (void)analyzer; (void)T; Array new_args; @@ -418,8 +421,11 @@ Stmt AtomAddRemoteOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) c // Map integers to semantic literal strings for PTX atom instruction // Unified Mapping: 2: relaxed, 3: acquire, 4: release, 5: acq_rel - const char *sem_str[] = {"weak", "volatile", "relaxed", "acquire", "release", "acq_rel"}; - const char *scope_str[] = {"cta", "cluster", "gpu", "sys"}; // Unified: 2: gpu, 3: system (mapped below) + const char *sem_str[] = {"weak", "volatile", "relaxed", + "acquire", "release", "acq_rel"}; + const char *scope_str[] = { + "cta", "cluster", "gpu", + "sys"}; // Unified: 2: gpu, 3: system (mapped below) // Build function name: tl::ptx_atom_add__ ss << "tl::ptx_atom_add_" << sem_str[sem] << "_" << scope_str[scope]; @@ -471,8 +477,10 @@ TIR_REGISTER_TL_OP(StOp, st).set_num_inputs(6).set_attr( TIR_REGISTER_TL_OP(LdOp, ld).set_num_inputs(7).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_REGISTER_TL_OP(AtomAddRemoteOp, atom_add_remote).set_num_inputs(5).set_attr( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_REGISTER_TL_OP(AtomAddRemoteOp, atom_add_remote) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); TVM_FFI_STATIC_INIT_BLOCK({ PutOpNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK({ GetOpNode::RegisterReflection(); }); @@ -480,6 +488,5 @@ TVM_FFI_STATIC_INIT_BLOCK({ StOpNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK({ LdOpNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK({ AtomAddRemoteOpNode::RegisterReflection(); }); - } // namespace tl } // namespace tvm diff --git a/src/op/remote_copy.h b/src/op/remote_copy.h index 94645a89a..d93bb45e2 100644 --- a/src/op/remote_copy.h +++ b/src/op/remote_copy.h @@ -325,7 +325,7 @@ class AtomAddRemoteOpNode : public TileOperatorNode { PrimExpr value; ///< Value to atomically add PrimExpr dst_pe; ///< Destination processing element (optional) int scope; ///< Memory scope (0: GPU, 1: SYS) - int sem; ///< Memory semantic (0: relaxed, 1: acquire, 2: release, 3: acq_rel) + int sem; ///< Memory semantic (0: relaxed, 1: acquire, 2: release, 3: acq_rel) bool is_distributed() const; @@ -348,7 +348,8 @@ class AtomAddRemoteOpNode : public TileOperatorNode { .def_ro("sem", &AtomAddRemoteOpNode::sem); } - bool SEqualReduce(const AtomAddRemoteOpNode *other, SEqualReducer equal) const { + bool SEqualReduce(const AtomAddRemoteOpNode *other, + SEqualReducer equal) const { return equal(dst, other->dst) && equal(value, other->value) && equal(dst_pe, other->dst_pe) && scope == other->scope && sem == other->sem; @@ -368,7 +369,8 @@ class AtomAddRemoteOpNode : public TileOperatorNode { class AtomAddRemoteOp : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(AtomAddRemoteOp, TileOperator, AtomAddRemoteOpNode); + TVM_DEFINE_OBJECT_REF_METHODS(AtomAddRemoteOp, TileOperator, + AtomAddRemoteOpNode); TVM_DLL AtomAddRemoteOp(Array args, BufferMap vmap); static const Op &Get(); }; diff --git a/src/op/sync.cc b/src/op/sync.cc index 7fd5c725d..12b13f9ab 100644 --- a/src/op/sync.cc +++ b/src/op/sync.cc @@ -180,7 +180,8 @@ Stmt WaitOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { new_args.push_back(expected); // Pass scope: 0=CTA, 1=CLUSTER, 2=GPU, 3=SYSTEM new_args.push_back(IntImm(DataType::Int(32), scope)); - // Pass semantic: 0=WEAK, 1=VOLATILE, 2=RELAXED, 3=ACQUIRE, 4=RELEASE, 5=ACQ_REL + // Pass semantic: 0=WEAK, 1=VOLATILE, 2=RELAXED, 3=ACQUIRE, 4=RELEASE, + // 5=ACQ_REL new_args.push_back(IntImm(DataType::Int(32), semantic)); auto wait = Call(DataType::Handle(), builtin::call_extern(), new_args); @@ -205,7 +206,7 @@ TIR_REGISTER_TL_OP(BarrierBlocksOp, barrier_blocks) Integer(CallEffectKind::kOpaque)); TIR_REGISTER_TL_OP(WaitOp, wait) - .set_num_inputs(6) // relation, addr, expected, peer, scope, semantic + .set_num_inputs(6) // relation, addr, expected, peer, scope, semantic .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/sync.h b/src/op/sync.h index 82d15eeaa..93656cf30 100644 --- a/src/op/sync.h +++ b/src/op/sync.h @@ -51,7 +51,8 @@ class WaitOpNode : public TileOperatorNode { PrimExpr peer; ///< The peer to compare against. int relation; ///< The relation to compare against. int scope; ///< Memory scope: 0=CTA, 1=CLUSTER, 2=GPU, 3=SYSTEM - int semantic; ///< Memory semantic: 0=WEAK, 1=VOLATILE, 2=RELAXED, 3=ACQUIRE, 4=RELEASE, 5=ACQ_REL + int semantic; ///< Memory semantic: 0=WEAK, 1=VOLATILE, 2=RELAXED, 3=ACQUIRE, + ///< 4=RELEASE, 5=ACQ_REL bool is_distributed() const; diff --git a/src/tl_templates/cuda/sync.h b/src/tl_templates/cuda/sync.h index 8f485488a..0a6801b5b 100644 --- a/src/tl_templates/cuda/sync.h +++ b/src/tl_templates/cuda/sync.h @@ -12,12 +12,7 @@ namespace tl { -enum class SyncScope { - CTA = 0, - CLUSTER = 1, - GPU = 2, - SYSTEM = 3 -}; +enum class SyncScope { CTA = 0, CLUSTER = 1, GPU = 2, SYSTEM = 3 }; enum class SyncSemantic { WEAK = 0, @@ -35,54 +30,54 @@ TL_DEVICE void trap() { asm("trap;\n"); } // CTA-level memory fence TL_DEVICE void memory_fence_cta(int sem) { switch (sem) { - case static_cast(SyncSemantic::ACQUIRE): - asm volatile("fence.acquire.cta;\n" ::: "memory"); - break; - case static_cast(SyncSemantic::RELEASE): - asm volatile("fence.release.cta;\n" ::: "memory"); - break; - case static_cast(SyncSemantic::SC): - asm volatile("fence.sc.cta;\n" ::: "memory"); - break; - default: - asm volatile("fence.acq_rel.cta;\n" ::: "memory"); - break; + case static_cast(SyncSemantic::ACQUIRE): + asm volatile("fence.acquire.cta;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::RELEASE): + asm volatile("fence.release.cta;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::SC): + asm volatile("fence.sc.cta;\n" ::: "memory"); + break; + default: + asm volatile("fence.acq_rel.cta;\n" ::: "memory"); + break; } } // GPU-level memory fence TL_DEVICE void memory_fence_gpu(int sem) { switch (sem) { - case static_cast(SyncSemantic::ACQUIRE): - asm volatile("fence.acquire.gpu;\n" ::: "memory"); - break; - case static_cast(SyncSemantic::RELEASE): - asm volatile("fence.release.gpu;\n" ::: "memory"); - break; - case static_cast(SyncSemantic::SC): - asm volatile("fence.sc.gpu;\n" ::: "memory"); - break; - default: - asm volatile("fence.acq_rel.gpu;\n" ::: "memory"); - break; + case static_cast(SyncSemantic::ACQUIRE): + asm volatile("fence.acquire.gpu;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::RELEASE): + asm volatile("fence.release.gpu;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::SC): + asm volatile("fence.sc.gpu;\n" ::: "memory"); + break; + default: + asm volatile("fence.acq_rel.gpu;\n" ::: "memory"); + break; } } // System-level memory fence TL_DEVICE void memory_fence_sys(int sem) { switch (sem) { - case static_cast(SyncSemantic::ACQUIRE): - asm volatile("fence.acquire.sys;\n" ::: "memory"); - break; - case static_cast(SyncSemantic::RELEASE): - asm volatile("fence.release.sys;\n" ::: "memory"); - break; - case static_cast(SyncSemantic::SC): - asm volatile("fence.sc.sys;\n" ::: "memory"); - break; - default: - asm volatile("fence.acq_rel.sys;\n" ::: "memory"); - break; + case static_cast(SyncSemantic::ACQUIRE): + asm volatile("fence.acquire.sys;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::RELEASE): + asm volatile("fence.release.sys;\n" ::: "memory"); + break; + case static_cast(SyncSemantic::SC): + asm volatile("fence.sc.sys;\n" ::: "memory"); + break; + default: + asm volatile("fence.acq_rel.sys;\n" ::: "memory"); + break; } } @@ -130,7 +125,8 @@ TL_DEVICE void init_barrier_gpu(uint32_t *barrier) { if (IS_MASTER_BLOCK() && IS_MASTER_THREAD()) { *barrier = BARRIER_MAGIC - kExpected; } - memory_fence_gpu(static_cast(SyncSemantic::ACQ_REL)); // TODO: Is fence or sync needed here? + memory_fence_gpu(static_cast( + SyncSemantic::ACQ_REL)); // TODO: Is fence or sync needed here? } // Arrive at a GPU barrier (atomic increment) @@ -249,55 +245,76 @@ TL_DEVICE T ld_wait_gpu(const T *ptr, WaitSemantic semantic) { int ret = 0; if constexpr (std::is_same_v) { if (semantic == WaitSemantic::RELAXED) { - asm volatile("ld.global.relaxed.gpu.s32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + asm volatile("ld.global.relaxed.gpu.s32 %0, [%1];\n" + : "=r"(ret) + : "l"(ptr)); } else if (semantic == WaitSemantic::VOLATILE) { - asm volatile("ld.global.volatile.gpu.s32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + asm volatile("ld.global.volatile.gpu.s32 %0, [%1];\n" + : "=r"(ret) + : "l"(ptr)); } else { // Default to acquire - asm volatile("ld.global.acquire.gpu.s32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + asm volatile("ld.global.acquire.gpu.s32 %0, [%1];\n" + : "=r"(ret) + : "l"(ptr)); } return ret; - } else if constexpr (std::is_same_v || std::is_same_v) { + } else if constexpr (std::is_same_v || + std::is_same_v) { // Cast to int* for ld_volatile_global, then cast back const int *int_ptr = reinterpret_cast(ptr); if (semantic == WaitSemantic::RELAXED) { - asm volatile("ld.global.relaxed.gpu.u32 %0, [%1];\n" : "=r"(ret) : "l"(int_ptr)); + asm volatile("ld.global.relaxed.gpu.u32 %0, [%1];\n" + : "=r"(ret) + : "l"(int_ptr)); } else if (semantic == WaitSemantic::VOLATILE) { - asm volatile("ld.global.volatile.gpu.u32 %0, [%1];\n" : "=r"(ret) : "l"(int_ptr)); + asm volatile("ld.global.volatile.gpu.u32 %0, [%1];\n" + : "=r"(ret) + : "l"(int_ptr)); } else { // Default to acquire - asm volatile("ld.global.acquire.gpu.u32 %0, [%1];\n" : "=r"(ret) : "l"(int_ptr)); + asm volatile("ld.global.acquire.gpu.u32 %0, [%1];\n" + : "=r"(ret) + : "l"(int_ptr)); } return static_cast(ret); } else { - return *reinterpret_cast(ptr); + return *reinterpret_cast(ptr); } } -// Load with acquire.sys semantics (SYSTEM scope, required for proper cross-PE sync) +// Load with acquire.sys semantics (SYSTEM scope, required for proper cross-PE +// sync) template TL_DEVICE T ld_wait_sys(const T *ptr, WaitSemantic semantic) { - if constexpr (std::is_same_v || std::is_same_v || + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { unsigned int ret = 0; if (semantic == WaitSemantic::RELAXED) { - asm volatile("ld.global.relaxed.sys.s32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + asm volatile("ld.global.relaxed.sys.s32 %0, [%1];\n" + : "=r"(ret) + : "l"(ptr)); } else if (semantic == WaitSemantic::VOLATILE) { - asm volatile("ld.global.volatile.sys.s32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + asm volatile("ld.global.volatile.sys.s32 %0, [%1];\n" + : "=r"(ret) + : "l"(ptr)); } else { // Default to acquire - asm volatile("ld.global.acquire.sys.s32 %0, [%1];\n" : "=r"(ret) : "l"(ptr)); + asm volatile("ld.global.acquire.sys.s32 %0, [%1];\n" + : "=r"(ret) + : "l"(ptr)); } return static_cast(ret); } else { // Fallback to volatile for other types - return *reinterpret_cast(ptr); + return *reinterpret_cast(ptr); } } // Generic load dispatcher based on scope and semantic template -TL_DEVICE T ld_wait_generic(const T *ptr, WaitScope scope, WaitSemantic semantic = WaitSemantic::ACQUIRE) { +TL_DEVICE T ld_wait_generic(const T *ptr, WaitScope scope, + WaitSemantic semantic = WaitSemantic::ACQUIRE) { if (scope == WaitScope::SYSTEM) { return ld_wait_sys(ptr, semantic); } else { @@ -306,68 +323,80 @@ TL_DEVICE T ld_wait_generic(const T *ptr, WaitScope scope, WaitSemantic semantic } template -TL_DEVICE void wait_eq(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { +TL_DEVICE void wait_eq(P ptr, T val, int scope = (int)WaitScope::SYSTEM, + int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); // Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) != val) + while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) != + val) ; } -template -TL_DEVICE void wait_ne(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { +template +TL_DEVICE void wait_ne(P ptr, T val, int scope = (int)WaitScope::SYSTEM, + int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); // Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) == val) + while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) == + val) ; } -template -TL_DEVICE void wait_ge(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { +template +TL_DEVICE void wait_ge(P ptr, T val, int scope = (int)WaitScope::SYSTEM, + int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); // Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) < val) + while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) < + val) ; } -template -TL_DEVICE void wait_le(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { +template +TL_DEVICE void wait_le(P ptr, T val, int scope = (int)WaitScope::SYSTEM, + int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); // Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) > val) + while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) > + val) ; } -template -TL_DEVICE void wait_gt(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { +template +TL_DEVICE void wait_gt(P ptr, T val, int scope = (int)WaitScope::SYSTEM, + int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); // Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) <= val) + while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) <= + val) ; } -template -TL_DEVICE void wait_lt(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { +template +TL_DEVICE void wait_lt(P ptr, T val, int scope = (int)WaitScope::SYSTEM, + int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); // Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) >= val) + while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) >= + val) ; } diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 922526e48..e285bc655 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -7,7 +7,7 @@ from tilelang.language.kernel import get_thread_bindings, get_block_extents from tilelang.utils.target import check_hip_availability from tvm import tir -from typing import Any, Literal +from typing import Any import tilelang.language as T from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad @@ -730,13 +730,21 @@ def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call): return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) -def atom_add(barrier: PrimExpr, value: PrimExpr, scope: MemoryScope = MemoryScope.GPU, sem: MemorySemantic = MemorySemantic.RELAXED): +def atom_add(barrier: PrimExpr, + value: PrimExpr, + scope: MemoryScope = MemoryScope.GPU, + sem: MemorySemantic = MemorySemantic.RELAXED): """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. """ scope_str = {MemoryScope.GPU: "gpu", MemoryScope.SYSTEM: "sys"}[scope] - sem_str = {MemorySemantic.RELAXED: "relaxed", MemorySemantic.ACQUIRE: "acquire", MemorySemantic.RELEASE: "release", MemorySemantic.ACQ_REL: "acq_rel"}[sem] - return tir.call_intrin("uint32", tir.op.Op.get("tl.atom_add"), address_of(barrier), value, sem_str, - scope_str) + sem_str = { + MemorySemantic.RELAXED: "relaxed", + MemorySemantic.ACQUIRE: "acquire", + MemorySemantic.RELEASE: "release", + MemorySemantic.ACQ_REL: "acq_rel" + }[sem] + return tir.call_intrin("uint32", tir.op.Op.get("tl.atom_add"), address_of(barrier), value, + sem_str, scope_str) def ld( @@ -765,8 +773,8 @@ def ld( """ na = 1 if na else 0 nc = 1 if nc else 0 - return tir.call_intrin("handle", tir.op.Op.get("tl.ld"), address_of(src), value, sem.value, scope.value, na, - nc, src_pe) + return tir.call_intrin("handle", tir.op.Op.get("tl.ld"), address_of(src), value, sem.value, + scope.value, na, nc, src_pe) def st( @@ -792,8 +800,8 @@ def st( tir.Call: A handle to the store operation. """ na = 1 if na else 0 - return tir.call_intrin("handle", tir.op.Op.get("tl.st"), address_of(dst), value, sem.value, scope.value, na, - dst_pe) + return tir.call_intrin("handle", tir.op.Op.get("tl.st"), address_of(dst), value, sem.value, + scope.value, na, dst_pe) def atom_add_remote( @@ -817,26 +825,28 @@ def atom_add_remote( tir.Call: Returns the old value before the atomic add (uint32). """ scope_str = {MemoryScope.GPU: "gpu", MemoryScope.SYSTEM: "sys"}[scope] - sem_str = {MemorySemantic.RELAXED: "relaxed", MemorySemantic.ACQUIRE: "acquire", MemorySemantic.RELEASE: "release", MemorySemantic.ACQ_REL: "acq_rel"}[sem] - + sem_str = { + MemorySemantic.RELAXED: "relaxed", + MemorySemantic.ACQUIRE: "acquire", + MemorySemantic.RELEASE: "release", + MemorySemantic.ACQ_REL: "acq_rel" + }[sem] + # Build the intrinsic function name func_name = f"tl::ptx_atom_add_{sem_str}_{scope_str}" - + # If dst_pe is specified and not -1, compute remote address is_remote = not (isinstance(dst_pe, tir.IntImm) and dst_pe.value == -1) - + if is_remote: # Compute remote address: remote_base_ptr(dst_pe) + (address_of(dst) - remote_base_ptr(get_rank())) local_rank = tir.Call("int64", tir.op.Op.get("tl.get_rank"), []) local_base_ptr = tir.Call("handle", tir.op.Op.get("tl.get_remote_base_ptr"), [local_rank]) offset_to_base = tir.Sub( tir.Call("handle", tir.op.Op.get("tl.get_uintptr_t"), [address_of(dst)]), - local_base_ptr - ) + local_base_ptr) remote_ptr = tir.Add( - tir.Call("handle", tir.op.Op.get("tl.get_remote_base_ptr"), [dst_pe]), - offset_to_base - ) + tir.Call("handle", tir.op.Op.get("tl.get_remote_base_ptr"), [dst_pe]), offset_to_base) # Call the PTX intrinsic directly with remote address return tir.call_extern("uint32", func_name, remote_ptr, value) else: diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 8e0170d3c..11400d86d 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -6,6 +6,7 @@ from tilelang import language as T from enum import Enum + def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): """ Create a tile memory-region descriptor for a BufferLoad. From fc98be06bc5016b293c57a039d5a3df47da230a5 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 4 Feb 2026 11:53:04 +0000 Subject: [PATCH 10/28] [Lint] --- tilelang/language/distributed/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tilelang/language/distributed/common.py b/tilelang/language/distributed/common.py index 543bc4cb7..fcd773840 100644 --- a/tilelang/language/distributed/common.py +++ b/tilelang/language/distributed/common.py @@ -147,7 +147,7 @@ def wait_ne(ptr: PrimExpr, scope: MemoryScope = MemoryScope.SYSTEM, semantic: MemorySemantic = MemorySemantic.ACQUIRE): """Wait until *ptr != expected - + Args: ptr: The memory address to wait on expected: The value to compare against @@ -165,7 +165,7 @@ def wait_ge(ptr: PrimExpr, scope: MemoryScope = MemoryScope.SYSTEM, semantic: MemorySemantic = MemorySemantic.ACQUIRE): """Wait until *ptr >= expected - + Args: ptr: The memory address to wait on expected: The value to compare against @@ -183,7 +183,7 @@ def wait_le(ptr: PrimExpr, scope: MemoryScope = MemoryScope.SYSTEM, semantic: MemorySemantic = MemorySemantic.ACQUIRE): """Wait until *ptr <= expected - + Args: ptr: The memory address to wait on expected: The value to compare against @@ -201,7 +201,7 @@ def wait_gt(ptr: PrimExpr, scope: MemoryScope = MemoryScope.SYSTEM, semantic: MemorySemantic = MemorySemantic.ACQUIRE): """Wait until *ptr > expected - + Args: ptr: The memory address to wait on expected: The value to compare against @@ -219,7 +219,7 @@ def wait_lt(ptr: PrimExpr, scope: MemoryScope = MemoryScope.SYSTEM, semantic: MemorySemantic = MemorySemantic.ACQUIRE): """Wait until *ptr < expected - + Args: ptr: The memory address to wait on expected: The value to compare against From 8404728fabb5a0e52c5184d10ee5a380b959a692 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Wed, 4 Feb 2026 11:53:34 +0000 Subject: [PATCH 11/28] [Lint] --- tilelang/language/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 11400d86d..e0d90b354 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -163,7 +163,7 @@ def linear_index(*args: PrimExpr) -> PrimExpr: class MemoryScope(Enum): """Memory scope for wait operations. - + - CTA: Uses ld.volatile.cta (faster, suitable for same-CTA synchronization) - CLUSTER: Uses ld.volatile.cluster (faster, suitable for same-cluster synchronization) - GPU: Uses ld.volatile.gpu (faster, suitable for same-GPU synchronization) @@ -177,7 +177,7 @@ class MemoryScope(Enum): class MemorySemantic(Enum): """Memory semantic for memory operations. - + - WEAK: Uses ld.weak (no synchronization) - VOLATILE: Uses ld.volatile (faster, suitable for same-GPU synchronization) - RELAXED: Uses ld.relaxed (faster, suitable for same-GPU synchronization) From 7b00d85cb529b87bb694be729f261f1935d51db5 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 5 Feb 2026 08:02:34 +0000 Subject: [PATCH 12/28] [BugFix] Add fence for inner CTA memory op --- .../distributed/intranode/example_alltoall.py | 101 ++++++++++++++++++ .../intranode/example_alltoall_route2x4.py | 34 ++++-- src/tl_templates/cuda/ldst.h | 2 +- 3 files changed, 128 insertions(+), 9 deletions(-) create mode 100644 examples/distributed/intranode/example_alltoall.py diff --git a/examples/distributed/intranode/example_alltoall.py b/examples/distributed/intranode/example_alltoall.py new file mode 100644 index 000000000..87734cbfe --- /dev/null +++ b/examples/distributed/intranode/example_alltoall.py @@ -0,0 +1,101 @@ +import tilelang +import tilelang.language as T +from tilelang.distributed import init_dist +import torch +import torch.distributed as dist +import argparse + + +def alltoall(PE_num, M, N, block_M, block_N, threads): + assert block_N == N + + @T.prim_func + def main( + src: T.Tensor((PE_num * M, N), "float16"), + dst: T.Tensor((PE_num * M, N), "float16"), + barrier: T.Tensor((PE_num), "int32"), + ): + # Currently not support tiled copy + with T.Kernel(PE_num, T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by, bz): + tx = T.get_thread_binding(0) + rank = T.alloc_local([1], "int32") + num_ranks = T.alloc_local([1], "int32") + + dst_rank = bx + rank[0] = T.get_rank() + num_ranks[0] = T.get_num_ranks() + + T.put_block( + src=T.address_of(src[dst_rank * M + by * block_M, 0]), + dst=T.address_of(dst[rank[0] * M + by * block_M, 0]), + size=block_M * block_N, + dst_pe=dst_rank, + ) + T.fence_sys(sem=T.MemorySemantic.RELEASE) + + + return main + + +def run_alltoall(local_rank, num_ranks, args): + PE_num = args.PE_num + M = args.M + N = args.N + block_M = 32 + block_N = N + threads = 128 + + local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) + allocator = tilelang.get_allocator( + size=2**35, + device="cuda", + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_ranks, + group=group_size, + ) + kernel = tilelang.compile(alltoall(PE_num, M, N, block_M, block_N, threads)) + kernel.initialize(allocator=allocator) + src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_() + dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() + barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() + + torch.cuda.synchronize() + dist.barrier(group_size) + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + kernel(src, dst, barrier) + end.record() + torch.cuda.synchronize() + dist.barrier(group_size) + elapsed_time = start.elapsed_time(end) + print(f"Rank {local_rank} Kernel execution time: {elapsed_time:.3f} ms, Bandwidth: {2 * PE_num * M * N / (elapsed_time * 1e6):.3f} GB/s") + + # Torch Reference + torch.cuda.synchronize() + dst_ref = torch.zeros((PE_num * M, N), dtype=torch.float16, device="cuda") + dist.all_to_all_single(dst_ref, src, group=group_size) + torch.cuda.synchronize() + + # 对比结果 + if torch.allclose(dst, dst_ref, atol=1e-2, rtol=1e-2): + print(f"Rank {local_rank} Verification Passed! ✅") + else: + max_diff = (dst - dst_ref).abs().max() + print(f"Rank {local_rank} Verification Failed! ❌ Max diff: {max_diff}") + print(f"dst: {dst}") + print(f"dst_ref: {dst_ref}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--PE_num", type=int, default=8) + parser.add_argument("--M", type=int, default=8192) + parser.add_argument("--N", type=int, default=7168) + + args = parser.parse_args() + torch.multiprocessing.spawn(run_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) + \ No newline at end of file diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index c0e813690..3a71ef7f5 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -6,7 +6,7 @@ import argparse from enum import IntEnum -tilelang.disable_cache() +# tilelang.disable_cache() class Direction(IntEnum): @@ -22,7 +22,7 @@ class Direction(IntEnum): tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }, - # debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/") + debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" ) def torus_alltoall_xy(PE_num, X, Y, M, N, block_M, block_N, threads): @@ -64,6 +64,7 @@ def main( num_block_M[0] = T.ceildiv(M, block_M) src_rank[0] = bx dst_rank[0] = by + # Prepare for routing dst_rank_x = T.floordiv(dst_rank[0], Y) dst_rank_y = T.floormod(dst_rank[0], Y) @@ -108,7 +109,6 @@ def main( T.put_block( T.address_of(src[dst_rank[0] * M + bz * block_M, 0]), T.address_of(buffer_transfer[rank[0], dst_rank[0], bz * block_M, 0]), - # T.address_of(dst[rank[0] * M + by * block_M, 0]), block_M * N, next_rank[0], ) @@ -127,8 +127,10 @@ def main( block_M * N, -1, ) + T.fence_cta(sem=T.MemorySemantic.RELEASE) - T.barrier_blocks(barrier) + # T.barrier_blocks(barrier) + T.fence_sys(sem=T.MemorySemantic.RELEASE) # Phase 2: Each block handles one final dst data in one direction buffer of current rank and check whether to transfer # Signal values: represent the src_rank @@ -163,6 +165,7 @@ def main( block_M * N, -1, ) + T.fence_cta(sem=T.MemorySemantic.RELEASE) if tx == 0: old_local[0] = T.atom_add( local_finish[0], @@ -193,7 +196,8 @@ def main( ) T.sync_threads() - T.barrier_blocks(barrier) + # T.barrier_blocks(barrier) + T.fence_sys(sem=T.MemorySemantic.RELEASE) return main @@ -207,7 +211,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) allocator = tilelang.get_allocator( - size=2**32, + size=2**35, device="cuda", is_distributed=True, local_rank=local_rank, @@ -250,14 +254,28 @@ def run_torus_alltoall(local_rank, num_ranks, args): else: max_diff = (dst - dst_ref).abs().max() print(f"Rank {local_rank} Verification Failed! ❌ Max diff: {max_diff}") + # Find differences + diff_mask = (dst != dst_ref) + diff_count = diff_mask.sum().item() + + if diff_count > 0: + diff_indices = torch.nonzero(diff_mask, as_tuple=False) + print(f"Rank {local_rank} found {diff_count} differences at locations:") + # Print first few differences + num_to_print = min(10, diff_count) + for i in range(num_to_print): + idx = diff_indices[i] + print(f" Position {idx.tolist()}: dst={dst[idx[0], idx[1]].item():.6f}, dst_ref={dst_ref[idx[0], idx[1]].item():.6f}") + if diff_count > num_to_print: + print(f" ... and {diff_count - num_to_print} more differences") dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--M", type=int, default=128) - parser.add_argument("--N", type=int, default=128) + parser.add_argument("--M", type=int, default=8192) + parser.add_argument("--N", type=int, default=7168) parser.add_argument("--PE_num", type=int, default=8) parser.add_argument("--X", type=int, default=2) parser.add_argument("--Y", type=int, default=4) diff --git a/src/tl_templates/cuda/ldst.h b/src/tl_templates/cuda/ldst.h index 5bccbcb90..247027187 100644 --- a/src/tl_templates/cuda/ldst.h +++ b/src/tl_templates/cuda/ldst.h @@ -3,7 +3,7 @@ #include "common.h" // Memory semantic and scope enums -enum class Semantic { WEAK, VOLATILE, ACQUIRE, RELEASE, RELAXED }; +enum class Semantic { WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL }; enum class Scope { CTA, CLUSTER, GPU, SYS }; #ifndef TL_ALWAYS_FALSE_V_DEFINED From 067511c528d7be9ef9661cc24dab18a0dfa3edce Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 5 Feb 2026 08:04:23 +0000 Subject: [PATCH 13/28] [BugFix] Fence and debug --- examples/distributed/intranode/example_alltoall_route2x4.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index 3a71ef7f5..eadedf38f 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -22,7 +22,7 @@ class Direction(IntEnum): tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }, - debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" + # debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" ) def torus_alltoall_xy(PE_num, X, Y, M, N, block_M, block_N, threads): @@ -103,6 +103,8 @@ def main( to_dir[0] = Direction.EAST next_rank[0] = rank_x[0] * Y + T.floormod(rank_y[0] + 1, Y) + T.fence_sys(sem=T.MemorySemantic.RELEASE) + # Phase 1: Initial send from src to the target neighbor if src_rank[0] == rank[0]: if dst_rank[0] != rank[0]: @@ -226,7 +228,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() buffer_transfer = tilelang.tensor((PE_num, PE_num, M, N), torch.float16, - allocator=allocator).zero_() + allocator=allocator).fill_(-1) # Signals for each buffer slot in each direction signal_transfer = tilelang.tensor((PE_num, PE_num), torch.uint32, From 31a4643bf54defca88011c1ee6bf7729eecf97fe Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 5 Feb 2026 08:05:24 +0000 Subject: [PATCH 14/28] [Lint] --- .../distributed/intranode/example_alltoall.py | 17 +++++++++-------- .../intranode/example_alltoall_route2x4.py | 6 ++++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall.py b/examples/distributed/intranode/example_alltoall.py index 87734cbfe..839e41dcb 100644 --- a/examples/distributed/intranode/example_alltoall.py +++ b/examples/distributed/intranode/example_alltoall.py @@ -11,13 +11,14 @@ def alltoall(PE_num, M, N, block_M, block_N, threads): @T.prim_func def main( - src: T.Tensor((PE_num * M, N), "float16"), - dst: T.Tensor((PE_num * M, N), "float16"), - barrier: T.Tensor((PE_num), "int32"), + src: T.Tensor((PE_num * M, N), "float16"), + dst: T.Tensor((PE_num * M, N), "float16"), + barrier: T.Tensor((PE_num), "int32"), ): # Currently not support tiled copy - with T.Kernel(PE_num, T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by, bz): - tx = T.get_thread_binding(0) + with T.Kernel( + PE_num, T.ceildiv(M, block_M), T.ceildiv(N, block_N), + threads=threads) as (bx, by, bz): rank = T.alloc_local([1], "int32") num_ranks = T.alloc_local([1], "int32") @@ -32,7 +33,6 @@ def main( dst_pe=dst_rank, ) T.fence_sys(sem=T.MemorySemantic.RELEASE) - return main @@ -70,7 +70,9 @@ def run_alltoall(local_rank, num_ranks, args): torch.cuda.synchronize() dist.barrier(group_size) elapsed_time = start.elapsed_time(end) - print(f"Rank {local_rank} Kernel execution time: {elapsed_time:.3f} ms, Bandwidth: {2 * PE_num * M * N / (elapsed_time * 1e6):.3f} GB/s") + print( + f"Rank {local_rank} Kernel execution time: {elapsed_time:.3f} ms, Bandwidth: {2 * PE_num * M * N / (elapsed_time * 1e6):.3f} GB/s" + ) # Torch Reference torch.cuda.synchronize() @@ -98,4 +100,3 @@ def run_alltoall(local_rank, num_ranks, args): args = parser.parse_args() torch.multiprocessing.spawn(run_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) - \ No newline at end of file diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index eadedf38f..28e191e23 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -259,7 +259,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): # Find differences diff_mask = (dst != dst_ref) diff_count = diff_mask.sum().item() - + if diff_count > 0: diff_indices = torch.nonzero(diff_mask, as_tuple=False) print(f"Rank {local_rank} found {diff_count} differences at locations:") @@ -267,7 +267,9 @@ def run_torus_alltoall(local_rank, num_ranks, args): num_to_print = min(10, diff_count) for i in range(num_to_print): idx = diff_indices[i] - print(f" Position {idx.tolist()}: dst={dst[idx[0], idx[1]].item():.6f}, dst_ref={dst_ref[idx[0], idx[1]].item():.6f}") + print( + f" Position {idx.tolist()}: dst={dst[idx[0], idx[1]].item():.6f}, dst_ref={dst_ref[idx[0], idx[1]].item():.6f}" + ) if diff_count > num_to_print: print(f" ... and {diff_count - num_to_print} more differences") From 67065ab26210b99ae2d57c54d1e5479faf55b32c Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 5 Feb 2026 10:03:18 +0000 Subject: [PATCH 15/28] [BugFix] Restore the signal to avoid duplicated sum of finish barrier --- .../intranode/example_alltoall_route2x4.py | 88 +++++++++++++++---- 1 file changed, 71 insertions(+), 17 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index 28e191e23..35d8d31f1 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -25,6 +25,7 @@ class Direction(IntEnum): # debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" ) def torus_alltoall_xy(PE_num, X, Y, M, N, block_M, block_N, threads): + num_blocks_M = M // block_M @T.prim_func def main( @@ -34,7 +35,7 @@ def main( # buffer[src_rank, dst_rank, *, *]: This PE save a slot for transferring data chunks for the real destination rank buffer_transfer: T.Tensor((PE_num, PE_num, M, N), "float16"), # Signal for each buffer - signal_transfer: T.Tensor((PE_num, PE_num), "uint32"), + signal_transfer: T.Tensor((PE_num, PE_num, num_blocks_M), "uint32"), # Signal for finish local_finish: T.Tensor((1), "uint32"), global_finish: T.Tensor((1), "uint32"), @@ -103,7 +104,7 @@ def main( to_dir[0] = Direction.EAST next_rank[0] = rank_x[0] * Y + T.floormod(rank_y[0] + 1, Y) - T.fence_sys(sem=T.MemorySemantic.RELEASE) + T.fence_gpu(sem=T.MemorySemantic.RELEASE) # Phase 1: Initial send from src to the target neighbor if src_rank[0] == rank[0]: @@ -116,7 +117,7 @@ def main( ) if tx == 0: T.st( - signal_transfer[rank[0], dst_rank[0]], + signal_transfer[rank[0], dst_rank[0], bz], rank[0], scope=T.MemoryScope.SYSTEM, sem=T.MemorySemantic.RELEASE, @@ -129,7 +130,15 @@ def main( block_M * N, -1, ) + if tx == 0: + old_local[0] = T.atom_add( + local_finish[0], + 1, + scope=T.MemoryScope.GPU, + sem=T.MemorySemantic.RELEASE, + ) T.fence_cta(sem=T.MemorySemantic.RELEASE) + T.print(bz, msg="block_idx") # T.barrier_blocks(barrier) T.fence_sys(sem=T.MemorySemantic.RELEASE) @@ -138,12 +147,12 @@ def main( # Signal values: represent the src_rank with T.While(global_finish[0] < PE_num): if tx == 0: - T.wait_le(signal_transfer[bx, dst_rank[0]], PE_num, scope=T.MemoryScope.SYSTEM) + T.wait_le(signal_transfer[bx, dst_rank[0], bz], PE_num, scope=T.MemoryScope.SYSTEM) T.sync_threads() - if signal_transfer[bx, dst_rank[0]] < PE_num: + if signal_transfer[bx, dst_rank[0], bz] < PE_num: # Only handle the transfer signal - if to_dir[0] != Direction.SELF: + if dst_rank[0] != rank[0]: T.put_block( T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M, 0]), T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M, 0]), @@ -152,7 +161,7 @@ def main( ) if tx == 0: T.st( - signal_transfer[bx, dst_rank[0]], + signal_transfer[bx, dst_rank[0], bz], bx, scope=T.MemoryScope.SYSTEM, sem=T.MemorySemantic.RELEASE, @@ -175,7 +184,9 @@ def main( scope=T.MemoryScope.GPU, sem=T.MemorySemantic.RELEASE, ) - if old_local[0] + 2 == PE_num * num_block_M[0]: + # T.print(signal_transfer[bx, rank[0], bz], msg="signal_transfer") + # T.print(old_local[0], msg="old_local") + if old_local[0] + 1 == PE_num * num_block_M[0]: for i in T.serial(PE_num): old_global[0] = T.atom_add_remote( global_finish[0], @@ -189,14 +200,26 @@ def main( for remote_pe in T.serial(PE_num): for src_rank in T.serial(PE_num): for dst_rank in T.serial(PE_num): - T.st( - signal_transfer[src_rank, dst_rank], - PE_num, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=remote_pe, - ) + for dst_block in T.serial(num_blocks_M): + T.st( + signal_transfer[src_rank, dst_rank, dst_block], + PE_num, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=remote_pe, + ) T.sync_threads() + + # Restore the signal to avoid duplicated sum of finish barrier + if tx == 0: + T.st( + signal_transfer[bx, dst_rank[0], bz], + PE_num + 1, + scope=T.MemoryScope.GPU, + sem=T.MemorySemantic.RELEASE, + dst_pe=-1, + ) + T.sync_threads() # T.barrier_blocks(barrier) T.fence_sys(sem=T.MemorySemantic.RELEASE) @@ -208,8 +231,9 @@ def run_torus_alltoall(local_rank, num_ranks, args): PE_num = args.PE_num X, Y = args.X, args.Y M, N = args.M, args.N - block_M, block_N = M, N + block_M, block_N = M // 8, N threads = 128 + num_blocks_M = M // block_M local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) allocator = tilelang.get_allocator( @@ -231,7 +255,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): allocator=allocator).fill_(-1) # Signals for each buffer slot in each direction - signal_transfer = tilelang.tensor((PE_num, PE_num), torch.uint32, + signal_transfer = tilelang.tensor((PE_num, PE_num, num_blocks_M), torch.uint32, allocator=allocator).fill_(PE_num + 1) local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) @@ -273,6 +297,35 @@ def run_torus_alltoall(local_rank, num_ranks, args): if diff_count > num_to_print: print(f" ... and {diff_count - num_to_print} more differences") + if args.benchmark: + # Warmup + for _ in range(5): + kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + torch.cuda.synchronize() + dist.barrier(group_size) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + num_iters = 10 + start_event.record() + for _ in range(num_iters): + kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + end_event.record() + torch.cuda.synchronize() + dist.barrier(group_size) + + if local_rank == 0: + elapsed_time_ms = start_event.elapsed_time(end_event) / num_iters + # All-to-all total data moved: each rank sends (PE_num - 1) * M * N elements + # and receives (PE_num - 1) * M * N elements. + # For bandwidth calculation, we usually use the amount of data sent per rank. + total_data_bytes = (PE_num - 1) * M * N * 2 # float16 = 2 bytes + bandwidth_gbps = (total_data_bytes / 1e9) / (elapsed_time_ms / 1e3) + print(f"Benchmark Results:") + print(f" Average Latency: {elapsed_time_ms:.4f} ms") + print(f" Effective Bandwidth: {bandwidth_gbps:.4f} GB/s") + dist.destroy_process_group() @@ -283,6 +336,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): parser.add_argument("--PE_num", type=int, default=8) parser.add_argument("--X", type=int, default=2) parser.add_argument("--Y", type=int, default=4) + parser.add_argument("--benchmark", action="store_true", help="Run benchmark") args = parser.parse_args() torch.multiprocessing.spawn(run_torus_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) From e2d8ee39f3e205de80e18e3f7519d1e72c0cf1df Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Thu, 5 Feb 2026 11:17:01 +0000 Subject: [PATCH 16/28] [Lint] --- .../intranode/example_alltoall_route2x4.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index 35d8d31f1..1e8ffc077 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -147,7 +147,8 @@ def main( # Signal values: represent the src_rank with T.While(global_finish[0] < PE_num): if tx == 0: - T.wait_le(signal_transfer[bx, dst_rank[0], bz], PE_num, scope=T.MemoryScope.SYSTEM) + T.wait_le( + signal_transfer[bx, dst_rank[0], bz], PE_num, scope=T.MemoryScope.SYSTEM) T.sync_threads() if signal_transfer[bx, dst_rank[0], bz] < PE_num: @@ -202,14 +203,15 @@ def main( for dst_rank in T.serial(PE_num): for dst_block in T.serial(num_blocks_M): T.st( - signal_transfer[src_rank, dst_rank, dst_block], + signal_transfer[src_rank, dst_rank, + dst_block], PE_num, scope=T.MemoryScope.SYSTEM, sem=T.MemorySemantic.RELEASE, dst_pe=remote_pe, ) T.sync_threads() - + # Restore the signal to avoid duplicated sum of finish barrier if tx == 0: T.st( @@ -255,7 +257,8 @@ def run_torus_alltoall(local_rank, num_ranks, args): allocator=allocator).fill_(-1) # Signals for each buffer slot in each direction - signal_transfer = tilelang.tensor((PE_num, PE_num, num_blocks_M), torch.uint32, + signal_transfer = tilelang.tensor((PE_num, PE_num, num_blocks_M), + torch.uint32, allocator=allocator).fill_(PE_num + 1) local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) @@ -322,7 +325,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): # For bandwidth calculation, we usually use the amount of data sent per rank. total_data_bytes = (PE_num - 1) * M * N * 2 # float16 = 2 bytes bandwidth_gbps = (total_data_bytes / 1e9) / (elapsed_time_ms / 1e3) - print(f"Benchmark Results:") + print("Benchmark Results:") print(f" Average Latency: {elapsed_time_ms:.4f} ms") print(f" Effective Bandwidth: {bandwidth_gbps:.4f} GB/s") From b00bdd8244f8895187b9d6d14d1b2a0726025561 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Fri, 6 Feb 2026 02:19:15 +0000 Subject: [PATCH 17/28] [BugFix] Warp-level scheduling with active blocks and correct synchronization --- .../intranode/example_alltoall_route2x4.py | 223 +++++++++--------- 1 file changed, 113 insertions(+), 110 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index 1e8ffc077..eeadf34e4 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -24,8 +24,15 @@ class Direction(IntEnum): }, # debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" ) -def torus_alltoall_xy(PE_num, X, Y, M, N, block_M, block_N, threads): - num_blocks_M = M // block_M +def torus_alltoall_xy(PE_num, X, Y, M, N, tile_M, tile_N, threads): + num_SM = 148 + num_tiles = M // tile_M + num_blocks = min(num_SM // (PE_num * PE_num), num_tiles) + num_warps = threads // 32 + block_M = M // num_blocks + tiles_per_block = block_M // tile_M + + assert tiles_per_block <= num_warps, "Each warp should handle the signal and transfer of one tile" @T.prim_func def main( @@ -35,15 +42,16 @@ def main( # buffer[src_rank, dst_rank, *, *]: This PE save a slot for transferring data chunks for the real destination rank buffer_transfer: T.Tensor((PE_num, PE_num, M, N), "float16"), # Signal for each buffer - signal_transfer: T.Tensor((PE_num, PE_num, num_blocks_M), "uint32"), + signal_transfer: T.Tensor((PE_num, PE_num, num_blocks, num_warps), "uint32"), # Signal for finish local_finish: T.Tensor((1), "uint32"), global_finish: T.Tensor((1), "uint32"), # Barrier for all blocks barrier: T.Tensor((PE_num), "int32"), ): - with T.Kernel(PE_num, PE_num, T.ceildiv(M, block_M), threads=threads) as (bx, by, bz): + with T.Kernel(PE_num, PE_num, num_blocks, threads=threads) as (bx, by, bz): tx = T.get_thread_binding() + warp_idx = tx // 32 rank = T.alloc_local([1], "uint32") rank_x = T.alloc_local([1], "uint32") @@ -55,16 +63,18 @@ def main( dst_rank = T.alloc_local([1], "uint32") old_local = T.alloc_local([1], "uint32") old_global = T.alloc_local([1], "uint32") - num_block_M = T.alloc_local([1], "uint32") + num_tiles = T.alloc_local([1], "uint32") + flag = T.alloc_local([1], "bool") rank[0] = T.get_rank() rank_x[0] = T.floordiv(rank[0], Y) rank_y[0] = T.floormod(rank[0], Y) next_rank[0] = rank[0] - - num_block_M[0] = T.ceildiv(M, block_M) + + num_tiles[0] = M // tile_M src_rank[0] = bx dst_rank[0] = by + flag[0] = False # Prepare for routing dst_rank_x = T.floordiv(dst_rank[0], Y) @@ -109,121 +119,111 @@ def main( # Phase 1: Initial send from src to the target neighbor if src_rank[0] == rank[0]: if dst_rank[0] != rank[0]: - T.put_block( - T.address_of(src[dst_rank[0] * M + bz * block_M, 0]), - T.address_of(buffer_transfer[rank[0], dst_rank[0], bz * block_M, 0]), - block_M * N, - next_rank[0], - ) - if tx == 0: - T.st( - signal_transfer[rank[0], dst_rank[0], bz], - rank[0], - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=next_rank[0]) - T.sync_threads() - else: - T.put_block( - T.address_of(src[dst_rank[0] * M + bz * block_M, 0]), - T.address_of(dst[rank[0] * M + bz * block_M, 0]), - block_M * N, - -1, - ) - if tx == 0: - old_local[0] = T.atom_add( - local_finish[0], - 1, - scope=T.MemoryScope.GPU, - sem=T.MemorySemantic.RELEASE, - ) - T.fence_cta(sem=T.MemorySemantic.RELEASE) - T.print(bz, msg="block_idx") - - # T.barrier_blocks(barrier) - T.fence_sys(sem=T.MemorySemantic.RELEASE) - - # Phase 2: Each block handles one final dst data in one direction buffer of current rank and check whether to transfer - # Signal values: represent the src_rank - with T.While(global_finish[0] < PE_num): - if tx == 0: - T.wait_le( - signal_transfer[bx, dst_rank[0], bz], PE_num, scope=T.MemoryScope.SYSTEM) - T.sync_threads() - - if signal_transfer[bx, dst_rank[0], bz] < PE_num: - # Only handle the transfer signal - if dst_rank[0] != rank[0]: - T.put_block( - T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M, 0]), - T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M, 0]), - block_M * N, + if warp_idx < tiles_per_block: + T.put_warp( + T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M, 0]), + T.address_of(buffer_transfer[rank[0], dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), + tile_M * N, next_rank[0], ) - if tx == 0: + if tx % 32 == 0: T.st( - signal_transfer[bx, dst_rank[0], bz], - bx, + signal_transfer[rank[0], dst_rank[0], bz, warp_idx], + 1, scope=T.MemoryScope.SYSTEM, sem=T.MemorySemantic.RELEASE, - dst_pe=next_rank[0], - ) - T.sync_threads() - else: - # Current rank is the real destination of this chunk of data, the real source rank is the buffer index - T.put_block( - T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M, 0]), - T.address_of(dst[bx * M + bz * block_M, 0]), - block_M * N, + dst_pe=next_rank[0]) + T.sync_warp() + else: + if warp_idx < tiles_per_block: + T.put_warp( + T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M, 0]), + T.address_of(dst[rank[0] * M + bz * block_M + warp_idx * tile_M, 0]), + tile_M * N, -1, ) - T.fence_cta(sem=T.MemorySemantic.RELEASE) - if tx == 0: + if tx % 32 == 0: old_local[0] = T.atom_add( local_finish[0], 1, scope=T.MemoryScope.GPU, sem=T.MemorySemantic.RELEASE, ) - # T.print(signal_transfer[bx, rank[0], bz], msg="signal_transfer") - # T.print(old_local[0], msg="old_local") - if old_local[0] + 1 == PE_num * num_block_M[0]: - for i in T.serial(PE_num): - old_global[0] = T.atom_add_remote( - global_finish[0], - 1, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=i, - ) - if old_global[0] + 1 == PE_num: - # Send termination signals to wake up all waiting blocks on all PEs - for remote_pe in T.serial(PE_num): - for src_rank in T.serial(PE_num): - for dst_rank in T.serial(PE_num): - for dst_block in T.serial(num_blocks_M): - T.st( - signal_transfer[src_rank, dst_rank, - dst_block], - PE_num, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=remote_pe, - ) - T.sync_threads() + T.sync_warp() + T.fence_cta(sem=T.MemorySemantic.RELEASE) - # Restore the signal to avoid duplicated sum of finish barrier - if tx == 0: + T.fence_sys(sem=T.MemorySemantic.RELEASE) + T.sync_threads() + + # Phase 2: Each block handles one final dst data in one direction buffer of current rank and check whether to transfer + if tx % 32 == 0: + T.wait_gt( + signal_transfer[bx, dst_rank[0], bz, warp_idx], 0, scope=T.MemoryScope.SYSTEM) + T.sync_warp() + + if signal_transfer[bx, dst_rank[0], bz, warp_idx] == 1 and flag[0] == False: + flag[0] = True + # Handle the transfer signal + if dst_rank[0] != rank[0]: + T.put_warp( + T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), + T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), + tile_M * N, + dst_pe=next_rank[0], + ) + if tx % 32 == 0: T.st( - signal_transfer[bx, dst_rank[0], bz], - PE_num + 1, + signal_transfer[bx, dst_rank[0], bz, warp_idx], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + T.sync_warp() + else: + # Current rank is the real destination of this chunk of data, the real source rank is the buffer index + T.put_warp( + T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), + T.address_of(dst[bx * M + bz * block_M + warp_idx * tile_M, 0]), + tile_M * N, + -1, + ) + T.sync_warp() + if tx % 32 == 0: + old_local[0] = T.atom_add( + local_finish[0], + 1, scope=T.MemoryScope.GPU, sem=T.MemorySemantic.RELEASE, - dst_pe=-1, ) - T.sync_threads() + if old_local[0] + 1 == PE_num * num_tiles[0]: + for i in T.serial(PE_num): + old_global[0] = T.atom_add_remote( + global_finish[0], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=i, + ) + if old_global[0] + 1 == PE_num: + # Send termination signals to wake up all waiting blocks on all PEs + for remote_pe in T.serial(PE_num): + for src_rank_idx in T.serial(PE_num): + for dst_rank_idx in T.serial(PE_num): + for bz_idx in T.serial(num_blocks): + for dst_tile in T.serial(num_warps): + T.st( + signal_transfer[src_rank_idx, dst_rank_idx, + bz_idx, dst_tile], + 2, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=remote_pe, + ) + T.sync_warp() + + T.sync_threads() - # T.barrier_blocks(barrier) T.fence_sys(sem=T.MemorySemantic.RELEASE) return main @@ -233,9 +233,12 @@ def run_torus_alltoall(local_rank, num_ranks, args): PE_num = args.PE_num X, Y = args.X, args.Y M, N = args.M, args.N - block_M, block_N = M // 8, N - threads = 128 - num_blocks_M = M // block_M + tile_M, tile_N = M // 16, N + threads = 256 + + num_tiles = M // tile_M + num_blocks = min(num_tiles, 148 // (PE_num * PE_num)) + num_warps = threads // 32 local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) allocator = tilelang.get_allocator( @@ -247,19 +250,19 @@ def run_torus_alltoall(local_rank, num_ranks, args): group=group_size, ) - kernel = torus_alltoall_xy(PE_num, X, Y, M, N, block_M, block_N, threads) + kernel = torus_alltoall_xy(PE_num, X, Y, M, N, tile_M, tile_N, threads) kernel.initialize(allocator=allocator) src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_() dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() buffer_transfer = tilelang.tensor((PE_num, PE_num, M, N), torch.float16, - allocator=allocator).fill_(-1) + allocator=allocator).fill_(0) # Signals for each buffer slot in each direction - signal_transfer = tilelang.tensor((PE_num, PE_num, num_blocks_M), + signal_transfer = tilelang.tensor((PE_num, PE_num, num_blocks, num_warps), torch.uint32, - allocator=allocator).fill_(PE_num + 1) + allocator=allocator).fill_(0) local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() From 2822ced45faa31e8125a8dd792fd3da179109346 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Fri, 6 Feb 2026 05:51:14 +0000 Subject: [PATCH 18/28] [Example] Add benchmark options --- .../distributed/intranode/example_alltoall.py | 20 +++++++++++++---- .../intranode/example_alltoall_route2x4.py | 22 +++++++++---------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall.py b/examples/distributed/intranode/example_alltoall.py index 839e41dcb..d863da83f 100644 --- a/examples/distributed/intranode/example_alltoall.py +++ b/examples/distributed/intranode/example_alltoall.py @@ -43,7 +43,7 @@ def run_alltoall(local_rank, num_ranks, args): N = args.N block_M = 32 block_N = N - threads = 128 + threads = 256 local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) allocator = tilelang.get_allocator( @@ -62,16 +62,26 @@ def run_alltoall(local_rank, num_ranks, args): torch.cuda.synchronize() dist.barrier(group_size) + + # Warmup + for _ in range(args.warmup): + kernel(src, dst, barrier) + torch.cuda.synchronize() + dist.barrier(group_size) + start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() - kernel(src, dst, barrier) + for _ in range(args.iter): + kernel(src, dst, barrier) + torch.cuda.synchronize() + dist.barrier(group_size) end.record() torch.cuda.synchronize() dist.barrier(group_size) - elapsed_time = start.elapsed_time(end) + elapsed_time = start.elapsed_time(end) / args.iter print( - f"Rank {local_rank} Kernel execution time: {elapsed_time:.3f} ms, Bandwidth: {2 * PE_num * M * N / (elapsed_time * 1e6):.3f} GB/s" + f"Rank {local_rank} Average Kernel execution time: {elapsed_time:.3f} ms, Bandwidth: {2 * PE_num * M * N / (elapsed_time * 1e6):.3f} GB/s" ) # Torch Reference @@ -97,6 +107,8 @@ def run_alltoall(local_rank, num_ranks, args): parser.add_argument("--PE_num", type=int, default=8) parser.add_argument("--M", type=int, default=8192) parser.add_argument("--N", type=int, default=7168) + parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") + parser.add_argument("--iter", type=int, default=10, help="Number of benchmark iterations") args = parser.parse_args() torch.multiprocessing.spawn(run_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index eeadf34e4..39f86c39f 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -239,6 +239,8 @@ def run_torus_alltoall(local_rank, num_ranks, args): num_tiles = M // tile_M num_blocks = min(num_tiles, 148 // (PE_num * PE_num)) num_warps = threads // 32 + # Modify tile_M to use all warps + tile_M = M // (num_blocks * num_warps) local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) allocator = tilelang.get_allocator( @@ -317,20 +319,18 @@ def run_torus_alltoall(local_rank, num_ranks, args): start_event.record() for _ in range(num_iters): kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + torch.cuda.synchronize() + dist.barrier(group_size) end_event.record() torch.cuda.synchronize() - dist.barrier(group_size) - if local_rank == 0: - elapsed_time_ms = start_event.elapsed_time(end_event) / num_iters - # All-to-all total data moved: each rank sends (PE_num - 1) * M * N elements - # and receives (PE_num - 1) * M * N elements. - # For bandwidth calculation, we usually use the amount of data sent per rank. - total_data_bytes = (PE_num - 1) * M * N * 2 # float16 = 2 bytes - bandwidth_gbps = (total_data_bytes / 1e9) / (elapsed_time_ms / 1e3) - print("Benchmark Results:") - print(f" Average Latency: {elapsed_time_ms:.4f} ms") - print(f" Effective Bandwidth: {bandwidth_gbps:.4f} GB/s") + elapsed_time_ms = start_event.elapsed_time(end_event) / num_iters + # All-to-all total data moved: each rank sends (PE_num - 1) * M * N elements + # and receives (PE_num - 1) * M * N elements. + # For bandwidth calculation, we usually use the amount of data sent per rank. + total_data_bytes = PE_num * M * N * 2 # float16 = 2 bytes + bandwidth_gbps = (total_data_bytes / 1e9) / (elapsed_time_ms / 1e3) + print(f"Rank {local_rank} Average Latency: {elapsed_time_ms:.4f} ms, Effective Bandwidth: {bandwidth_gbps:.4f} GB/s") dist.destroy_process_group() From 81af526bc35188dabcbd210214cc1ef92d19a0bf Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Fri, 6 Feb 2026 08:40:38 +0000 Subject: [PATCH 19/28] [Enhancement] Fully utilize blocks to send/recv data --- .../intranode/example_alltoall_route2x4.py | 110 +++++++++--------- 1 file changed, 56 insertions(+), 54 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index 39f86c39f..70b46bd73 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -22,18 +22,13 @@ class Direction(IntEnum): tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }, - # debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" + debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" ) -def torus_alltoall_xy(PE_num, X, Y, M, N, tile_M, tile_N, threads): - num_SM = 148 - num_tiles = M // tile_M - num_blocks = min(num_SM // (PE_num * PE_num), num_tiles) +def torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, threads): num_warps = threads // 32 block_M = M // num_blocks - tiles_per_block = block_M // tile_M + tile_M = block_M // num_warps - assert tiles_per_block <= num_warps, "Each warp should handle the signal and transfer of one tile" - @T.prim_func def main( # For each (src, dst) pair, the real transfer size is M * N @@ -116,52 +111,58 @@ def main( T.fence_gpu(sem=T.MemorySemantic.RELEASE) - # Phase 1: Initial send from src to the target neighbor - if src_rank[0] == rank[0]: - if dst_rank[0] != rank[0]: - if warp_idx < tiles_per_block: - T.put_warp( - T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M, 0]), - T.address_of(buffer_transfer[rank[0], dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), - tile_M * N, - next_rank[0], - ) - if tx % 32 == 0: - T.st( - signal_transfer[rank[0], dst_rank[0], bz, warp_idx], - 1, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=next_rank[0]) - T.sync_warp() - else: - if warp_idx < tiles_per_block: - T.put_warp( - T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M, 0]), - T.address_of(dst[rank[0] * M + bz * block_M + warp_idx * tile_M, 0]), - tile_M * N, - -1, - ) - if tx % 32 == 0: - old_local[0] = T.atom_add( - local_finish[0], - 1, - scope=T.MemoryScope.GPU, - sem=T.MemorySemantic.RELEASE, - ) - T.sync_warp() - T.fence_cta(sem=T.MemorySemantic.RELEASE) + # Phase 1: Fully use all blocks to initially send from src to the target neighbor + # Split the tile_M to each block - T.fence_sys(sem=T.MemorySemantic.RELEASE) + chunk_M = T.ceildiv(tile_M, PE_num) + chunk_start = bx * chunk_M + chunk_size = T.min(chunk_M, tile_M - chunk_start) + + # if src_rank[0] == rank[0]: + if dst_rank[0] != rank[0]: + T.put_warp( + T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), + T.address_of(buffer_transfer[rank[0], dst_rank[0], bz * block_M + warp_idx * tile_M + chunk_start, 0]), + chunk_size * N, + next_rank[0], + ) + if tx % 32 == 0: + T.atom_add_remote( + signal_transfer[rank[0], dst_rank[0], bz, warp_idx], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + T.sync_warp() + # T.fence_sys(sem=T.MemorySemantic.RELEASE) + else: + T.put_warp( + T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), + T.address_of(dst[rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), + chunk_size * N, + -1, + ) + if tx % 32 == 0: + old_local[0] = T.atom_add( + local_finish[0], + 1, + scope=T.MemoryScope.GPU, + sem=T.MemorySemantic.RELEASE, + ) + T.sync_warp() + # T.fence_cta(sem=T.MemorySemantic.RELEASE) + + # T.fence_sys(sem=T.MemorySemantic.RELEASE) T.sync_threads() # Phase 2: Each block handles one final dst data in one direction buffer of current rank and check whether to transfer if tx % 32 == 0: - T.wait_gt( - signal_transfer[bx, dst_rank[0], bz, warp_idx], 0, scope=T.MemoryScope.SYSTEM) + T.wait_ge( + signal_transfer[bx, dst_rank[0], bz, warp_idx], PE_num, scope=T.MemoryScope.SYSTEM) T.sync_warp() - if signal_transfer[bx, dst_rank[0], bz, warp_idx] == 1 and flag[0] == False: + if signal_transfer[bx, dst_rank[0], bz, warp_idx] == PE_num and flag[0] == False: flag[0] = True # Handle the transfer signal if dst_rank[0] != rank[0]: @@ -174,7 +175,7 @@ def main( if tx % 32 == 0: T.st( signal_transfer[bx, dst_rank[0], bz, warp_idx], - 1, + PE_num, scope=T.MemoryScope.SYSTEM, sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0], @@ -196,7 +197,7 @@ def main( scope=T.MemoryScope.GPU, sem=T.MemorySemantic.RELEASE, ) - if old_local[0] + 1 == PE_num * num_tiles[0]: + if old_local[0] + 1 == (PE_num - 1) * num_tiles[0] + PE_num * num_tiles[0]: for i in T.serial(PE_num): old_global[0] = T.atom_add_remote( global_finish[0], @@ -215,7 +216,7 @@ def main( T.st( signal_transfer[src_rank_idx, dst_rank_idx, bz_idx, dst_tile], - 2, + PE_num + 1, scope=T.MemoryScope.SYSTEM, sem=T.MemorySemantic.RELEASE, dst_pe=remote_pe, @@ -230,14 +231,15 @@ def main( def run_torus_alltoall(local_rank, num_ranks, args): + NUM_SM = 148 PE_num = args.PE_num X, Y = args.X, args.Y M, N = args.M, args.N - tile_M, tile_N = M // 16, N + block_M, block_N = M // 8, N threads = 256 - num_tiles = M // tile_M - num_blocks = min(num_tiles, 148 // (PE_num * PE_num)) + num_blocks = M // block_M + num_blocks = min(num_blocks, NUM_SM // (PE_num * PE_num)) num_warps = threads // 32 # Modify tile_M to use all warps tile_M = M // (num_blocks * num_warps) @@ -252,7 +254,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): group=group_size, ) - kernel = torus_alltoall_xy(PE_num, X, Y, M, N, tile_M, tile_N, threads) + kernel = torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, threads) kernel.initialize(allocator=allocator) src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_() From 5f584e5625775eb1198a8233e339fa4df7e54500 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Sun, 8 Feb 2026 06:30:17 +0000 Subject: [PATCH 20/28] [Routing] Optimize for balanced routing direction --- .../distributed/intranode/example_alltoall_route2x4.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index 70b46bd73..33fd0b9de 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -81,7 +81,7 @@ def main( diff[0] = dst_rank_x - rank_x[0] if diff[0] > T.floordiv(X, 2): diff[0] -= X - elif diff[0] < -T.floordiv(X, 2): + elif diff[0] <= -T.floordiv(X, 2): diff[0] += X if diff[0] < 0: @@ -97,7 +97,7 @@ def main( diff[0] = dst_rank_y - rank_y[0] if diff[0] > T.floordiv(Y, 2): diff[0] -= Y - elif diff[0] < -T.floordiv(Y, 2): + elif diff[0] <= -T.floordiv(Y, 2): diff[0] += Y if diff[0] < 0: @@ -135,7 +135,7 @@ def main( dst_pe=next_rank[0], ) T.sync_warp() - # T.fence_sys(sem=T.MemorySemantic.RELEASE) + T.fence_sys(sem=T.MemorySemantic.RELEASE) else: T.put_warp( T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), @@ -151,9 +151,9 @@ def main( sem=T.MemorySemantic.RELEASE, ) T.sync_warp() - # T.fence_cta(sem=T.MemorySemantic.RELEASE) + T.fence_cta(sem=T.MemorySemantic.RELEASE) - # T.fence_sys(sem=T.MemorySemantic.RELEASE) + T.fence_sys(sem=T.MemorySemantic.RELEASE) T.sync_threads() # Phase 2: Each block handles one final dst data in one direction buffer of current rank and check whether to transfer From 1d4bbd3c74343ce600c8895e06cf09943ad4bc2c Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Mon, 9 Feb 2026 07:32:56 +0000 Subject: [PATCH 21/28] [BugFix] Reinitialize the signal before benchmark --- .../intranode/example_alltoall_route2x4.py | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index 33fd0b9de..a16a8d2de 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -6,7 +6,7 @@ import argparse from enum import IntEnum -# tilelang.disable_cache() +tilelang.disable_cache() class Direction(IntEnum): @@ -30,7 +30,7 @@ def torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, threads): tile_M = block_M // num_warps @T.prim_func - def main( + def main_route( # For each (src, dst) pair, the real transfer size is M * N src: T.Tensor((PE_num * M, N), "float16"), dst: T.Tensor((PE_num * M, N), "float16"), @@ -180,6 +180,8 @@ def main( sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0], ) + # if bz == 0 and tx == 0: + # T.print(rank[0], "transfer rank") T.sync_warp() else: # Current rank is the real destination of this chunk of data, the real source rank is the buffer index @@ -197,6 +199,8 @@ def main( scope=T.MemoryScope.GPU, sem=T.MemorySemantic.RELEASE, ) + # if bz == 0 and tx == 0: + # T.print(rank[0], "dst rank") if old_local[0] + 1 == (PE_num - 1) * num_tiles[0] + PE_num * num_tiles[0]: for i in T.serial(PE_num): old_global[0] = T.atom_add_remote( @@ -227,7 +231,7 @@ def main( T.fence_sys(sem=T.MemorySemantic.RELEASE) - return main + return main_route def run_torus_alltoall(local_rank, num_ranks, args): @@ -235,14 +239,12 @@ def run_torus_alltoall(local_rank, num_ranks, args): PE_num = args.PE_num X, Y = args.X, args.Y M, N = args.M, args.N - block_M, block_N = M // 8, N + block_M, block_N = M // 2, N threads = 256 num_blocks = M // block_M num_blocks = min(num_blocks, NUM_SM // (PE_num * PE_num)) num_warps = threads // 32 - # Modify tile_M to use all warps - tile_M = M // (num_blocks * num_warps) local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) allocator = tilelang.get_allocator( @@ -260,6 +262,10 @@ def run_torus_alltoall(local_rank, num_ranks, args): src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_() dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() + dst_ref = torch.zeros((PE_num * M, N), dtype=torch.float16, device="cuda") + dist.all_to_all_single(dst_ref, src, group=group_size) + torch.cuda.synchronize() + buffer_transfer = tilelang.tensor((PE_num, PE_num, M, N), torch.float16, allocator=allocator).fill_(0) @@ -281,10 +287,6 @@ def run_torus_alltoall(local_rank, num_ranks, args): print(f"Rank {local_rank} TileLang AllToAll XY Routing Finished.") - dst_ref = torch.zeros((PE_num * M, N), dtype=torch.float16, device="cuda") - dist.all_to_all_single(dst_ref, src, group=group_size) - torch.cuda.synchronize() - if torch.allclose(dst, dst_ref, atol=1e-2, rtol=1e-2): print(f"Rank {local_rank} Verification Passed! ✅") else: @@ -295,32 +297,39 @@ def run_torus_alltoall(local_rank, num_ranks, args): diff_count = diff_mask.sum().item() if diff_count > 0: - diff_indices = torch.nonzero(diff_mask, as_tuple=False) - print(f"Rank {local_rank} found {diff_count} differences at locations:") - # Print first few differences - num_to_print = min(10, diff_count) - for i in range(num_to_print): - idx = diff_indices[i] + # Check each rank's symmetric buffer at every M boundary + print(f"Rank {local_rank} found {diff_count} differences") + print(f"Rank {local_rank} checking symmetric buffer positions at M={M} boundaries:") + for rank_idx in range(PE_num): + start_idx = rank_idx * M + # Check first element of each rank's buffer print( - f" Position {idx.tolist()}: dst={dst[idx[0], idx[1]].item():.6f}, dst_ref={dst_ref[idx[0], idx[1]].item():.6f}" + f"Rank {local_rank} Buffer[{rank_idx}][0,0]: dst={buffer_transfer[rank_idx, local_rank, 0, 0].item():.6f}, dst_ref={dst_ref[start_idx, 0].item():.6f}" ) - if diff_count > num_to_print: - print(f" ... and {diff_count - num_to_print} more differences") if args.benchmark: # Warmup - for _ in range(5): + for _ in range(args.warmup): kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) torch.cuda.synchronize() dist.barrier(group_size) + # Reinitialize + buffer_transfer.zero_() + signal_transfer.zero_() + local_finish.zero_() + global_finish.zero_() + start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) - num_iters = 10 + num_iters = args.iter start_event.record() for _ in range(num_iters): + # torch.cuda.profiler.start() + # with torch.cuda.nvtx.range("alltoall_xy_routing_benchmark"): kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + # torch.cuda.profiler.stop() torch.cuda.synchronize() dist.barrier(group_size) end_event.record() @@ -345,6 +354,8 @@ def run_torus_alltoall(local_rank, num_ranks, args): parser.add_argument("--X", type=int, default=2) parser.add_argument("--Y", type=int, default=4) parser.add_argument("--benchmark", action="store_true", help="Run benchmark") + parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") + parser.add_argument("--iter", type=int, default=10, help="Number of benchmark iterations") args = parser.parse_args() torch.multiprocessing.spawn(run_torus_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) From b0651994e84f733b4676204642768ca5c6d77061 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Mon, 9 Feb 2026 10:20:07 +0000 Subject: [PATCH 22/28] [Feature] Add return value of wait op --- .../example_allgather_gemm_overlapped.py | 2 +- .../example_alltoall_route2x4_opt.py | 356 ++++++++++++++++++ src/op/sync.cc | 5 +- src/op/sync.h | 8 +- src/target/codegen_cuda.cc | 14 + src/tl_templates/cuda/sync.h | 36 +- tilelang/language/distributed/common.py | 37 +- 7 files changed, 429 insertions(+), 29 deletions(-) create mode 100644 examples/distributed/intranode/example_alltoall_route2x4_opt.py diff --git a/examples/distributed/example_allgather_gemm_overlapped.py b/examples/distributed/example_allgather_gemm_overlapped.py index cebf58ed1..61aff0dd9 100644 --- a/examples/distributed/example_allgather_gemm_overlapped.py +++ b/examples/distributed/example_allgather_gemm_overlapped.py @@ -93,7 +93,7 @@ def main( tid = T.get_thread_binding(0) T.clear(C_local) if tid == 0: - T.wait_eq(signal_buffer[pid_m * block_M // M_per_rank], 1) + T.wait_eq(signal_buffer[pid_m * block_M // M_per_rank], 1, dtype="uint32") for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(A[pid_m * block_M, k * block_K], A_shared) T.copy(B[k * block_K, pid_n * block_N], B_shared) diff --git a/examples/distributed/intranode/example_alltoall_route2x4_opt.py b/examples/distributed/intranode/example_alltoall_route2x4_opt.py new file mode 100644 index 000000000..8222a9921 --- /dev/null +++ b/examples/distributed/intranode/example_alltoall_route2x4_opt.py @@ -0,0 +1,356 @@ +import tilelang +import tilelang.language as T +from tilelang.distributed import init_dist +import torch +import torch.distributed as dist +import argparse +from enum import IntEnum + +tilelang.disable_cache() + + +class Direction(IntEnum): + NORTH = 0 + SOUTH = 1 + WEST = 2 + EAST = 3 + SELF = 4 + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" +) +def torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, threads): + num_warps = threads // 32 + block_M = M // num_blocks + # tile_M = block_M // num_warps + # Number of slots + num_slots = PE_num + + @T.prim_func + def main_route( + # For each (src, dst) pair, the real transfer size is M * N + src: T.Tensor((PE_num * M, N), "float16"), + dst: T.Tensor((PE_num * M, N), "float16"), + # buffer[dst_rank, slots, *, *]: This PE save slots for transferring data chunks for the real destination rank + buffer_transfer: T.Tensor((PE_num, num_slots, M, N), "float16"), + # Signal for each dst buffer + signal_transfer: T.Tensor((PE_num, num_blocks, num_warps), "uint32"), + # Signal for finish + local_finish: T.Tensor((1), "uint32"), + global_finish: T.Tensor((1), "uint32"), + # Barrier for all blocks + barrier: T.Tensor((PE_num), "int32"), + ): + with T.Kernel(PE_num, num_blocks, threads=threads) as (bx, bz): + tx = T.get_thread_binding() + warp_idx = tx // 32 + + rank = T.alloc_local([1], "uint32") + rank_x = T.alloc_local([1], "uint32") + rank_y = T.alloc_local([1], "uint32") + next_rank = T.alloc_local([1], "uint32") + to_dir = T.alloc_local([1], "uint32") + diff = T.alloc_local([1], "int32") + dst_rank = T.alloc_local([1], "uint32") + old_local = T.alloc_local([1], "uint32") + old_global = T.alloc_local([1], "uint32") + num_tiles = T.alloc_local([1], "uint32") + + rank[0] = T.get_rank() + rank_x[0] = T.floordiv(rank[0], Y) + rank_y[0] = T.floormod(rank[0], Y) + next_rank[0] = rank[0] + + dst_rank[0] = bx + + # Prepare for routing + dst_rank_x = T.floordiv(dst_rank[0], Y) + dst_rank_y = T.floormod(dst_rank[0], Y) + to_dir[0] = Direction.SELF + # XY-routing: first route along X-axis, then Y-axis + if dst_rank_x != rank_x[0]: + # Calculate shortest path in Torus X-dimension + diff[0] = dst_rank_x - rank_x[0] + if diff[0] > T.floordiv(X, 2): + diff[0] -= X + elif diff[0] <= -T.floordiv(X, 2): + diff[0] += X + + if diff[0] < 0: + # Send North (up): neighbor receives in its north buffer + to_dir[0] = Direction.NORTH + next_rank[0] = T.floormod(rank_x[0] + X - 1, X) * Y + rank_y[0] + else: + # Send South (down): neighbor receives in its south buffer + to_dir[0] = Direction.SOUTH + next_rank[0] = T.floormod(rank_x[0] + 1, X) * Y + rank_y[0] + elif dst_rank_y != rank_y[0]: + # Calculate shortest path in Torus Y-dimension + diff[0] = dst_rank_y - rank_y[0] + if diff[0] > T.floordiv(Y, 2): + diff[0] -= Y + elif diff[0] <= -T.floordiv(Y, 2): + diff[0] += Y + + if diff[0] < 0: + # Send West (left): neighbor receives in its west buffer + to_dir[0] = Direction.WEST + next_rank[0] = rank_x[0] * Y + T.floormod(rank_y[0] + Y - 1, Y) + else: + # Send East (right): neighbor receives in its east buffer + to_dir[0] = Direction.EAST + next_rank[0] = rank_x[0] * Y + T.floormod(rank_y[0] + 1, Y) + + # Phase 1: Fully use all blocks to initially send from src to the target neighbor + # Split the tile_M to each block + + chunk_M = T.ceildiv(tile_M, PE_num) + chunk_start = bx * chunk_M + chunk_size = T.min(chunk_M, tile_M - chunk_start) + + # if src_rank[0] == rank[0]: + if dst_rank[0] != rank[0]: + T.put_warp( + T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), + T.address_of(buffer_transfer[rank[0], dst_rank[0], bz * block_M + warp_idx * tile_M + chunk_start, 0]), + chunk_size * N, + next_rank[0], + ) + if tx % 32 == 0: + T.atom_add_remote( + signal_transfer[rank[0], dst_rank[0], bz, warp_idx], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + T.sync_warp() + # T.fence_sys(sem=T.MemorySemantic.RELEASE) + else: + T.put_warp( + T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), + T.address_of(dst[rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), + chunk_size * N, + -1, + ) + if tx % 32 == 0: + old_local[0] = T.atom_add( + local_finish[0], + 1, + scope=T.MemoryScope.GPU, + sem=T.MemorySemantic.RELEASE, + ) + T.sync_warp() + # T.fence_cta(sem=T.MemorySemantic.RELEASE) + + T.fence_sys(sem=T.MemorySemantic.RELEASE) + T.sync_threads() + + # Phase 2: Each block handles one final dst data in one direction buffer of current rank and check whether to transfer + if tx % 32 == 0: + T.wait_ge( + signal_transfer[bx, dst_rank[0], bz, warp_idx], PE_num, scope=T.MemoryScope.SYSTEM) + T.sync_warp() + + if signal_transfer[bx, dst_rank[0], bz, warp_idx] == PE_num and flag[0] == False: + flag[0] = True + # Handle the transfer signal + if dst_rank[0] != rank[0]: + T.put_warp( + T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), + T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), + tile_M * N, + dst_pe=next_rank[0], + ) + if tx % 32 == 0: + T.st( + signal_transfer[bx, dst_rank[0], bz, warp_idx], + PE_num, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + # if bz == 0 and tx == 0: + # T.print(rank[0], "transfer rank") + T.sync_warp() + else: + # Current rank is the real destination of this chunk of data, the real source rank is the buffer index + T.put_warp( + T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), + T.address_of(dst[bx * M + bz * block_M + warp_idx * tile_M, 0]), + tile_M * N, + -1, + ) + T.sync_warp() + if tx % 32 == 0: + old_local[0] = T.atom_add( + local_finish[0], + 1, + scope=T.MemoryScope.GPU, + sem=T.MemorySemantic.RELEASE, + ) + # if bz == 0 and tx == 0: + # T.print(rank[0], "dst rank") + if old_local[0] + 1 == (PE_num - 1) * num_tiles[0] + PE_num * num_tiles[0]: + for i in T.serial(PE_num): + old_global[0] = T.atom_add_remote( + global_finish[0], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=i, + ) + if old_global[0] + 1 == PE_num: + # Send termination signals to wake up all waiting blocks on all PEs + for remote_pe in T.serial(PE_num): + for src_rank_idx in T.serial(PE_num): + for dst_rank_idx in T.serial(PE_num): + for bz_idx in T.serial(num_blocks): + for dst_tile in T.serial(num_warps): + T.st( + signal_transfer[src_rank_idx, dst_rank_idx, + bz_idx, dst_tile], + PE_num + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=remote_pe, + ) + T.sync_warp() + + T.sync_threads() + + T.fence_sys(sem=T.MemorySemantic.RELEASE) + + return main_route + + +def run_torus_alltoall(local_rank, num_ranks, args): + NUM_SM = 148 + PE_num = args.PE_num + X, Y = args.X, args.Y + M, N = args.M, args.N + block_M, block_N = M // 2, N + threads = 256 + + num_blocks = M // block_M + num_blocks = min(num_blocks, NUM_SM // (PE_num * PE_num)) + num_warps = threads // 32 + + local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) + allocator = tilelang.get_allocator( + size=2**35, + device="cuda", + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_ranks, + group=group_size, + ) + + kernel = torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, threads) + kernel.initialize(allocator=allocator) + + src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_() + dst = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).zero_() + + dst_ref = torch.zeros((PE_num * M, N), dtype=torch.float16, device="cuda") + dist.all_to_all_single(dst_ref, src, group=group_size) + torch.cuda.synchronize() + + buffer_transfer = tilelang.tensor((PE_num, PE_num, M, N), torch.float16, + allocator=allocator).fill_(0) + + # Signals for each buffer slot in each direction + signal_transfer = tilelang.tensor((PE_num, PE_num, num_blocks, num_warps), + torch.uint32, + allocator=allocator).fill_(0) + local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) + global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) + barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() + + torch.cuda.synchronize() + dist.barrier(group_size) + + kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + + torch.cuda.synchronize() + dist.barrier(group_size) + + print(f"Rank {local_rank} TileLang AllToAll XY Routing Finished.") + + if torch.allclose(dst, dst_ref, atol=1e-2, rtol=1e-2): + print(f"Rank {local_rank} Verification Passed! ✅") + else: + max_diff = (dst - dst_ref).abs().max() + print(f"Rank {local_rank} Verification Failed! ❌ Max diff: {max_diff}") + # Find differences + diff_mask = (dst != dst_ref) + diff_count = diff_mask.sum().item() + + if diff_count > 0: + # Check each rank's symmetric buffer at every M boundary + print(f"Rank {local_rank} found {diff_count} differences") + print(f"Rank {local_rank} checking symmetric buffer positions at M={M} boundaries:") + for rank_idx in range(PE_num): + start_idx = rank_idx * M + # Check first element of each rank's buffer + print( + f"Rank {local_rank} Buffer[{rank_idx}][0,0]: dst={buffer_transfer[rank_idx, local_rank, 0, 0].item():.6f}, dst_ref={dst_ref[start_idx, 0].item():.6f}" + ) + + if args.benchmark: + # Warmup + for _ in range(args.warmup): + kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + torch.cuda.synchronize() + dist.barrier(group_size) + + # Reinitialize + buffer_transfer.zero_() + signal_transfer.zero_() + local_finish.zero_() + global_finish.zero_() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + num_iters = args.iter + start_event.record() + for _ in range(num_iters): + # torch.cuda.profiler.start() + # with torch.cuda.nvtx.range("alltoall_xy_routing_benchmark"): + kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + # torch.cuda.profiler.stop() + torch.cuda.synchronize() + dist.barrier(group_size) + end_event.record() + torch.cuda.synchronize() + + elapsed_time_ms = start_event.elapsed_time(end_event) / num_iters + # All-to-all total data moved: each rank sends (PE_num - 1) * M * N elements + # and receives (PE_num - 1) * M * N elements. + # For bandwidth calculation, we usually use the amount of data sent per rank. + total_data_bytes = PE_num * M * N * 2 # float16 = 2 bytes + bandwidth_gbps = (total_data_bytes / 1e9) / (elapsed_time_ms / 1e3) + print(f"Rank {local_rank} Average Latency: {elapsed_time_ms:.4f} ms, Effective Bandwidth: {bandwidth_gbps:.4f} GB/s") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=8192) + parser.add_argument("--N", type=int, default=7168) + parser.add_argument("--PE_num", type=int, default=8) + parser.add_argument("--X", type=int, default=2) + parser.add_argument("--Y", type=int, default=4) + parser.add_argument("--benchmark", action="store_true", help="Run benchmark") + parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") + parser.add_argument("--iter", type=int, default=10, help="Number of benchmark iterations") + args = parser.parse_args() + + torch.multiprocessing.spawn(run_torus_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) diff --git a/src/op/sync.cc b/src/op/sync.cc index 12b13f9ab..f8d585476 100644 --- a/src/op/sync.cc +++ b/src/op/sync.cc @@ -145,6 +145,7 @@ WaitOp::WaitOp(Array args, BufferMap vmap) { node->scope = (args.size() > 4) ? args[4].as()->value : 3; // semantic parameter is optional, default to ACQUIRE (3) for safety node->semantic = (args.size() > 5) ? args[5].as()->value : 3; + node->dtype = (args.size() > 6) ? args[6].as()->value : "int32"; data_ = std::move(node); (void)vmap; } @@ -183,8 +184,8 @@ Stmt WaitOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { // Pass semantic: 0=WEAK, 1=VOLATILE, 2=RELAXED, 3=ACQUIRE, 4=RELEASE, // 5=ACQ_REL new_args.push_back(IntImm(DataType::Int(32), semantic)); - - auto wait = Call(DataType::Handle(), builtin::call_extern(), new_args); + auto datatype = dtype == "int32" ? DataType::Int(32) : DataType::UInt(32); + auto wait = Call(datatype, builtin::call_extern(), new_args); return Evaluate(wait); } diff --git a/src/op/sync.h b/src/op/sync.h index 93656cf30..a4c78b0ef 100644 --- a/src/op/sync.h +++ b/src/op/sync.h @@ -53,6 +53,7 @@ class WaitOpNode : public TileOperatorNode { int scope; ///< Memory scope: 0=CTA, 1=CLUSTER, 2=GPU, 3=SYSTEM int semantic; ///< Memory semantic: 0=WEAK, 1=VOLATILE, 2=RELAXED, 3=ACQUIRE, ///< 4=RELEASE, 5=ACQ_REL + std::string dtype; ///< The data type of the memory address, must be int32 or uint32 bool is_distributed() const; @@ -73,13 +74,15 @@ class WaitOpNode : public TileOperatorNode { .def_ro("peer", &WaitOpNode::peer) .def_ro("relation", &WaitOpNode::relation) .def_ro("scope", &WaitOpNode::scope) - .def_ro("semantic", &WaitOpNode::semantic); + .def_ro("semantic", &WaitOpNode::semantic) + .def_ro("dtype", &WaitOpNode::dtype); } bool SEqualReduce(const WaitOpNode *other, SEqualReducer equal) const { return equal(addr, other->addr) && equal(expected, other->expected) && equal(peer, other->peer) && relation == other->relation && - scope == other->scope && semantic == other->semantic; + scope == other->scope && semantic == other->semantic && + dtype == other->dtype; } void SHashReduce(SHashReducer hash_reduce) const { @@ -89,6 +92,7 @@ class WaitOpNode : public TileOperatorNode { hash_reduce(relation); hash_reduce(scope); hash_reduce(semantic); + hash_reduce(dtype); } static constexpr bool _type_has_method_sequal_reduce = true; diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 333048b4b..3a500f813 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1477,6 +1477,20 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->PrintIndent(); int num_mma = Downcast(op->args[0])->value; this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n"; + } else if (op->op.same_as(Op::Get("tl.wait"))) { + // relation, addr, expected, peer, scope, semantic + ICHECK_GE(op->args.size(), 4); + const char *relation_str[] = {"eq", "ne", "ge", "le", "gt", "lt"}; + int relation = Downcast(op->args[0])->value; + std::string addr = this->PrintExpr(op->args[1]); + std::string expected = this->PrintExpr(op->args[2]); + // peer is args[3], but wait_ne in sync.h doesn't take peer if it's already a pointer + // The Lower() in sync.cc handles distributed by converting to remote pointer + int scope = (op->args.size() > 4) ? Downcast(op->args[4])->value : 3; + int semantic = (op->args.size() > 5) ? Downcast(op->args[5])->value : 3; + + os << "tl::wait_" << relation_str[relation] << "(" << addr << ", " << expected + << ", " << scope << ", " << semantic << ");\n"; } else if (op->op.same_as(tl::pack_b16())) { os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1]) << ")"; diff --git a/src/tl_templates/cuda/sync.h b/src/tl_templates/cuda/sync.h index 0a6801b5b..26e5ecc7f 100644 --- a/src/tl_templates/cuda/sync.h +++ b/src/tl_templates/cuda/sync.h @@ -323,81 +323,93 @@ TL_DEVICE T ld_wait_generic(const T *ptr, WaitScope scope, } template -TL_DEVICE void wait_eq(P ptr, T val, int scope = (int)WaitScope::SYSTEM, +TL_DEVICE T wait_eq(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); + T ret; // Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) != + while ((ret = ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic)) != val) ; + return ret; } template -TL_DEVICE void wait_ne(P ptr, T val, int scope = (int)WaitScope::SYSTEM, +TL_DEVICE T wait_ne(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); + T ret; // Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) == + while ((ret = ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic)) == val) ; + return ret; } template -TL_DEVICE void wait_ge(P ptr, T val, int scope = (int)WaitScope::SYSTEM, +TL_DEVICE T wait_ge(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); + T ret; // Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) < + while ((ret = ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic)) < val) ; + return ret; } template -TL_DEVICE void wait_le(P ptr, T val, int scope = (int)WaitScope::SYSTEM, +TL_DEVICE T wait_le(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); + T ret; // Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) > + while ((ret = ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic)) > val) ; + return ret; } template -TL_DEVICE void wait_gt(P ptr, T val, int scope = (int)WaitScope::SYSTEM, +TL_DEVICE T wait_gt(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); + T ret; // Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) <= + while ((ret = ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic)) <= val) ; + return ret; } template -TL_DEVICE void wait_lt(P ptr, T val, int scope = (int)WaitScope::SYSTEM, +TL_DEVICE T wait_lt(P ptr, T val, int scope = (int)WaitScope::SYSTEM, int semantic = (int)WaitSemantic::ACQUIRE) { static_assert(std::is_same_v || std::is_pointer_v

, "P must be a pointer or uint64_t"); T *flag_ptr = reinterpret_cast(ptr); + T ret; // Spin-loop #pragma unroll 1 - while (ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic) >= + while ((ret = ld_wait_generic(flag_ptr, (WaitScope)scope, (WaitSemantic)semantic)) >= val) ; + return ret; } } // namespace tl diff --git a/tilelang/language/distributed/common.py b/tilelang/language/distributed/common.py index fcd773840..b311bfcde 100644 --- a/tilelang/language/distributed/common.py +++ b/tilelang/language/distributed/common.py @@ -127,7 +127,8 @@ def wait_eq(barrier: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: MemoryScope = MemoryScope.SYSTEM, - semantic: MemorySemantic = MemorySemantic.ACQUIRE): + semantic: MemorySemantic = MemorySemantic.ACQUIRE, + dtype = "int32"): """Wait until *barrier == expected* for GPU-level synchronization. # todo: have different semantic compared to 3 fns below currently Args: @@ -136,8 +137,9 @@ def wait_eq(barrier: PrimExpr, peer: The PE to wait on (-1 for local) scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) + dtype: The data type of the memory address, must be int32 or uint32 """ - return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.EQ.value, + return tir.call_intrin(dtype, tir.op.Op.get("tl.wait"), BinaryRelation.EQ.value, address_of(barrier), expected, peer, scope.value, semantic.value) @@ -145,7 +147,8 @@ def wait_ne(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: MemoryScope = MemoryScope.SYSTEM, - semantic: MemorySemantic = MemorySemantic.ACQUIRE): + semantic: MemorySemantic = MemorySemantic.ACQUIRE, + dtype = "int32"): """Wait until *ptr != expected Args: @@ -154,8 +157,9 @@ def wait_ne(ptr: PrimExpr, peer: The PE to wait on (-1 for local) scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) + dtype: The data type of the memory address, must be int32 or uint32 """ - return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.NE.value, + return tir.call_intrin(dtype, tir.op.Op.get("tl.wait"), BinaryRelation.NE.value, address_of(ptr), expected, peer, scope.value, semantic.value) @@ -163,7 +167,8 @@ def wait_ge(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: MemoryScope = MemoryScope.SYSTEM, - semantic: MemorySemantic = MemorySemantic.ACQUIRE): + semantic: MemorySemantic = MemorySemantic.ACQUIRE, + dtype = "int32"): """Wait until *ptr >= expected Args: @@ -172,8 +177,10 @@ def wait_ge(ptr: PrimExpr, peer: The PE to wait on (-1 for local) scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) + dtype: The data type of the memory address + dtype: The data type of the memory address, must be int32 or uint32 """ - return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.GE.value, + return tir.call_intrin(dtype, tir.op.Op.get("tl.wait"), BinaryRelation.GE.value, address_of(ptr), expected, peer, scope.value, semantic.value) @@ -181,7 +188,8 @@ def wait_le(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: MemoryScope = MemoryScope.SYSTEM, - semantic: MemorySemantic = MemorySemantic.ACQUIRE): + semantic: MemorySemantic = MemorySemantic.ACQUIRE, + dtype = "int32"): """Wait until *ptr <= expected Args: @@ -190,8 +198,9 @@ def wait_le(ptr: PrimExpr, peer: The PE to wait on (-1 for local) scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) + dtype: The data type of the memory address, must be int32 or uint32 """ - return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.LE.value, + return tir.call_intrin(dtype, tir.op.Op.get("tl.wait"), BinaryRelation.LE.value, address_of(ptr), expected, peer, scope.value, semantic.value) @@ -199,7 +208,8 @@ def wait_gt(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: MemoryScope = MemoryScope.SYSTEM, - semantic: MemorySemantic = MemorySemantic.ACQUIRE): + semantic: MemorySemantic = MemorySemantic.ACQUIRE, + dtype = "int32"): """Wait until *ptr > expected Args: @@ -208,8 +218,9 @@ def wait_gt(ptr: PrimExpr, peer: The PE to wait on (-1 for local) scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) + dtype: The data type of the memory address, must be int32 or uint32 """ - return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.GT.value, + return tir.call_intrin(dtype, tir.op.Op.get("tl.wait"), BinaryRelation.GT.value, address_of(ptr), expected, peer, scope.value, semantic.value) @@ -217,7 +228,8 @@ def wait_lt(ptr: PrimExpr, expected: PrimExpr, peer: PrimExpr | None = -1, scope: MemoryScope = MemoryScope.SYSTEM, - semantic: MemorySemantic = MemorySemantic.ACQUIRE): + semantic: MemorySemantic = MemorySemantic.ACQUIRE, + dtype = "int32"): """Wait until *ptr < expected Args: @@ -226,6 +238,7 @@ def wait_lt(ptr: PrimExpr, peer: The PE to wait on (-1 for local) scope: Memory scope (GPU=volatile, SYSTEM=acquire.sys for cross-PE sync) semantic: Memory semantic (WEAK, VOLATILE, RELAXED, ACQUIRE, RELEASE, ACQ_REL) + dtype: The data type of the memory address, must be int32 or uint32 """ - return tir.call_intrin("handle", tir.op.Op.get("tl.wait"), BinaryRelation.LT.value, + return tir.call_intrin(dtype, tir.op.Op.get("tl.wait"), BinaryRelation.LT.value, address_of(ptr), expected, peer, scope.value, semantic.value) From 3845d36316e051f71eb06596a44f3d1fbf5246e5 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 10 Feb 2026 06:16:56 +0000 Subject: [PATCH 23/28] [Routing] New version of routing --- .../example_alltoall_route2x4_opt.py | 303 +++++++++++------- 1 file changed, 194 insertions(+), 109 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4_opt.py b/examples/distributed/intranode/example_alltoall_route2x4_opt.py index 8222a9921..6447410a2 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4_opt.py +++ b/examples/distributed/intranode/example_alltoall_route2x4_opt.py @@ -25,21 +25,23 @@ class Direction(IntEnum): debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" ) def torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, threads): - num_warps = threads // 32 block_M = M // num_blocks - # tile_M = block_M // num_warps # Number of slots num_slots = PE_num @T.prim_func - def main_route( + def main_route_opt( # For each (src, dst) pair, the real transfer size is M * N src: T.Tensor((PE_num * M, N), "float16"), dst: T.Tensor((PE_num * M, N), "float16"), # buffer[dst_rank, slots, *, *]: This PE save slots for transferring data chunks for the real destination rank buffer_transfer: T.Tensor((PE_num, num_slots, M, N), "float16"), + # slot_counter[dst_rank, num_blocks]: Counter for each slot in each block + slot_counter: T.Tensor((PE_num, num_blocks), "uint32"), # Signal for each dst buffer - signal_transfer: T.Tensor((PE_num, num_blocks, num_warps), "uint32"), + signal_transfer: T.Tensor((PE_num, num_blocks), "int32"), + # Src idx during transfer + src_transfer: T.Tensor((PE_num, num_slots, num_blocks), "uint32"), # Signal for finish local_finish: T.Tensor((1), "uint32"), global_finish: T.Tensor((1), "uint32"), @@ -48,7 +50,6 @@ def main_route( ): with T.Kernel(PE_num, num_blocks, threads=threads) as (bx, bz): tx = T.get_thread_binding() - warp_idx = tx // 32 rank = T.alloc_local([1], "uint32") rank_x = T.alloc_local([1], "uint32") @@ -57,9 +58,14 @@ def main_route( to_dir = T.alloc_local([1], "uint32") diff = T.alloc_local([1], "int32") dst_rank = T.alloc_local([1], "uint32") + old_counter = T.alloc_local([1], "uint32") + old_counter_shared = T.alloc_shared([PE_num, num_blocks], "uint32") + old_signal = T.alloc_local([1], "int32") + new_signal = T.alloc_local([1], "int32") + cur_counter = T.alloc_local([1], "uint32") + cur_counter_shared = T.alloc_shared([PE_num, num_blocks], "uint32") old_local = T.alloc_local([1], "uint32") old_global = T.alloc_local([1], "uint32") - num_tiles = T.alloc_local([1], "uint32") rank[0] = T.get_rank() rank_x[0] = T.floordiv(rank[0], Y) @@ -78,7 +84,7 @@ def main_route( diff[0] = dst_rank_x - rank_x[0] if diff[0] > T.floordiv(X, 2): diff[0] -= X - elif diff[0] <= -T.floordiv(X, 2): + elif diff[0] <= -T.ceildiv(X, 2): diff[0] += X if diff[0] < 0: @@ -94,7 +100,7 @@ def main_route( diff[0] = dst_rank_y - rank_y[0] if diff[0] > T.floordiv(Y, 2): diff[0] -= Y - elif diff[0] <= -T.floordiv(Y, 2): + elif diff[0] <= -T.ceildiv(Y, 2): diff[0] += Y if diff[0] < 0: @@ -109,124 +115,182 @@ def main_route( # Phase 1: Fully use all blocks to initially send from src to the target neighbor # Split the tile_M to each block - chunk_M = T.ceildiv(tile_M, PE_num) - chunk_start = bx * chunk_M - chunk_size = T.min(chunk_M, tile_M - chunk_start) - - # if src_rank[0] == rank[0]: if dst_rank[0] != rank[0]: - T.put_warp( - T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), - T.address_of(buffer_transfer[rank[0], dst_rank[0], bz * block_M + warp_idx * tile_M + chunk_start, 0]), - chunk_size * N, + if tx == 0: + old_counter[0] = T.atom_add_remote( + slot_counter[dst_rank[0], bz], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + old_counter_shared[dst_rank[0], bz] = old_counter[0] + T.sync_threads() + T.put_block( + T.address_of(src[dst_rank[0] * M + bz * block_M, 0]), + T.address_of(buffer_transfer[dst_rank[0], old_counter_shared[dst_rank[0], bz], bz * block_M, 0]), + block_M * N, next_rank[0], ) - if tx % 32 == 0: + if tx == 0: + # Write src idx to src_transfer + T.st( + src_transfer[dst_rank[0], old_counter[0], bz], + rank[0], + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + T.sync_threads() + T.fence_sys(sem=T.MemorySemantic.RELEASE) + if tx == 0: + # Write signal always after remote data is ready T.atom_add_remote( - signal_transfer[rank[0], dst_rank[0], bz, warp_idx], + signal_transfer[dst_rank[0], bz], 1, scope=T.MemoryScope.SYSTEM, sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0], ) - T.sync_warp() - # T.fence_sys(sem=T.MemorySemantic.RELEASE) + T.sync_threads() + T.fence_sys(sem=T.MemorySemantic.RELEASE) + # if tx == 0: + # T.print(rank[0], "rank") + # T.print(next_rank[0], "next rank") + # T.print(dst_rank[0], "dst rank") + # T.print(old_signal[0], "old signal") + # T.print(old_signal_shared[dst_rank[0], bz], "old signal shared") + # T.print(signal_transfer[dst_rank[0], bz], "signal transfer") else: - T.put_warp( - T.address_of(src[dst_rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), - T.address_of(dst[rank[0] * M + bz * block_M + warp_idx * tile_M + chunk_start, 0]), - chunk_size * N, + T.put_block( + T.address_of(src[dst_rank[0] * M + bz * block_M, 0]), + T.address_of(dst[rank[0] * M + bz * block_M, 0]), + block_M * N, -1, ) - if tx % 32 == 0: + if tx == 0: old_local[0] = T.atom_add( local_finish[0], 1, scope=T.MemoryScope.GPU, sem=T.MemorySemantic.RELEASE, ) - T.sync_warp() - # T.fence_cta(sem=T.MemorySemantic.RELEASE) + T.sync_threads() + T.fence_cta(sem=T.MemorySemantic.RELEASE) T.fence_sys(sem=T.MemorySemantic.RELEASE) - T.sync_threads() - - # Phase 2: Each block handles one final dst data in one direction buffer of current rank and check whether to transfer - if tx % 32 == 0: - T.wait_ge( - signal_transfer[bx, dst_rank[0], bz, warp_idx], PE_num, scope=T.MemoryScope.SYSTEM) - T.sync_warp() - - if signal_transfer[bx, dst_rank[0], bz, warp_idx] == PE_num and flag[0] == False: - flag[0] = True - # Handle the transfer signal - if dst_rank[0] != rank[0]: - T.put_warp( - T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), - T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), - tile_M * N, - dst_pe=next_rank[0], - ) - if tx % 32 == 0: - T.st( - signal_transfer[bx, dst_rank[0], bz, warp_idx], - PE_num, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, + + # Phase 2: Each block handles one final dst data in each slot of buffer of current rank and check whether to transfer + old_signal[0] = 0 + with T.While(signal_transfer[dst_rank[0], bz] >= old_signal[0]): + new_signal[0] = T.wait_ne(signal_transfer[dst_rank[0], bz], old_signal[0], scope=T.MemoryScope.GPU) + # if tx == 0: + # T.print(rank[0], "rank") + # T.print(new_signal[0], "new signal") + + # Termination signal is -1 + if new_signal[0] < old_signal[0]: + T.loop_break() + # We send all intermediate buffer according to the signal + for slot_idx in T.serial(old_signal[0], new_signal[0]): + src_idx = src_transfer[dst_rank[0], slot_idx, bz] + # Handle the transfer signal + if dst_rank[0] != rank[0]: + if tx == 0: + cur_counter[0] = T.atom_add_remote( + slot_counter[dst_rank[0], bz], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + cur_counter_shared[dst_rank[0], bz] = cur_counter[0] + T.sync_threads() + T.put_block( + T.address_of(buffer_transfer[dst_rank[0], slot_idx, bz * block_M, 0]), + T.address_of(buffer_transfer[dst_rank[0], cur_counter_shared[dst_rank[0], bz], bz * block_M, 0]), + block_M * N, dst_pe=next_rank[0], ) - # if bz == 0 and tx == 0: - # T.print(rank[0], "transfer rank") - T.sync_warp() - else: - # Current rank is the real destination of this chunk of data, the real source rank is the buffer index - T.put_warp( - T.address_of(buffer_transfer[bx, dst_rank[0], bz * block_M + warp_idx * tile_M, 0]), - T.address_of(dst[bx * M + bz * block_M + warp_idx * tile_M, 0]), - tile_M * N, - -1, - ) - T.sync_warp() - if tx % 32 == 0: - old_local[0] = T.atom_add( - local_finish[0], - 1, - scope=T.MemoryScope.GPU, - sem=T.MemorySemantic.RELEASE, + if tx == 0: + # Write src idx to src_transfer + T.st( + src_transfer[dst_rank[0], cur_counter[0], bz], + src_idx, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + T.sync_threads() + T.fence_sys(sem=T.MemorySemantic.RELEASE) + if tx == 0: + # Write signal always after remote data is ready + T.atom_add_remote( + signal_transfer[dst_rank[0], bz], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + T.sync_threads() + T.fence_sys(sem=T.MemorySemantic.RELEASE) + # if tx == 0: + # T.print(rank[0], "rank") + # T.print(slot_idx, "slot idx") + # T.print(next_rank[0], "transfer to rank") + else: + # Current rank is the real destination of this chunk of data, the real source rank is the buffer index + # if tx == 0: + # T.print(rank[0], "rank") + # T.print(src_transfer[dst_rank[0], slot_idx, bz], "src transfer") + # T.print(slot_idx, "slot idx") + T.put_block( + T.address_of(buffer_transfer[dst_rank[0], slot_idx, bz * block_M, 0]), + T.address_of(dst[src_transfer[dst_rank[0], slot_idx, bz] * M + bz * block_M, 0]), + block_M * N, + -1, ) - # if bz == 0 and tx == 0: - # T.print(rank[0], "dst rank") - if old_local[0] + 1 == (PE_num - 1) * num_tiles[0] + PE_num * num_tiles[0]: - for i in T.serial(PE_num): - old_global[0] = T.atom_add_remote( - global_finish[0], - 1, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=i, - ) - if old_global[0] + 1 == PE_num: - # Send termination signals to wake up all waiting blocks on all PEs - for remote_pe in T.serial(PE_num): - for src_rank_idx in T.serial(PE_num): + T.fence_cta(sem=T.MemorySemantic.RELEASE) + T.sync_threads() + if tx == 0: + old_local[0] = T.atom_add( + local_finish[0], + 1, + scope=T.MemoryScope.GPU, + sem=T.MemorySemantic.RELEASE, + ) + # if bz == 0 and tx == 0: + # T.print(rank[0], "dst rank") + if old_local[0] + 1 == PE_num * num_blocks: + for i in T.serial(PE_num): + old_global[0] = T.atom_add_remote( + global_finish[0], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=i, + ) + if old_global[0] + 1 == PE_num: + # Send termination signals to wake up all waiting blocks on all PEs + for remote_pe in T.serial(PE_num): for dst_rank_idx in T.serial(PE_num): for bz_idx in T.serial(num_blocks): - for dst_tile in T.serial(num_warps): - T.st( - signal_transfer[src_rank_idx, dst_rank_idx, - bz_idx, dst_tile], - PE_num + 1, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=remote_pe, - ) - T.sync_warp() - - T.sync_threads() + T.st( + signal_transfer[dst_rank_idx, bz_idx], + -1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=remote_pe, + ) + T.sync_threads() + + # Update old signal for next iteration + old_signal[0] = new_signal[0] + T.sync_threads() T.fence_sys(sem=T.MemorySemantic.RELEASE) - return main_route + return main_route_opt def run_torus_alltoall(local_rank, num_ranks, args): @@ -234,16 +298,15 @@ def run_torus_alltoall(local_rank, num_ranks, args): PE_num = args.PE_num X, Y = args.X, args.Y M, N = args.M, args.N - block_M, block_N = M // 2, N - threads = 256 + block_M, block_N = M, N + threads = 128 num_blocks = M // block_M num_blocks = min(num_blocks, NUM_SM // (PE_num * PE_num)) - num_warps = threads // 32 local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) allocator = tilelang.get_allocator( - size=2**35, + size=2**34, device="cuda", is_distributed=True, local_rank=local_rank, @@ -262,12 +325,18 @@ def run_torus_alltoall(local_rank, num_ranks, args): torch.cuda.synchronize() buffer_transfer = tilelang.tensor((PE_num, PE_num, M, N), torch.float16, - allocator=allocator).fill_(0) + allocator=allocator).fill_(-1) # Signals for each buffer slot in each direction - signal_transfer = tilelang.tensor((PE_num, PE_num, num_blocks, num_warps), - torch.uint32, + slot_counter = tilelang.tensor((PE_num, num_blocks), torch.uint32, allocator=allocator).fill_(0) + signal_transfer = tilelang.tensor((PE_num, num_blocks), + torch.int32, allocator=allocator).fill_(0) + + src_transfer = tilelang.tensor((PE_num, PE_num, num_blocks), + torch.uint32, + allocator=allocator).fill_(PE_num) + local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() @@ -275,7 +344,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): torch.cuda.synchronize() dist.barrier(group_size) - kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, src_transfer, local_finish, global_finish, barrier) torch.cuda.synchronize() dist.barrier(group_size) @@ -299,18 +368,34 @@ def run_torus_alltoall(local_rank, num_ranks, args): start_idx = rank_idx * M # Check first element of each rank's buffer print( - f"Rank {local_rank} Buffer[{rank_idx}][0,0]: dst={buffer_transfer[rank_idx, local_rank, 0, 0].item():.6f}, dst_ref={dst_ref[start_idx, 0].item():.6f}" + f"Rank {local_rank} Buffer[{rank_idx}][0,0]: dst={dst[start_idx, 0].item():.6f}, dst_ref={dst_ref[start_idx, 0].item():.6f}" + ) + for slot in range(PE_num): + print( + f"Rank {local_rank} slot={slot}: buffer={buffer_transfer[local_rank, slot, 0, 0].item():.6f}, src_transfer={src_transfer[local_rank, slot, 0].item():.6f}" + ) + # Print first 10 differences with their coordinates + print(f"Rank {local_rank} first 10 differences:") + diff_indices = torch.nonzero(diff_mask, as_tuple=False) + for i in range(min(10, diff_indices.shape[0])): + idx = diff_indices[i] + row, col = idx[0].item(), idx[1].item() + print( + f"Rank {local_rank} Diff[{i}] at ({row}, {col}): dst={dst[row, col].item():.6f}, dst_ref={dst_ref[row, col].item():.6f}" ) + # print(f"Rank {local_rank} slot_counter: {slot_counter}, signal_transfer: {signal_transfer}, src_transfer: {src_transfer}") if args.benchmark: # Warmup for _ in range(args.warmup): - kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, local_finish, global_finish, barrier) torch.cuda.synchronize() dist.barrier(group_size) # Reinitialize buffer_transfer.zero_() + src_transfer.fill_(PE_num) + slot_counter.zero_() signal_transfer.zero_() local_finish.zero_() global_finish.zero_() @@ -323,7 +408,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): for _ in range(num_iters): # torch.cuda.profiler.start() # with torch.cuda.nvtx.range("alltoall_xy_routing_benchmark"): - kernel(src, dst, buffer_transfer, signal_transfer, local_finish, global_finish, barrier) + kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, local_finish, global_finish, barrier) # torch.cuda.profiler.stop() torch.cuda.synchronize() dist.barrier(group_size) From 6f570f443fafdaea75d55d61c628aadea840af34 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 10 Feb 2026 06:48:17 +0000 Subject: [PATCH 24/28] [BugFix] Transfer source index before put data --- .../example_alltoall_route2x4_opt.py | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4_opt.py b/examples/distributed/intranode/example_alltoall_route2x4_opt.py index 6447410a2..b6b939bdc 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4_opt.py +++ b/examples/distributed/intranode/example_alltoall_route2x4_opt.py @@ -124,15 +124,6 @@ def main_route_opt( sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0], ) - old_counter_shared[dst_rank[0], bz] = old_counter[0] - T.sync_threads() - T.put_block( - T.address_of(src[dst_rank[0] * M + bz * block_M, 0]), - T.address_of(buffer_transfer[dst_rank[0], old_counter_shared[dst_rank[0], bz], bz * block_M, 0]), - block_M * N, - next_rank[0], - ) - if tx == 0: # Write src idx to src_transfer T.st( src_transfer[dst_rank[0], old_counter[0], bz], @@ -141,7 +132,14 @@ def main_route_opt( sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0], ) + old_counter_shared[dst_rank[0], bz] = old_counter[0] T.sync_threads() + T.put_block( + T.address_of(src[dst_rank[0] * M + bz * block_M, 0]), + T.address_of(buffer_transfer[dst_rank[0], old_counter_shared[dst_rank[0], bz], bz * block_M, 0]), + block_M * N, + next_rank[0], + ) T.fence_sys(sem=T.MemorySemantic.RELEASE) if tx == 0: # Write signal always after remote data is ready @@ -204,15 +202,6 @@ def main_route_opt( sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0], ) - cur_counter_shared[dst_rank[0], bz] = cur_counter[0] - T.sync_threads() - T.put_block( - T.address_of(buffer_transfer[dst_rank[0], slot_idx, bz * block_M, 0]), - T.address_of(buffer_transfer[dst_rank[0], cur_counter_shared[dst_rank[0], bz], bz * block_M, 0]), - block_M * N, - dst_pe=next_rank[0], - ) - if tx == 0: # Write src idx to src_transfer T.st( src_transfer[dst_rank[0], cur_counter[0], bz], @@ -221,7 +210,14 @@ def main_route_opt( sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0], ) + cur_counter_shared[dst_rank[0], bz] = cur_counter[0] T.sync_threads() + T.put_block( + T.address_of(buffer_transfer[dst_rank[0], slot_idx, bz * block_M, 0]), + T.address_of(buffer_transfer[dst_rank[0], cur_counter_shared[dst_rank[0], bz], bz * block_M, 0]), + block_M * N, + dst_pe=next_rank[0], + ) T.fence_sys(sem=T.MemorySemantic.RELEASE) if tx == 0: # Write signal always after remote data is ready From 5abf3f1ffb689537581fad5db44ea3ad82a64586 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 10 Feb 2026 08:24:09 +0000 Subject: [PATCH 25/28] [BugFix] Interface for benchmark --- .../intranode/example_alltoall_route2x4.py | 4 +-- .../example_alltoall_route2x4_opt.py | 25 +++++++++++++------ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4.py b/examples/distributed/intranode/example_alltoall_route2x4.py index a16a8d2de..0c1d73aa2 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4.py +++ b/examples/distributed/intranode/example_alltoall_route2x4.py @@ -239,8 +239,8 @@ def run_torus_alltoall(local_rank, num_ranks, args): PE_num = args.PE_num X, Y = args.X, args.Y M, N = args.M, args.N - block_M, block_N = M // 2, N - threads = 256 + block_M, block_N = M // 4, N + threads = 128 num_blocks = M // block_M num_blocks = min(num_blocks, NUM_SM // (PE_num * PE_num)) diff --git a/examples/distributed/intranode/example_alltoall_route2x4_opt.py b/examples/distributed/intranode/example_alltoall_route2x4_opt.py index b6b939bdc..df159f4d6 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4_opt.py +++ b/examples/distributed/intranode/example_alltoall_route2x4_opt.py @@ -294,11 +294,12 @@ def run_torus_alltoall(local_rank, num_ranks, args): PE_num = args.PE_num X, Y = args.X, args.Y M, N = args.M, args.N - block_M, block_N = M, N - threads = 128 + blocks = args.blocks + block_M, block_N = M // blocks, N + threads = 512 - num_blocks = M // block_M - num_blocks = min(num_blocks, NUM_SM // (PE_num * PE_num)) + num_blocks = blocks + num_blocks = min(num_blocks, NUM_SM // PE_num) local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) allocator = tilelang.get_allocator( @@ -384,12 +385,19 @@ def run_torus_alltoall(local_rank, num_ranks, args): if args.benchmark: # Warmup for _ in range(args.warmup): - kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, local_finish, global_finish, barrier) - torch.cuda.synchronize() + # Reinitialize + buffer_transfer.fill_(-1) + src_transfer.fill_(PE_num) + slot_counter.zero_() + signal_transfer.zero_() + local_finish.zero_() + global_finish.zero_() + kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, src_transfer, local_finish, global_finish, barrier) + torch.cuda.synchronize() dist.barrier(group_size) # Reinitialize - buffer_transfer.zero_() + buffer_transfer.fill_(-1) src_transfer.fill_(PE_num) slot_counter.zero_() signal_transfer.zero_() @@ -404,7 +412,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): for _ in range(num_iters): # torch.cuda.profiler.start() # with torch.cuda.nvtx.range("alltoall_xy_routing_benchmark"): - kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, local_finish, global_finish, barrier) + kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, src_transfer, local_finish, global_finish, barrier) # torch.cuda.profiler.stop() torch.cuda.synchronize() dist.barrier(group_size) @@ -432,6 +440,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): parser.add_argument("--benchmark", action="store_true", help="Run benchmark") parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") parser.add_argument("--iter", type=int, default=10, help="Number of benchmark iterations") + parser.add_argument("--blocks", type=int, default=1, help="Number of blocks") args = parser.parse_args() torch.multiprocessing.spawn(run_torus_alltoall, args=(args.PE_num, args), nprocs=args.PE_num) From a9dae4c57a3eddcfa557c043f924091b38fc94cb Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Tue, 10 Feb 2026 09:30:33 +0000 Subject: [PATCH 26/28] [Misc] Remove log --- .../example_alltoall_route2x4_opt.py | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4_opt.py b/examples/distributed/intranode/example_alltoall_route2x4_opt.py index df159f4d6..06cd32f49 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4_opt.py +++ b/examples/distributed/intranode/example_alltoall_route2x4_opt.py @@ -152,13 +152,6 @@ def main_route_opt( ) T.sync_threads() T.fence_sys(sem=T.MemorySemantic.RELEASE) - # if tx == 0: - # T.print(rank[0], "rank") - # T.print(next_rank[0], "next rank") - # T.print(dst_rank[0], "dst rank") - # T.print(old_signal[0], "old signal") - # T.print(old_signal_shared[dst_rank[0], bz], "old signal shared") - # T.print(signal_transfer[dst_rank[0], bz], "signal transfer") else: T.put_block( T.address_of(src[dst_rank[0] * M + bz * block_M, 0]), @@ -182,9 +175,6 @@ def main_route_opt( old_signal[0] = 0 with T.While(signal_transfer[dst_rank[0], bz] >= old_signal[0]): new_signal[0] = T.wait_ne(signal_transfer[dst_rank[0], bz], old_signal[0], scope=T.MemoryScope.GPU) - # if tx == 0: - # T.print(rank[0], "rank") - # T.print(new_signal[0], "new signal") # Termination signal is -1 if new_signal[0] < old_signal[0]: @@ -230,16 +220,8 @@ def main_route_opt( ) T.sync_threads() T.fence_sys(sem=T.MemorySemantic.RELEASE) - # if tx == 0: - # T.print(rank[0], "rank") - # T.print(slot_idx, "slot idx") - # T.print(next_rank[0], "transfer to rank") else: # Current rank is the real destination of this chunk of data, the real source rank is the buffer index - # if tx == 0: - # T.print(rank[0], "rank") - # T.print(src_transfer[dst_rank[0], slot_idx, bz], "src transfer") - # T.print(slot_idx, "slot idx") T.put_block( T.address_of(buffer_transfer[dst_rank[0], slot_idx, bz * block_M, 0]), T.address_of(dst[src_transfer[dst_rank[0], slot_idx, bz] * M + bz * block_M, 0]), @@ -255,8 +237,6 @@ def main_route_opt( scope=T.MemoryScope.GPU, sem=T.MemorySemantic.RELEASE, ) - # if bz == 0 and tx == 0: - # T.print(rank[0], "dst rank") if old_local[0] + 1 == PE_num * num_blocks: for i in T.serial(PE_num): old_global[0] = T.atom_add_remote( @@ -295,7 +275,6 @@ def run_torus_alltoall(local_rank, num_ranks, args): X, Y = args.X, args.Y M, N = args.M, args.N blocks = args.blocks - block_M, block_N = M // blocks, N threads = 512 num_blocks = blocks From 0d1dc1cc0c1db2fd2edab6cb0802e818f2569026 Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Sat, 14 Feb 2026 01:53:29 +0000 Subject: [PATCH 27/28] [BugFix] Warp level communication with robust per-slot signal --- .../example_alltoall_route2x4_opt.py | 282 ++++++++++-------- 1 file changed, 154 insertions(+), 128 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall_route2x4_opt.py b/examples/distributed/intranode/example_alltoall_route2x4_opt.py index 06cd32f49..7fa1305b8 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4_opt.py +++ b/examples/distributed/intranode/example_alltoall_route2x4_opt.py @@ -6,7 +6,7 @@ import argparse from enum import IntEnum -tilelang.disable_cache() +# tilelang.disable_cache() class Direction(IntEnum): @@ -22,10 +22,11 @@ class Direction(IntEnum): tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }, - debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" + # debug_root_path="/home/zhengju.tang/tilescale/examples/distributed/debug/" ) -def torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, threads): +def torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, num_warps, threads): block_M = M // num_blocks + tile_M = block_M // num_warps # Number of slots num_slots = PE_num @@ -36,12 +37,13 @@ def main_route_opt( dst: T.Tensor((PE_num * M, N), "float16"), # buffer[dst_rank, slots, *, *]: This PE save slots for transferring data chunks for the real destination rank buffer_transfer: T.Tensor((PE_num, num_slots, M, N), "float16"), - # slot_counter[dst_rank, num_blocks]: Counter for each slot in each block - slot_counter: T.Tensor((PE_num, num_blocks), "uint32"), - # Signal for each dst buffer - signal_transfer: T.Tensor((PE_num, num_blocks), "int32"), + # slot_counter[dst_rank, num_blocks, num_warps]: Counter for allocating slot indices + slot_counter: T.Tensor((PE_num, num_blocks, num_warps), "uint32"), + # Per-slot ready flag: 0=not ready, 1=data ready, -1=termination (no more slots) + # signal_transfer[dst_rank, slot, num_blocks, num_warps] + signal_transfer: T.Tensor((PE_num, num_slots, num_blocks, num_warps), "int32"), # Src idx during transfer - src_transfer: T.Tensor((PE_num, num_slots, num_blocks), "uint32"), + src_transfer: T.Tensor((PE_num, num_slots, num_blocks, num_warps), "uint32"), # Signal for finish local_finish: T.Tensor((1), "uint32"), global_finish: T.Tensor((1), "uint32"), @@ -50,6 +52,7 @@ def main_route_opt( ): with T.Kernel(PE_num, num_blocks, threads=threads) as (bx, bz): tx = T.get_thread_binding() + warp_id = tx // 32 rank = T.alloc_local([1], "uint32") rank_x = T.alloc_local([1], "uint32") @@ -59,13 +62,13 @@ def main_route_opt( diff = T.alloc_local([1], "int32") dst_rank = T.alloc_local([1], "uint32") old_counter = T.alloc_local([1], "uint32") - old_counter_shared = T.alloc_shared([PE_num, num_blocks], "uint32") - old_signal = T.alloc_local([1], "int32") - new_signal = T.alloc_local([1], "int32") + old_counter_shared = T.alloc_shared([PE_num, num_blocks, num_warps], "uint32") cur_counter = T.alloc_local([1], "uint32") - cur_counter_shared = T.alloc_shared([PE_num, num_blocks], "uint32") + cur_counter_shared = T.alloc_shared([PE_num, num_blocks, num_warps], "uint32") old_local = T.alloc_local([1], "uint32") old_global = T.alloc_local([1], "uint32") + slot_flag = T.alloc_local([1], "int32") + slot_flag_shared = T.alloc_shared([num_warps], "int32") rank[0] = T.get_rank() rank_x[0] = T.floordiv(rank[0], Y) @@ -115,10 +118,11 @@ def main_route_opt( # Phase 1: Fully use all blocks to initially send from src to the target neighbor # Split the tile_M to each block + chunk_M = T.min(block_M - warp_id * tile_M, tile_M) if dst_rank[0] != rank[0]: - if tx == 0: + if tx % 32 == 0: old_counter[0] = T.atom_add_remote( - slot_counter[dst_rank[0], bz], + slot_counter[dst_rank[0], bz, warp_id], 1, scope=T.MemoryScope.SYSTEM, sem=T.MemorySemantic.RELEASE, @@ -126,145 +130,165 @@ def main_route_opt( ) # Write src idx to src_transfer T.st( - src_transfer[dst_rank[0], old_counter[0], bz], + src_transfer[dst_rank[0], old_counter[0], bz, warp_id], rank[0], scope=T.MemoryScope.SYSTEM, sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0], ) - old_counter_shared[dst_rank[0], bz] = old_counter[0] - T.sync_threads() - T.put_block( - T.address_of(src[dst_rank[0] * M + bz * block_M, 0]), - T.address_of(buffer_transfer[dst_rank[0], old_counter_shared[dst_rank[0], bz], bz * block_M, 0]), - block_M * N, + old_counter_shared[dst_rank[0], bz, warp_id] = old_counter[0] + T.sync_warp() + T.put_warp( + T.address_of(src[dst_rank[0] * M + bz * block_M + warp_id * tile_M, 0]), + T.address_of(buffer_transfer[dst_rank[0], old_counter_shared[dst_rank[0], bz, warp_id], bz * block_M + warp_id * tile_M, 0]), + chunk_M * N, next_rank[0], ) T.fence_sys(sem=T.MemorySemantic.RELEASE) - if tx == 0: - # Write signal always after remote data is ready - T.atom_add_remote( - signal_transfer[dst_rank[0], bz], + # Ensure ALL lanes' put_warp writes + fence are complete before lane 0 sends signal + T.sync_warp() + if tx % 32 == 0: + # Set per-slot ready flag: this specific slot is now ready + T.st( + signal_transfer[dst_rank[0], old_counter_shared[dst_rank[0], bz, warp_id], bz, warp_id], 1, scope=T.MemoryScope.SYSTEM, sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0], ) - T.sync_threads() + T.sync_warp() T.fence_sys(sem=T.MemorySemantic.RELEASE) else: - T.put_block( - T.address_of(src[dst_rank[0] * M + bz * block_M, 0]), - T.address_of(dst[rank[0] * M + bz * block_M, 0]), - block_M * N, + T.put_warp( + T.address_of(src[dst_rank[0] * M + bz * block_M + warp_id * tile_M, 0]), + T.address_of(dst[rank[0] * M + bz * block_M + warp_id * tile_M, 0]), + chunk_M * N, -1, ) - if tx == 0: + if tx % 32 == 0: old_local[0] = T.atom_add( local_finish[0], 1, scope=T.MemoryScope.GPU, sem=T.MemorySemantic.RELEASE, ) - T.sync_threads() + T.sync_warp() T.fence_cta(sem=T.MemorySemantic.RELEASE) + T.sync_threads() T.fence_sys(sem=T.MemorySemantic.RELEASE) - # Phase 2: Each block handles one final dst data in each slot of buffer of current rank and check whether to transfer - old_signal[0] = 0 - with T.While(signal_transfer[dst_rank[0], bz] >= old_signal[0]): - new_signal[0] = T.wait_ne(signal_transfer[dst_rank[0], bz], old_signal[0], scope=T.MemoryScope.GPU) + # Phase 2: Poll per-slot ready flags sequentially and process each slot + # Each warp checks slot 0, 1, 2, ... for its (dst_rank, bz, warp_id). + # flag=0 means not ready yet (spin), flag=1 means data ready, flag=-1 means no more slots. + for slot_idx in T.serial(num_slots): + # Wait for this slot's flag to become non-zero + if tx % 32 == 0: + slot_flag[0] = T.wait_ne( + signal_transfer[dst_rank[0], slot_idx, bz, warp_id], + 0, + scope=T.MemoryScope.SYSTEM, + ) + slot_flag_shared[warp_id] = slot_flag[0] + T.sync_warp() - # Termination signal is -1 - if new_signal[0] < old_signal[0]: + # Termination: flag == -1 means no more slots to process + if slot_flag_shared[warp_id] < 0: T.loop_break() - # We send all intermediate buffer according to the signal - for slot_idx in T.serial(old_signal[0], new_signal[0]): - src_idx = src_transfer[dst_rank[0], slot_idx, bz] - # Handle the transfer signal - if dst_rank[0] != rank[0]: - if tx == 0: - cur_counter[0] = T.atom_add_remote( - slot_counter[dst_rank[0], bz], - 1, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=next_rank[0], - ) - # Write src idx to src_transfer - T.st( - src_transfer[dst_rank[0], cur_counter[0], bz], - src_idx, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=next_rank[0], - ) - cur_counter_shared[dst_rank[0], bz] = cur_counter[0] - T.sync_threads() - T.put_block( - T.address_of(buffer_transfer[dst_rank[0], slot_idx, bz * block_M, 0]), - T.address_of(buffer_transfer[dst_rank[0], cur_counter_shared[dst_rank[0], bz], bz * block_M, 0]), - block_M * N, + + src_idx = src_transfer[dst_rank[0], slot_idx, bz, warp_id] + # Handle the transfer + if dst_rank[0] != rank[0]: + # Forward to next hop + if tx % 32 == 0: + cur_counter[0] = T.atom_add_remote( + slot_counter[dst_rank[0], bz, warp_id], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, dst_pe=next_rank[0], ) - T.fence_sys(sem=T.MemorySemantic.RELEASE) - if tx == 0: - # Write signal always after remote data is ready - T.atom_add_remote( - signal_transfer[dst_rank[0], bz], - 1, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=next_rank[0], - ) - T.sync_threads() - T.fence_sys(sem=T.MemorySemantic.RELEASE) - else: - # Current rank is the real destination of this chunk of data, the real source rank is the buffer index - T.put_block( - T.address_of(buffer_transfer[dst_rank[0], slot_idx, bz * block_M, 0]), - T.address_of(dst[src_transfer[dst_rank[0], slot_idx, bz] * M + bz * block_M, 0]), - block_M * N, - -1, + # Write src idx to src_transfer on next hop + T.st( + src_transfer[dst_rank[0], cur_counter[0], bz, warp_id], + src_idx, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], ) - T.fence_cta(sem=T.MemorySemantic.RELEASE) - T.sync_threads() - if tx == 0: - old_local[0] = T.atom_add( - local_finish[0], - 1, - scope=T.MemoryScope.GPU, - sem=T.MemorySemantic.RELEASE, - ) - if old_local[0] + 1 == PE_num * num_blocks: - for i in T.serial(PE_num): - old_global[0] = T.atom_add_remote( - global_finish[0], - 1, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=i, - ) - if old_global[0] + 1 == PE_num: - # Send termination signals to wake up all waiting blocks on all PEs - for remote_pe in T.serial(PE_num): - for dst_rank_idx in T.serial(PE_num): - for bz_idx in T.serial(num_blocks): - T.st( - signal_transfer[dst_rank_idx, bz_idx], - -1, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=remote_pe, - ) - T.sync_threads() - - # Update old signal for next iteration - old_signal[0] = new_signal[0] - T.sync_threads() - - T.fence_sys(sem=T.MemorySemantic.RELEASE) + cur_counter_shared[dst_rank[0], bz, warp_id] = cur_counter[0] + T.sync_warp() + T.put_warp( + T.address_of(buffer_transfer[dst_rank[0], slot_idx, bz * block_M + warp_id * tile_M, 0]), + T.address_of(buffer_transfer[dst_rank[0], cur_counter_shared[dst_rank[0], bz, warp_id], bz * block_M + warp_id * tile_M, 0]), + chunk_M * N, + dst_pe=next_rank[0], + ) + T.fence_sys(sem=T.MemorySemantic.RELEASE) + T.sync_warp() + if tx % 32 == 0: + # Set per-slot ready flag on next hop + T.st( + signal_transfer[dst_rank[0], cur_counter_shared[dst_rank[0], bz, warp_id], bz, warp_id], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=next_rank[0], + ) + T.sync_warp() + T.fence_sys(sem=T.MemorySemantic.RELEASE) + else: + # Final destination: copy from buffer to dst + T.put_warp( + T.address_of(buffer_transfer[dst_rank[0], slot_idx, bz * block_M + warp_id * tile_M, 0]), + T.address_of(dst[src_transfer[dst_rank[0], slot_idx, bz, warp_id] * M + bz * block_M + warp_id * tile_M, 0]), + chunk_M * N, + -1, + ) + T.fence_cta(sem=T.MemorySemantic.RELEASE) + T.sync_warp() + if tx % 32 == 0: + old_local[0] = T.atom_add( + local_finish[0], + 1, + scope=T.MemoryScope.GPU, + sem=T.MemorySemantic.RELEASE, + ) + if old_local[0] + 1 == PE_num * num_blocks * num_warps: + # All local data received, notify all PEs + for i in T.serial(PE_num): + old_global[0] = T.atom_add_remote( + global_finish[0], + 1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=i, + ) + if old_global[0] + 1 == PE_num: + # Last PE to finish: send termination flags to all waiting slots on all PEs + # For each (dst_rank, bz, warp_id), we need to set the flag of the + # next unused slot to -1 so the receiver knows to stop polling. + for remote_pe in T.serial(PE_num): + for dst_rank_idx in T.serial(PE_num): + for bz_idx in T.serial(num_blocks): + for warp_idx in T.serial(num_warps): + # The next unused slot index is slot_counter[dst_rank, bz, warp] + # We write -1 to signal_transfer[dst_rank, slot_counter_val, bz, warp] + # But reading slot_counter from remote is complex. + # Instead, write -1 to ALL remaining slots (from 0 to num_slots-1). + # Slots already processed (flag=1) won't be re-checked. + # The receiver will hit -1 at the first unwritten slot. + for slot_i in T.serial(num_slots): + T.st( + signal_transfer[dst_rank_idx, slot_i, bz_idx, warp_idx], + -1, + scope=T.MemoryScope.SYSTEM, + sem=T.MemorySemantic.RELEASE, + dst_pe=remote_pe, + ) + T.sync_warp() + + T.sync_warp() return main_route_opt @@ -275,7 +299,9 @@ def run_torus_alltoall(local_rank, num_ranks, args): X, Y = args.X, args.Y M, N = args.M, args.N blocks = args.blocks - threads = 512 + threads = 256 + assert threads % 32 == 0, "threads must be divisible by 32" + num_warps = threads // 32 num_blocks = blocks num_blocks = min(num_blocks, NUM_SM // PE_num) @@ -290,7 +316,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): group=group_size, ) - kernel = torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, threads) + kernel = torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, num_warps, threads) kernel.initialize(allocator=allocator) src = tilelang.tensor((PE_num * M, N), torch.float16, allocator=allocator).random_() @@ -304,12 +330,13 @@ def run_torus_alltoall(local_rank, num_ranks, args): allocator=allocator).fill_(-1) # Signals for each buffer slot in each direction - slot_counter = tilelang.tensor((PE_num, num_blocks), torch.uint32, allocator=allocator).fill_(0) - signal_transfer = tilelang.tensor((PE_num, num_blocks), + slot_counter = tilelang.tensor((PE_num, num_blocks, num_warps), torch.uint32, allocator=allocator).fill_(0) + # Per-slot ready flags: 0=not ready, 1=ready, -1=termination + signal_transfer = tilelang.tensor((PE_num, PE_num, num_blocks, num_warps), torch.int32, allocator=allocator).fill_(0) - src_transfer = tilelang.tensor((PE_num, PE_num, num_blocks), + src_transfer = tilelang.tensor((PE_num, PE_num, num_blocks, num_warps), torch.uint32, allocator=allocator).fill_(PE_num) @@ -348,7 +375,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): ) for slot in range(PE_num): print( - f"Rank {local_rank} slot={slot}: buffer={buffer_transfer[local_rank, slot, 0, 0].item():.6f}, src_transfer={src_transfer[local_rank, slot, 0].item():.6f}" + f"Rank {local_rank} slot={slot}: buffer={buffer_transfer[local_rank, slot, 0, 0].item():.6f}, src_transfer={src_transfer[local_rank, slot, 0, 0].item():.6f}" ) # Print first 10 differences with their coordinates print(f"Rank {local_rank} first 10 differences:") @@ -359,7 +386,6 @@ def run_torus_alltoall(local_rank, num_ranks, args): print( f"Rank {local_rank} Diff[{i}] at ({row}, {col}): dst={dst[row, col].item():.6f}, dst_ref={dst_ref[row, col].item():.6f}" ) - # print(f"Rank {local_rank} slot_counter: {slot_counter}, signal_transfer: {signal_transfer}, src_transfer: {src_transfer}") if args.benchmark: # Warmup From baf1fc4aee3ac9fe9068bb01bead58e1e62920bb Mon Sep 17 00:00:00 2001 From: tzj-fxz Date: Sat, 14 Feb 2026 11:40:19 +0000 Subject: [PATCH 28/28] [Routing] AOT routing and signal slot assignment --- .../distributed/intranode/example_alltoall.py | 4 +- .../example_alltoall_route2x4_opt.py | 193 +++++++++++------- 2 files changed, 116 insertions(+), 81 deletions(-) diff --git a/examples/distributed/intranode/example_alltoall.py b/examples/distributed/intranode/example_alltoall.py index d863da83f..b2bba1493 100644 --- a/examples/distributed/intranode/example_alltoall.py +++ b/examples/distributed/intranode/example_alltoall.py @@ -47,7 +47,7 @@ def run_alltoall(local_rank, num_ranks, args): local_rank, num_ranks, group_size = init_dist(local_rank, num_ranks) allocator = tilelang.get_allocator( - size=2**35, + size=2**34, device="cuda", is_distributed=True, local_rank=local_rank, @@ -66,6 +66,7 @@ def run_alltoall(local_rank, num_ranks, args): # Warmup for _ in range(args.warmup): kernel(src, dst, barrier) + dst.zero_() torch.cuda.synchronize() dist.barrier(group_size) @@ -90,7 +91,6 @@ def run_alltoall(local_rank, num_ranks, args): dist.all_to_all_single(dst_ref, src, group=group_size) torch.cuda.synchronize() - # 对比结果 if torch.allclose(dst, dst_ref, atol=1e-2, rtol=1e-2): print(f"Rank {local_rank} Verification Passed! ✅") else: diff --git a/examples/distributed/intranode/example_alltoall_route2x4_opt.py b/examples/distributed/intranode/example_alltoall_route2x4_opt.py index 7fa1305b8..e3ce0f471 100644 --- a/examples/distributed/intranode/example_alltoall_route2x4_opt.py +++ b/examples/distributed/intranode/example_alltoall_route2x4_opt.py @@ -17,6 +17,81 @@ class Direction(IntEnum): SELF = 4 +def compute_next_hop(src_rank, dst_rank, X, Y): + """Compute the next hop from src_rank towards dst_rank using XY-routing on a 2D torus.""" + if src_rank == dst_rank: + return src_rank # self, no hop needed + + src_x, src_y = src_rank // Y, src_rank % Y + dst_x, dst_y = dst_rank // Y, dst_rank % Y + + if dst_x != src_x: + # Route along X-axis first + diff = dst_x - src_x + if diff > X // 2: + diff -= X + elif diff <= -X // 2: + diff += X + + if diff < 0: + # North + next_x = (src_x + X - 1) % X + return next_x * Y + src_y + else: + # South + next_x = (src_x + 1) % X + return next_x * Y + src_y + else: + # Route along Y-axis + diff = dst_y - src_y + if diff > Y // 2: + diff -= Y + elif diff <= -Y // 2: + diff += Y + + if diff < 0: + # West + next_y = (src_y + Y - 1) % Y + return src_x * Y + next_y + else: + # East + next_y = (src_y + 1) % Y + return src_x * Y + next_y + + +def compute_expected_slots(PE_num, X, Y): + """ + Pre-compute the number of incoming slots each PE will receive for each dst_rank, + tracing the full XY-routing path through the torus. + + Returns a list of lists of shape (PE_num, PE_num) where expected_slots[me][dst_rank] is + the number of data chunks that PE 'me' will receive in its buffer for destination 'dst_rank'. + + This includes both: + - Final reception: dst_rank == me, each other PE sends one chunk → PE_num - 1 slots + - Forwarding: dst_rank != me, data passes through me on the way to dst_rank + + For each (src, final_dst) pair where src != final_dst, we trace the full multi-hop + XY-routing path. Every intermediate PE and the final destination PE each receive one slot. + """ + # expected_slots[receiver_pe][dst_rank] = count of incoming slots + expected = [[0] * PE_num for _ in range(PE_num)] + + for src in range(PE_num): + for final_dst in range(PE_num): + if src == final_dst: + continue + # Trace the full path from src to final_dst + current = src + while current != final_dst: + next_hop = compute_next_hop(current, final_dst, X, Y) + # next_hop receives one slot for dst_rank=final_dst + expected[next_hop][final_dst] += 1 + current = next_hop + + return expected + + @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, @@ -39,16 +114,14 @@ def main_route_opt( buffer_transfer: T.Tensor((PE_num, num_slots, M, N), "float16"), # slot_counter[dst_rank, num_blocks, num_warps]: Counter for allocating slot indices slot_counter: T.Tensor((PE_num, num_blocks, num_warps), "uint32"), - # Per-slot ready flag: 0=not ready, 1=data ready, -1=termination (no more slots) + # Per-slot ready flag: 0=not ready, 1=data ready # signal_transfer[dst_rank, slot, num_blocks, num_warps] signal_transfer: T.Tensor((PE_num, num_slots, num_blocks, num_warps), "int32"), # Src idx during transfer src_transfer: T.Tensor((PE_num, num_slots, num_blocks, num_warps), "uint32"), - # Signal for finish - local_finish: T.Tensor((1), "uint32"), - global_finish: T.Tensor((1), "uint32"), - # Barrier for all blocks - barrier: T.Tensor((PE_num), "int32"), + # Pre-computed expected incoming slot counts per (PE, dst_rank) + # expected_slots[pe, dst_rank]: how many slots PE 'pe' will receive for dst_rank + expected_slots: T.Tensor((PE_num, PE_num), "int32"), ): with T.Kernel(PE_num, num_blocks, threads=threads) as (bx, bz): tx = T.get_thread_binding() @@ -65,10 +138,9 @@ def main_route_opt( old_counter_shared = T.alloc_shared([PE_num, num_blocks, num_warps], "uint32") cur_counter = T.alloc_local([1], "uint32") cur_counter_shared = T.alloc_shared([PE_num, num_blocks, num_warps], "uint32") - old_local = T.alloc_local([1], "uint32") - old_global = T.alloc_local([1], "uint32") slot_flag = T.alloc_local([1], "int32") slot_flag_shared = T.alloc_shared([num_warps], "int32") + num_expected = T.alloc_local([1], "int32") rank[0] = T.get_rank() rank_x[0] = T.floordiv(rank[0], Y) @@ -77,6 +149,9 @@ def main_route_opt( dst_rank[0] = bx + # Read pre-computed expected slot count for this PE and dst_rank + num_expected[0] = expected_slots[rank[0], dst_rank[0]] + # Prepare for routing dst_rank_x = T.floordiv(dst_rank[0], Y) dst_rank_y = T.floormod(dst_rank[0], Y) @@ -145,8 +220,6 @@ def main_route_opt( next_rank[0], ) T.fence_sys(sem=T.MemorySemantic.RELEASE) - # Ensure ALL lanes' put_warp writes + fence are complete before lane 0 sends signal - T.sync_warp() if tx % 32 == 0: # Set per-slot ready flag: this specific slot is now ready T.st( @@ -165,24 +238,18 @@ def main_route_opt( chunk_M * N, -1, ) - if tx % 32 == 0: - old_local[0] = T.atom_add( - local_finish[0], - 1, - scope=T.MemoryScope.GPU, - sem=T.MemorySemantic.RELEASE, - ) T.sync_warp() T.fence_cta(sem=T.MemorySemantic.RELEASE) - T.sync_threads() - T.fence_sys(sem=T.MemorySemantic.RELEASE) - - # Phase 2: Poll per-slot ready flags sequentially and process each slot - # Each warp checks slot 0, 1, 2, ... for its (dst_rank, bz, warp_id). - # flag=0 means not ready yet (spin), flag=1 means data ready, flag=-1 means no more slots. + # Phase 2: Poll per-slot ready flags sequentially and process each slot. + # Use pre-computed expected_slots count to know exactly how many slots to process, + # eliminating the need for termination signals. for slot_idx in T.serial(num_slots): - # Wait for this slot's flag to become non-zero + # Skip if no more expected slots or this block doesn't receive for this dst_rank + if slot_idx >= num_expected[0]: + T.loop_break() + + # Wait for this slot's data to become ready (flag != 0) if tx % 32 == 0: slot_flag[0] = T.wait_ne( signal_transfer[dst_rank[0], slot_idx, bz, warp_id], @@ -192,9 +259,6 @@ def main_route_opt( slot_flag_shared[warp_id] = slot_flag[0] T.sync_warp() - # Termination: flag == -1 means no more slots to process - if slot_flag_shared[warp_id] < 0: - T.loop_break() src_idx = src_transfer[dst_rank[0], slot_idx, bz, warp_id] # Handle the transfer @@ -247,46 +311,6 @@ def main_route_opt( ) T.fence_cta(sem=T.MemorySemantic.RELEASE) T.sync_warp() - if tx % 32 == 0: - old_local[0] = T.atom_add( - local_finish[0], - 1, - scope=T.MemoryScope.GPU, - sem=T.MemorySemantic.RELEASE, - ) - if old_local[0] + 1 == PE_num * num_blocks * num_warps: - # All local data received, notify all PEs - for i in T.serial(PE_num): - old_global[0] = T.atom_add_remote( - global_finish[0], - 1, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=i, - ) - if old_global[0] + 1 == PE_num: - # Last PE to finish: send termination flags to all waiting slots on all PEs - # For each (dst_rank, bz, warp_id), we need to set the flag of the - # next unused slot to -1 so the receiver knows to stop polling. - for remote_pe in T.serial(PE_num): - for dst_rank_idx in T.serial(PE_num): - for bz_idx in T.serial(num_blocks): - for warp_idx in T.serial(num_warps): - # The next unused slot index is slot_counter[dst_rank, bz, warp] - # We write -1 to signal_transfer[dst_rank, slot_counter_val, bz, warp] - # But reading slot_counter from remote is complex. - # Instead, write -1 to ALL remaining slots (from 0 to num_slots-1). - # Slots already processed (flag=1) won't be re-checked. - # The receiver will hit -1 at the first unwritten slot. - for slot_i in T.serial(num_slots): - T.st( - signal_transfer[dst_rank_idx, slot_i, bz_idx, warp_idx], - -1, - scope=T.MemoryScope.SYSTEM, - sem=T.MemorySemantic.RELEASE, - dst_pe=remote_pe, - ) - T.sync_warp() T.sync_warp() @@ -316,6 +340,13 @@ def run_torus_alltoall(local_rank, num_ranks, args): group=group_size, ) + # Pre-compute expected slot counts based on XY-routing topology + expected_slots_host = compute_expected_slots(PE_num, X, Y) + if local_rank == 0: + print(f"Expected slots per PE (rows=receiver, cols=dst_rank):") + for pe in range(PE_num): + print(f" PE {pe}: {expected_slots_host[pe]}") + kernel = torus_alltoall_xy(PE_num, X, Y, M, N, num_blocks, num_warps, threads) kernel.initialize(allocator=allocator) @@ -331,7 +362,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): # Signals for each buffer slot in each direction slot_counter = tilelang.tensor((PE_num, num_blocks, num_warps), torch.uint32, allocator=allocator).fill_(0) - # Per-slot ready flags: 0=not ready, 1=ready, -1=termination + # Per-slot ready flags: 0=not ready, 1=ready signal_transfer = tilelang.tensor((PE_num, PE_num, num_blocks, num_warps), torch.int32, allocator=allocator).fill_(0) @@ -340,14 +371,14 @@ def run_torus_alltoall(local_rank, num_ranks, args): torch.uint32, allocator=allocator).fill_(PE_num) - local_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) - global_finish = tilelang.tensor((1), torch.uint32, allocator=allocator).fill_(0) - barrier = tilelang.tensor((PE_num), torch.int32, allocator=allocator).zero_() + # Pre-computed expected slots tensor (same for all PEs) + expected_slots_tensor = tilelang.tensor((PE_num, PE_num), torch.int32, allocator=allocator) + expected_slots_tensor.copy_(torch.tensor(expected_slots_host, dtype=torch.int32, device="cuda")) torch.cuda.synchronize() dist.barrier(group_size) - kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, src_transfer, local_finish, global_finish, barrier) + kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, src_transfer, expected_slots_tensor) torch.cuda.synchronize() dist.barrier(group_size) @@ -391,23 +422,27 @@ def run_torus_alltoall(local_rank, num_ranks, args): # Warmup for _ in range(args.warmup): # Reinitialize + torch.cuda.synchronize() + dist.barrier(group_size) buffer_transfer.fill_(-1) src_transfer.fill_(PE_num) slot_counter.zero_() signal_transfer.zero_() - local_finish.zero_() - global_finish.zero_() - kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, src_transfer, local_finish, global_finish, barrier) + dst.zero_() + expected_slots_tensor.copy_(torch.tensor(expected_slots_host, dtype=torch.int32, device="cuda")) + kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, src_transfer, expected_slots_tensor) torch.cuda.synchronize() - dist.barrier(group_size) + dist.barrier(group_size) # Reinitialize buffer_transfer.fill_(-1) src_transfer.fill_(PE_num) slot_counter.zero_() signal_transfer.zero_() - local_finish.zero_() - global_finish.zero_() + dst.zero_() + expected_slots_tensor.copy_(torch.tensor(expected_slots_host, dtype=torch.int32, device="cuda")) + torch.cuda.synchronize() + dist.barrier(group_size) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -417,7 +452,7 @@ def run_torus_alltoall(local_rank, num_ranks, args): for _ in range(num_iters): # torch.cuda.profiler.start() # with torch.cuda.nvtx.range("alltoall_xy_routing_benchmark"): - kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, src_transfer, local_finish, global_finish, barrier) + kernel(src, dst, buffer_transfer, slot_counter, signal_transfer, src_transfer, expected_slots_tensor) # torch.cuda.profiler.stop() torch.cuda.synchronize() dist.barrier(group_size)