From 5fb7df6aebb0a105ae5ff63daeba74c8e2e4f68b Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Thu, 30 Oct 2025 17:51:20 +0800 Subject: [PATCH 01/14] add draft test --- .../primitives/example_tilescale_copy.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 examples/distributed/primitives/example_tilescale_copy.py diff --git a/examples/distributed/primitives/example_tilescale_copy.py b/examples/distributed/primitives/example_tilescale_copy.py new file mode 100644 index 000000000..f436572e2 --- /dev/null +++ b/examples/distributed/primitives/example_tilescale_copy.py @@ -0,0 +1,84 @@ +import os +import tilelang +import tilelang.language as T +import argparse +import torch +import torch.distributed as dist +import torch.multiprocessing +from tilelang.distributed import init_dist + +tilelang.disable_cache() +os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log + + +def kernel_(M, block_M, threads): + + @T.prim_func + def main( + dst: T.Tensor((M), "float32"), + src: T.Tensor((M), "float32"), + ): + with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx): + rank = T.alloc_local([1], "uint64") + rank[0] = T.get_rank() + + # We can use T.copy just as in TileLang, except for setting {src/dst}_pe + T.copy( + dst[bx * block_M: (bx + 1) * block_M], + src[bx * block_M: (bx + 1) * block_M], + dst_pe=rank[0] ^ 1, + disable_tma=True # Ensure testing SIMT remote copy + ) + + return main + + +def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + M = args.M if args else 65536 + BLOCK_M = 4096 + threads = 128 + assert num_local_ranks == 2, "this example only supports 2 ranks copying to each other" + + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + allocator = tilelang.get_allocator( + size=2**25, + device="cuda", + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_local_ranks, + group=group) + kernel = tilelang.compile(kernel_(M, num_ranks, BLOCK_M, threads)) + kernel.initialize(allocator=allocator) + if local_rank == 0: + print(kernel.get_kernel_source()) + + src = tilelang.tensor((M), torch.float32, allocator=allocator).normal_() + dst = tilelang.tensor((M), torch.float32, allocator=allocator) + + torch.cuda.synchronize() + torch.distributed.barrier(group) + kernel(dst, src) + torch.cuda.synchronize() + torch.distributed.barrier(group) + + dst_torchs = [torch.empty_like(src) for _ in range(num_local_ranks)] + dist.all_gather(dst_torchs, src, group) + dst_torch = dst_torchs[local_rank ^ 1] + + if torch.allclose(dst_torch, dst, atol=1e-6, rtol=1e-6): + print(f"rank {local_rank} check passed.✅") + else: + print(f"rank {local_rank} check failed.❌") + print(f"dst_torch: {dst_torch}, dst: {dst}") + raise ValueError("Test failed") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--M', type=int, default=65536, help='M dimension') + args = parser.parse_args() + num_processes = 2 + + torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes) From b822ea9d6ee26470c4fd0a02b914f55e26089844 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Thu, 30 Oct 2025 20:40:07 +0800 Subject: [PATCH 02/14] draft --- src/op/copy.cc | 18 ++- src/op/copy.h | 6 + src/op/distributed.cc | 5 + src/op/distributed.h | 5 + src/target/codegen_cuda.cc | 3 + src/tl_templates/cuda/distributed.h | 8 +- src/transform/declare_symm_buffer.cc | 170 +++++++++++++++++++++++++++ tilelang/engine/phase.py | 4 + tilelang/language/__init__.py | 1 + tilelang/language/copy.py | 19 ++- tilelang/language/proxy.py | 17 ++- tilelang/transform/__init__.py | 11 ++ 12 files changed, 257 insertions(+), 10 deletions(-) create mode 100644 src/transform/declare_symm_buffer.cc diff --git a/src/op/copy.cc b/src/op/copy.cc index 754dd7336..da7da7f2f 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -155,6 +155,18 @@ Copy::Copy(Array args, BufferMap vmap) { if (args.size() >= 5) { node->eviction_policy = args[4].as()->value; } + // remote copy params + if (args.size() >= 6) { + node->src_pe = args[5]; + } + if (args.size() >= 7) { + node->dst_pe = args[6]; + } + if (args.size() >= 8) { + node->is_remote_copy = Downcast(args[7]); + } + // TODO: check symm buffer is on global + data_ = std::move(node); } @@ -1940,11 +1952,11 @@ Array TMAIm2ColDesc::EncodeCallArgs() const { // Register the Copy operation with TVM's TIR system // This makes the copy operation available for use in TVM programs -// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, -// eviction_policy +// - Takes 8 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, +// eviction_policy, src_pe, dst_pe, is_remote_copy // - Marked as opaque since it has side effects (memory writes) TIR_REGISTER_TL_OP(Copy, copy) - .set_num_inputs(5) + .set_num_inputs(8) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/copy.h b/src/op/copy.h index 00d07f169..cc6cfa735 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -92,6 +92,12 @@ class CopyNode : public TileOperatorNode { IntImm coalesced_width; // Width (in elements) for coalesced memory access Bool disable_tma = Bool(false); // Whether to disable TMA acceleration + // Params for remote copy + Bool is_remote_copy = Bool(false); // Whether to enable remote copy + PrimExpr src_pe; // Source PE for remote copy + PrimExpr dst_pe; // Destination PE for remote copy + Buffer symm_buffer; // Symmetric buffer for remote copy + mutable ParallelOp par_op_; // Optional associated parallelization operator enum class EvictionPolicy : uint8_t { diff --git a/src/op/distributed.cc b/src/op/distributed.cc index 84a23afa7..768d2160b 100644 --- a/src/op/distributed.cc +++ b/src/op/distributed.cc @@ -207,6 +207,11 @@ TIR_DEFINE_TL_BUILTIN(get_remote_base_ptr) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(get_local_base_ptr) +.set_num_inputs(0) +.set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(get_uintptr_t) .set_num_inputs(1) .set_attr("TCallEffectKind", diff --git a/src/op/distributed.h b/src/op/distributed.h index 5170cc7ea..df16be695 100644 --- a/src/op/distributed.h +++ b/src/op/distributed.h @@ -226,6 +226,11 @@ const Op &get_num_ranks(); */ const Op &get_remote_base_ptr(); +/*! + * \brief tvm intrinsics for getting the local base pointer + */ + const Op &get_local_base_ptr(); + /*! * \brief tvm intrinsics for getting the uintptr_t of a pointer */ diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index e93b6fc4e..62dc975f0 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1543,6 +1543,9 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->use_distributed_ = true; std::string pe_str = this->PrintExpr(op->args[0]); os << "tl::get_remote_base_ptr(" << pe_str << ")"; + } else if (op->op.same_as(tl::get_local_base_ptr())) { + this->use_distributed_ = true; + os << "tl::get_local_base_ptr()"; } else if (op->op.same_as(tl::get_uintptr_t())) { os << "tl::get_uintptr_t(" << this->PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(builtin::tvm_fill_fragment())) { diff --git a/src/tl_templates/cuda/distributed.h b/src/tl_templates/cuda/distributed.h index 7eca3a708..68a43abfc 100644 --- a/src/tl_templates/cuda/distributed.h +++ b/src/tl_templates/cuda/distributed.h @@ -10,8 +10,12 @@ TL_DEVICE uint64_t get_rank() { return meta_data[0]; } TL_DEVICE uint64_t get_num_ranks() { return meta_data[1]; } -TL_DEVICE uint64_t get_remote_base_ptr(uint64_t rank) { - return meta_data[2 + rank]; +TL_DEVICE void* get_remote_base_ptr(uint64_t rank) { + return (void*)meta_data[2 + rank]; +} + +TL_DEVICE uint64_t get_local_base_ptr() { + return meta_data[2 + get_rank()]; } template TL_DEVICE uint64_t get_uintptr_t(dtype_t *ptr) { diff --git a/src/transform/declare_symm_buffer.cc b/src/transform/declare_symm_buffer.cc new file mode 100644 index 000000000..6cef38d1f --- /dev/null +++ b/src/transform/declare_symm_buffer.cc @@ -0,0 +1,170 @@ +// TileScale pass + +/*! + * \file declare_symm_buffer.cc + * \brief Declare the symmetry buffer to prepare for operators that need buffers on peer's symm heap + */ + + #include + #include + #include + #include + #include + #include + #include + #include + + #include + +#include "../op/copy.h" +#include "../op/distributed.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using tvm::transform::PassContext; + +PrimExpr CalculateSymmPtr(PrimExpr ptr, PrimExpr pe) { + PrimExpr local_rank = Call(DataType::Int(64), tl::get_rank(), {}); + PrimExpr local_base_ptr = + Call(DataType::Handle(), tl::get_local_base_ptr(), {}); + PrimExpr offset_to_base = + Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {ptr}), + local_base_ptr); + PrimExpr result = Call(DataType::Handle(), tl::get_remote_base_ptr(), {pe}) + offset_to_base; + return result; +} + +/*! + * \brief Declare the symmetry buffer to prepare for operators that need buffers on peer's symm heap + */ +class SymmBufferDeclarer : public StmtExprMutator { +public: + static PrimFunc Apply(PrimFunc f) { + if (!f->body.defined()) { + return f; + } + + SymmBufferDeclarer declarer; + + // Extract symm buffer info and replace them + // The LetStmt insertion will happen inside VisitStmt_(const BlockNode*) + f.CopyOnWrite()->body = declarer.VisitStmt(f->body); + + return f; + }; + +private: + // Override BlockNode visitor to insert LetStmt inside blocks, not at PrimFunc level + Stmt VisitStmt_(const BlockNode *op) final { + // First, recursively visit children to collect let_bindings + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + + // Insert let bindings inside the block body (not at PrimFunc level) + // We do this after visiting to ensure all let_bindings are collected + if (!let_bindings_.empty() && !let_bindings_inserted_) { + // Insert inside any non-root block to avoid PrimFunc-level insertion + // The "tilelang_root" or similar computation blocks are ideal + if (op->name_hint != "root") { + let_bindings_inserted_ = true; + Stmt body = block->body; + // Wrap the block body with all let bindings + for (const auto& kv : let_bindings_) { + body = LetStmt(GetRef(kv.first), kv.second, body); + } + BlockNode* n = block.CopyOnWrite(); + n->body = body; + } + } + + return block; + } + + PrimExpr VisitExpr_(const CallNode *op) final { + // LOG(INFO) << "Found call"; + auto parsed_op = ParseOperator(GetRef(op), buffer_data_to_buffer_); + if (parsed_op.defined() && parsed_op.as()) { + // LOG(INFO) << "Found copy"; + if (parsed_op.as()->is_remote_copy) { + // LOG(INFO) << "Found remote copy"; + if (parsed_op.as()->dst_pe.defined()) // TODO: add check here + // && parsed_op.as()->dst_pe.as()->value != -1) + { + LOG(INFO) << "Found remote push"; + Buffer dst = parsed_op.as()->dst; + Array dst_range = parsed_op.as()->dst_range; + + // 1. Calculate symm dst ptr + PrimExpr symm_dst_ptr_expr = CalculateSymmPtr(dst->data, parsed_op.as()->dst_pe); + LOG(INFO) << "Symm dst ptr expr: " << symm_dst_ptr_expr; + + // 2. Record a let stmt to assign PrimExpr to Var + String storage_scope = dst->data->type_annotation.as()->storage_scope; + Var symm_dst_var = Var(dst->name+"_symm", PointerType(PrimType(dst->dtype), storage_scope)); + PrimExpr casted_ptr = Cast(DataType::Handle(), + symm_dst_ptr_expr); + let_bindings_[symm_dst_var.get()] = casted_ptr; + + // 3. Create modified dst buffer with symm var + dst.CopyOnWrite()->data = symm_dst_var; + + // 4. Rebuild the destination region call with the modified buffer + // RegionOp args: [BufferLoad(min_indices), access_mask, extent_0, extent_1, ...] + Array dst_region_mins; + Array dst_region_extents; + for (const Range& r : dst_range) { + dst_region_mins.push_back(r->min); + dst_region_extents.push_back(r->extent); + } + BufferLoad dst_load(dst, dst_region_mins); + + Array dst_region_args; + dst_region_args.push_back(dst_load); + dst_region_args.push_back(IntImm(DataType::Int(32), op->args[1].as()->args[1].as()->value)); // access_mask + for (const PrimExpr& extent : dst_region_extents) { + dst_region_args.push_back(extent); + } + + // Create new Call for the destination region + Call dst_region_call = Call(op->args[1].as()->dtype, + op->args[1].as()->op, + dst_region_args, + op->args[1].as()->span); + + // 5. Rebuild the Copy call with modified args + Array new_copy_args; + new_copy_args.push_back(op->args[0]); // src region (unchanged) + new_copy_args.push_back(dst_region_call); // modified dst region + // Copy remaining args + for (size_t i = 2; i < op->args.size(); i++) { + new_copy_args.push_back(op->args[i]); + } + + return Call(op->dtype, op->op, new_copy_args, op->span); + } + } + } + return StmtExprMutator::VisitExpr_(op); + } + + Map buffer_data_to_buffer_; + std::unordered_map let_bindings_; + bool let_bindings_inserted_ = false; +}; + +tvm::transform::Pass DeclareSymmBuffer() { + auto pass_func = [](PrimFunc f, const IRModule &, const PassContext &) { + f = SymmBufferDeclarer::Apply(std::move(f)); + return f; + }; + return tir::transform::CreatePrimFuncPass(pass_func, 0, "tl.DeclareSymmBuffer", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.DeclareSymmBuffer", DeclareSymmBuffer); +}); + +} // namespace tl +} // namespace tvm diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 6c34eae08..a1306ea94 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -104,6 +104,10 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.LayoutReducer()(mod) # Infer memory layouts for fragments and shared memory mod = tilelang.transform.LayoutInference()(mod) + # Declare symmetric buffer on peer's global memory + print(f"before:\n{mod}") + mod = tilelang.transform.DeclareSymmBuffer()(mod) + print(f"after:\n{mod}") # Lower high-level tile operations to low-level operations mod = tilelang.transform.LowerTileOp()(mod) # Lower l2 persistent map diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 27f5432fc..023271a7d 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -16,6 +16,7 @@ from .proxy import ( ptr, # noqa: F401 make_tensor, # noqa: F401 + make_tensor_like, # noqa: F401 Buffer, # noqa: F401 Tensor, # noqa: F401 StridedTensor, # noqa: F401 diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index 84444b8c6..46e1ef58b 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -1,7 +1,7 @@ """The language interface for tl programs.""" from __future__ import annotations -from typing import Literal +from typing import Literal, Optional from tilelang import language as T from tilelang.utils.language import get_buffer_region_from_load from tvm import ir, tir @@ -10,16 +10,23 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, dst: tir.Buffer | tir.BufferLoad, + src_pe: Optional[tir.PrimExpr] = -1, + dst_pe: Optional[tir.PrimExpr] = -1, coalesced_width: int | None = None, disable_tma: bool = False, - eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None): + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, + ): """Copy data between memory regions. Args: src (Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion]): Source memory region dst (Union[tir.Buffer, tir.BufferLoad]): Destination memory region + src_pe (Optional[tir.PrimExpr], optional): Source PE index. Defaults to -1, which means copy from local + dst_pe (Optional[tir.PrimExpr], optional): Destination PE index. Defaults to -1, which means copy to local. coalesced_width (Optional[int], optional): Width for coalesced memory access. Defaults to None. - + disable_tma (bool, optional): Whether to disable TMA. Defaults to False. + eviction_policy (Optional[Literal["evict_normal", "evict_first", "evict_last"]], optional): Eviction policy. Defaults to None. + Raises: TypeError: If copy extents cannot be deduced from arguments @@ -83,8 +90,12 @@ def _to_region(data, access_type): eviction_policy = 0 else: eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] + + assert src_pe == -1 or dst_pe == -1, "At least one of src_pe or dst_pe must be local rank" + is_remote_copy = src_pe is not None or dst_pe is not None + return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width, - disable_tma, eviction_policy) + disable_tma, eviction_policy, src_pe, dst_pe, is_remote_copy) def c2d_im2col(img: tir.Buffer, diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py index 539c1d94c..fade0599f 100644 --- a/tilelang/language/proxy.py +++ b/tilelang/language/proxy.py @@ -1,7 +1,7 @@ """The language interface for tl programs.""" from __future__ import annotations -from typing import Any, Sequence, SupportsIndex, TYPE_CHECKING +from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING from typing_extensions import Self from tvm import tir @@ -300,3 +300,18 @@ def make_tensor(ptr: Var, dtype: str = "float32", strides: tuple[PrimExpr, ...] = None) -> tir.Buffer: return Tensor.from_ptr(ptr, shape, dtype, strides) + + +def make_tensor_like( + tensor, + ptr: Var, + shape: Optional[tuple[PrimExpr, ...]] = None, + dtype: Optional[str] = None, + strides: Optional[tuple[PrimExpr, ...]] = None +) -> tir.Buffer: + return Tensor.from_ptr( + ptr if ptr is not None else tensor.data, + shape if shape is not None else tensor.shape, + dtype if dtype is not None else tensor.dtype, + strides if strides is not None else tensor.strides + ) \ No newline at end of file diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 808c97dc6..88cdb121e 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -484,3 +484,14 @@ def LayoutReducer(): The transform pass object produced by the FFI backend. """ return _ffi_api.LayoutReducer() # type: ignore + + +# TileScale passes + +def DeclareSymmBuffer(): + """ + Declare symmetric buffer on peer's global memory. + + This pass prepares for distributed operators, such as remote copy. + """ + return _ffi_api.DeclareSymmBuffer() # type: ignore \ No newline at end of file From 784dd18f3f856da8c550157724094cb831ff5766 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Thu, 30 Oct 2025 22:29:34 +0800 Subject: [PATCH 03/14] support SIMT push and fix a bug --- .../distributed/primitives/example_tilescale_copy.py | 12 ++++++------ tilelang/engine/phase.py | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/distributed/primitives/example_tilescale_copy.py b/examples/distributed/primitives/example_tilescale_copy.py index f436572e2..a3f385519 100644 --- a/examples/distributed/primitives/example_tilescale_copy.py +++ b/examples/distributed/primitives/example_tilescale_copy.py @@ -18,15 +18,15 @@ def main( dst: T.Tensor((M), "float32"), src: T.Tensor((M), "float32"), ): - with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx): + with T.Kernel((M//block_M), threads=threads) as (bx): rank = T.alloc_local([1], "uint64") rank[0] = T.get_rank() # We can use T.copy just as in TileLang, except for setting {src/dst}_pe T.copy( - dst[bx * block_M: (bx + 1) * block_M], - src[bx * block_M: (bx + 1) * block_M], - dst_pe=rank[0] ^ 1, + dst[bx * block_M:(bx + 1) * block_M], + src[bx * block_M:(bx + 1) * block_M], + dst_pe=1-rank[0], disable_tma=True # Ensure testing SIMT remote copy ) @@ -47,7 +47,7 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): local_rank=local_rank, num_local_ranks=num_local_ranks, group=group) - kernel = tilelang.compile(kernel_(M, num_ranks, BLOCK_M, threads)) + kernel = tilelang.compile(kernel_(M, BLOCK_M, threads)) kernel.initialize(allocator=allocator) if local_rank == 0: print(kernel.get_kernel_source()) @@ -70,7 +70,7 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): else: print(f"rank {local_rank} check failed.❌") print(f"dst_torch: {dst_torch}, dst: {dst}") - raise ValueError("Test failed") + # raise ValueError("Test failed") dist.destroy_process_group() diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index a1306ea94..d3bafa598 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -217,5 +217,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # Transform threadblock to persistent threadblock mod = tilelang.transform.PersistThreadblock()(mod) + print(mod) return mod From a434339a459ccf1a192d923a28d246440a971e49 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Fri, 31 Oct 2025 13:46:23 +0800 Subject: [PATCH 04/14] refactor and support pull --- .../primitives/example_tilescale_copy.py | 142 +++++++++++++-- src/op/copy.cc | 23 ++- src/op/copy.h | 5 +- src/op/distributed.cc | 2 +- src/op/distributed.h | 2 +- src/target/codegen_cuda.cc | 5 +- src/tl_templates/cuda/common.h | 4 + src/tl_templates/cuda/distributed.h | 41 ++++- src/transform/declare_symm_buffer.cc | 171 +++++++++++------- src/transform/lower_hopper_intrin.cc | 59 +++++- tilelang/engine/phase.py | 3 - tilelang/jit/adapter/utils.py | 126 +++++++++++++ tilelang/jit/adapter/wrapper.py | 15 +- tilelang/language/copy.py | 7 +- 14 files changed, 493 insertions(+), 112 deletions(-) diff --git a/examples/distributed/primitives/example_tilescale_copy.py b/examples/distributed/primitives/example_tilescale_copy.py index a3f385519..6d96ab4db 100644 --- a/examples/distributed/primitives/example_tilescale_copy.py +++ b/examples/distributed/primitives/example_tilescale_copy.py @@ -11,35 +11,140 @@ os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log -def kernel_(M, block_M, threads): +@tilelang.jit +def get_kernel(M, N, block_M, block_N, threads, kernel='simt_push_tile', rank=None): @T.prim_func - def main( - dst: T.Tensor((M), "float32"), - src: T.Tensor((M), "float32"), + def simt_push_buffer( + dst: T.Tensor((M, N), "float32"), + src: T.Tensor((M, N), "float32"), ): - with T.Kernel((M//block_M), threads=threads) as (bx): + with T.Kernel((1), threads=threads) as (bx): rank = T.alloc_local([1], "uint64") rank[0] = T.get_rank() - # We can use T.copy just as in TileLang, except for setting {src/dst}_pe T.copy( - dst[bx * block_M:(bx + 1) * block_M], - src[bx * block_M:(bx + 1) * block_M], + src, + dst, dst_pe=1-rank[0], disable_tma=True # Ensure testing SIMT remote copy ) - return main + @T.prim_func + def simt_push_tile( + dst: T.Tensor((M, N), "float32"), + src: T.Tensor((M, N), "float32"), + ): + with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by): + rank = T.alloc_local([1], "uint64") + rank[0] = T.get_rank() + + smem = T.alloc_shared((block_M, block_N), "float32") + T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)}) + + T.copy( + src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + smem, + disable_tma=True # Ensure testing SIMT remote copy + ) + + T.copy( + smem, + dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + dst_pe=1-rank[0], + disable_tma=True # Ensure testing SIMT remote copy + ) + + @T.prim_func + def simt_pull_tile( + dst: T.Tensor((M, N), "float32"), + src: T.Tensor((M, N), "float32"), + ): + with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by): + rank = T.alloc_local([1], "uint64") + rank[0] = T.get_rank() + + smem = T.alloc_shared((block_M, block_N), "float32") + T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)}) + + T.copy( + src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + smem, + src_pe=1-rank[0], + disable_tma=True # Ensure testing SIMT remote copy + ) + + T.copy( + smem, + dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + disable_tma=True # Ensure testing SIMT remote copy + ) + + # TMA kernel requires run-time aware peer rank + @T.prim_func + def tma_load_tile( + dst: T.Tensor((M, N), "float32"), + src: T.Tensor((M, N), "float32"), + ): + with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by): + + smem = T.alloc_shared((block_M, block_N), "float32") + T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)}) + + # TMA load + T.copy( + src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + smem, + src_pe=1-rank, + ) + + T.copy( + smem, + dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + disable_tma=True # Ensure testing SIMT remote copy + ) + + @T.prim_func + def tma_store_tile( + dst: T.Tensor((M, N), "float32"), + src: T.Tensor((M, N), "float32"), + ): + with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by): + + smem = T.alloc_shared((block_M, block_N), "float32") + T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)}) + + T.copy( + src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + smem, + disable_tma=True # Ensure testing SIMT remote copy + ) + + # TMA store + T.copy( + smem, + dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + dst_pe=1-rank + ) + + return { + 'simt_push_buffer': simt_push_buffer, + 'simt_push_tile': simt_push_tile, + 'simt_pull_tile': simt_pull_tile, + 'tma_load_tile': tma_load_tile, + 'tma_store_tile': tma_store_tile + }[kernel] def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): - M = args.M if args else 65536 - BLOCK_M = 4096 + M = args.M + N = args.N + BLOCK_M = 64 + BLOCK_N = 128 threads = 128 assert num_local_ranks == 2, "this example only supports 2 ranks copying to each other" - rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + _, _, group = init_dist(local_rank, num_local_ranks) allocator = tilelang.get_allocator( size=2**25, device="cuda", @@ -47,13 +152,14 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): local_rank=local_rank, num_local_ranks=num_local_ranks, group=group) - kernel = tilelang.compile(kernel_(M, BLOCK_M, threads)) + + kernel = get_kernel(M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel, rank=local_rank) # only TMA kernels need compile-time aware peer rank kernel.initialize(allocator=allocator) if local_rank == 0: print(kernel.get_kernel_source()) - src = tilelang.tensor((M), torch.float32, allocator=allocator).normal_() - dst = tilelang.tensor((M), torch.float32, allocator=allocator) + src = tilelang.tensor((M, N), torch.float32, allocator=allocator).normal_() + dst = tilelang.tensor((M, N), torch.float32, allocator=allocator) torch.cuda.synchronize() torch.distributed.barrier(group) @@ -70,14 +176,16 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): else: print(f"rank {local_rank} check failed.❌") print(f"dst_torch: {dst_torch}, dst: {dst}") - # raise ValueError("Test failed") + raise ValueError("Test failed") dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--M', type=int, default=65536, help='M dimension') + parser.add_argument('--M', type=int, default=1024, help='M dimension') + parser.add_argument('--N', type=int, default=1024, help='N dimension') + parser.add_argument('--kernel', type=str, default='simt_push_tile', help='kernel to use') args = parser.parse_args() num_processes = 2 diff --git a/src/op/copy.cc b/src/op/copy.cc index da7da7f2f..2fccd512d 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -155,21 +155,36 @@ Copy::Copy(Array args, BufferMap vmap) { if (args.size() >= 5) { node->eviction_policy = args[4].as()->value; } - // remote copy params + + // Parse remote copy params if (args.size() >= 6) { node->src_pe = args[5]; } if (args.size() >= 7) { node->dst_pe = args[6]; } - if (args.size() >= 8) { - node->is_remote_copy = Downcast(args[7]); + + if (node->is_remote_push()) { + ICHECK(node->dst.scope()=="global") << "Can only copy to peer's global memory, but got " << node->dst.scope(); + } else if (node->is_remote_pull()) { + ICHECK(node->src.scope()=="global") << "Can only pull from peer's global memory, but got " << node->src.scope(); } - // TODO: check symm buffer is on global data_ = std::move(node); } +bool CopyNode::is_remote_push() const { + return !(dst_pe->IsInstance() && dst_pe.as()->value == -1); +} + +bool CopyNode::is_remote_pull() const { + return !(src_pe->IsInstance() && src_pe.as()->value == -1); +} + +bool CopyNode::is_remote_copy() const { + return is_remote_push() || is_remote_pull(); +} + /** * @brief Create a shallow clone of this CopyNode as a TileOperator. * diff --git a/src/op/copy.h b/src/op/copy.h index cc6cfa735..ad6037227 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -93,11 +93,14 @@ class CopyNode : public TileOperatorNode { Bool disable_tma = Bool(false); // Whether to disable TMA acceleration // Params for remote copy - Bool is_remote_copy = Bool(false); // Whether to enable remote copy PrimExpr src_pe; // Source PE for remote copy PrimExpr dst_pe; // Destination PE for remote copy Buffer symm_buffer; // Symmetric buffer for remote copy + bool is_remote_copy() const; + bool is_remote_push() const; + bool is_remote_pull() const; + mutable ParallelOp par_op_; // Optional associated parallelization operator enum class EvictionPolicy : uint8_t { diff --git a/src/op/distributed.cc b/src/op/distributed.cc index 768d2160b..20635479e 100644 --- a/src/op/distributed.cc +++ b/src/op/distributed.cc @@ -207,7 +207,7 @@ TIR_DEFINE_TL_BUILTIN(get_remote_base_ptr) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(get_local_base_ptr) +TIR_DEFINE_TL_BUILTIN(get_local_base) .set_num_inputs(0) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/distributed.h b/src/op/distributed.h index df16be695..aaae98aef 100644 --- a/src/op/distributed.h +++ b/src/op/distributed.h @@ -229,7 +229,7 @@ const Op &get_remote_base_ptr(); /*! * \brief tvm intrinsics for getting the local base pointer */ - const Op &get_local_base_ptr(); + const Op &get_local_base(); /*! * \brief tvm intrinsics for getting the uintptr_t of a pointer diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 62dc975f0..97dc0fa33 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -297,6 +297,7 @@ std::string CodeGenTileLangCUDA::Finish() { if (use_distributed_) { decl_stream << "uint64_t __constant__ meta_data[1024];\n"; + decl_stream << "uint64_t* host_meta_data = nullptr;\n"; // An alias of host_table } decl_stream << "#ifdef ENABLE_BF16\n"; decl_stream << "#include \n"; @@ -1543,9 +1544,9 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->use_distributed_ = true; std::string pe_str = this->PrintExpr(op->args[0]); os << "tl::get_remote_base_ptr(" << pe_str << ")"; - } else if (op->op.same_as(tl::get_local_base_ptr())) { + } else if (op->op.same_as(tl::get_local_base())) { this->use_distributed_ = true; - os << "tl::get_local_base_ptr()"; + os << "tl::get_local_base()"; } else if (op->op.same_as(tl::get_uintptr_t())) { os << "tl::get_uintptr_t(" << this->PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(builtin::tvm_fill_fragment())) { diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index dfbc062cf..ad3a3e81e 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -32,6 +32,10 @@ using int4_t = int4; #define TL_DEVICE __forceinline__ __device__ #define TL_DEVICE_NOINLINE __noinline__ __device__ +#define TL_HOST __forceinline__ __host__ +#define TL_HOST_NOINLINE __noinline__ _host__ +#define TL_HOST_DEVICE __forceinline__ __host__ __device__ +#define TL_HOST_DEVICE_NOINLINE __noinline__ __host__ __device__ #define TL_PATCH #define TILELANG_CHECK(stmt) \ diff --git a/src/tl_templates/cuda/distributed.h b/src/tl_templates/cuda/distributed.h index 68a43abfc..a250ca010 100644 --- a/src/tl_templates/cuda/distributed.h +++ b/src/tl_templates/cuda/distributed.h @@ -1,24 +1,49 @@ #pragma once #include "common.h" +#include namespace tl { -extern "C" extern __device__ uint64_t meta_data[1024]; +extern "C" __device__ uint64_t meta_data[1024]; +extern "C" __host__ uint64_t* host_meta_data; -TL_DEVICE uint64_t get_rank() { return meta_data[0]; } - -TL_DEVICE uint64_t get_num_ranks() { return meta_data[1]; } +TL_HOST_DEVICE uint64_t get_rank() { +#ifdef __CUDA_ARCH__ + return meta_data[0]; +#else + return host_meta_data[0]; +#endif +} -TL_DEVICE void* get_remote_base_ptr(uint64_t rank) { - return (void*)meta_data[2 + rank]; +TL_HOST_DEVICE uint64_t get_num_ranks() { +#ifdef __CUDA_ARCH__ + return meta_data[1]; +#else + return host_meta_data[1]; +#endif } -TL_DEVICE uint64_t get_local_base_ptr() { +TL_HOST_DEVICE void* get_remote_base_ptr(uint64_t rank) { + #ifdef __CUDA_ARCH__ + return (void*)meta_data[2 + rank]; + #else + return (void*)host_meta_data[2 + rank]; + #endif + } + + +// NOTE(wt): Be careful about the return types here! +// get_local_base() returns u64 since I could not find a way cast u64 to ptr in tir +TL_HOST_DEVICE uint64_t get_local_base() { +#ifdef __CUDA_ARCH__ return meta_data[2 + get_rank()]; +#else + return host_meta_data[2 + get_rank()]; +#endif } -template TL_DEVICE uint64_t get_uintptr_t(dtype_t *ptr) { +template TL_HOST_DEVICE uint64_t get_uintptr_t(dtype_t *ptr) { return reinterpret_cast(ptr); } diff --git a/src/transform/declare_symm_buffer.cc b/src/transform/declare_symm_buffer.cc index 6cef38d1f..db934cab0 100644 --- a/src/transform/declare_symm_buffer.cc +++ b/src/transform/declare_symm_buffer.cc @@ -5,16 +5,16 @@ * \brief Declare the symmetry buffer to prepare for operators that need buffers on peer's symm heap */ - #include - #include - #include - #include - #include - #include - #include - #include +#include +#include +#include +#include +#include +#include +#include +#include - #include +#include #include "../op/copy.h" #include "../op/distributed.h" @@ -25,10 +25,13 @@ namespace tl { using namespace tir; using tvm::transform::PassContext; +static int name_suffix_id = 0; // Avoid name collision for symm buffers, start from 0 + +/* Create a PrimExpr to calculate the symmetry pointer given a local ptr and target PE */ PrimExpr CalculateSymmPtr(PrimExpr ptr, PrimExpr pe) { PrimExpr local_rank = Call(DataType::Int(64), tl::get_rank(), {}); PrimExpr local_base_ptr = - Call(DataType::Handle(), tl::get_local_base_ptr(), {}); + Call(DataType::Handle(), tl::get_local_base(), {}); PrimExpr offset_to_base = Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {ptr}), local_base_ptr); @@ -37,7 +40,8 @@ PrimExpr CalculateSymmPtr(PrimExpr ptr, PrimExpr pe) { } /*! - * \brief Declare the symmetry buffer to prepare for operators that need buffers on peer's symm heap + * \brief Declare the symmetry buffer to prepare for operators + * that need buffers on peer's symm heap */ class SymmBufferDeclarer : public StmtExprMutator { public: @@ -49,63 +53,35 @@ class SymmBufferDeclarer : public StmtExprMutator { SymmBufferDeclarer declarer; // Extract symm buffer info and replace them - // The LetStmt insertion will happen inside VisitStmt_(const BlockNode*) + // The LetStmt insertion will happen in VisitStmt_ before each copy f.CopyOnWrite()->body = declarer.VisitStmt(f->body); return f; }; private: - // Override BlockNode visitor to insert LetStmt inside blocks, not at PrimFunc level - Stmt VisitStmt_(const BlockNode *op) final { - // First, recursively visit children to collect let_bindings - Block block = Downcast(StmtExprMutator::VisitStmt_(op)); - - // Insert let bindings inside the block body (not at PrimFunc level) - // We do this after visiting to ensure all let_bindings are collected - if (!let_bindings_.empty() && !let_bindings_inserted_) { - // Insert inside any non-root block to avoid PrimFunc-level insertion - // The "tilelang_root" or similar computation blocks are ideal - if (op->name_hint != "root") { - let_bindings_inserted_ = true; - Stmt body = block->body; - // Wrap the block body with all let bindings - for (const auto& kv : let_bindings_) { - body = LetStmt(GetRef(kv.first), kv.second, body); - } - BlockNode* n = block.CopyOnWrite(); - n->body = body; - } - } - - return block; - } + Stmt VisitStmt_(const EvaluateNode *op) final { + // Check if this Evaluate contains a Call + if (const CallNode *call_op = op->value.as()) { + auto parsed_op = ParseOperator(GetRef(call_op), buffer_data_to_buffer_); + if (parsed_op.defined() && parsed_op.as()) { + // LOG(INFO) << "Found copy"; + + if (parsed_op.as()->is_remote_push()) { + // LOG(INFO) << "Found remote push"; - PrimExpr VisitExpr_(const CallNode *op) final { - // LOG(INFO) << "Found call"; - auto parsed_op = ParseOperator(GetRef(op), buffer_data_to_buffer_); - if (parsed_op.defined() && parsed_op.as()) { - // LOG(INFO) << "Found copy"; - if (parsed_op.as()->is_remote_copy) { - // LOG(INFO) << "Found remote copy"; - if (parsed_op.as()->dst_pe.defined()) // TODO: add check here - // && parsed_op.as()->dst_pe.as()->value != -1) - { - LOG(INFO) << "Found remote push"; Buffer dst = parsed_op.as()->dst; Array dst_range = parsed_op.as()->dst_range; // 1. Calculate symm dst ptr PrimExpr symm_dst_ptr_expr = CalculateSymmPtr(dst->data, parsed_op.as()->dst_pe); - LOG(INFO) << "Symm dst ptr expr: " << symm_dst_ptr_expr; + // LOG(INFO) << "Symm dst ptr expr: " << symm_dst_ptr_expr; - // 2. Record a let stmt to assign PrimExpr to Var + // 2. Create a let binding String storage_scope = dst->data->type_annotation.as()->storage_scope; - Var symm_dst_var = Var(dst->name+"_symm", PointerType(PrimType(dst->dtype), storage_scope)); - PrimExpr casted_ptr = Cast(DataType::Handle(), - symm_dst_ptr_expr); - let_bindings_[symm_dst_var.get()] = casted_ptr; - + Var symm_dst_var = Var(dst->name+"_symm_"+std::to_string(name_suffix_id++), + PointerType(PrimType(dst->dtype), storage_scope)); + // 3. Create modified dst buffer with symm var dst.CopyOnWrite()->data = symm_dst_var; @@ -121,36 +97,101 @@ class SymmBufferDeclarer : public StmtExprMutator { Array dst_region_args; dst_region_args.push_back(dst_load); - dst_region_args.push_back(IntImm(DataType::Int(32), op->args[1].as()->args[1].as()->value)); // access_mask + dst_region_args.push_back(IntImm(DataType::Int(32), call_op->args[1].as()->args[1].as()->value)); // access_mask for (const PrimExpr& extent : dst_region_extents) { dst_region_args.push_back(extent); } // Create new Call for the destination region - Call dst_region_call = Call(op->args[1].as()->dtype, - op->args[1].as()->op, + Call dst_region_call = Call(call_op->args[1].as()->dtype, + call_op->args[1].as()->op, dst_region_args, - op->args[1].as()->span); + call_op->args[1].as()->span); // 5. Rebuild the Copy call with modified args Array new_copy_args; - new_copy_args.push_back(op->args[0]); // src region (unchanged) + new_copy_args.push_back(call_op->args[0]); // src region (unchanged) new_copy_args.push_back(dst_region_call); // modified dst region // Copy remaining args - for (size_t i = 2; i < op->args.size(); i++) { - new_copy_args.push_back(op->args[i]); + for (size_t i = 2; i < call_op->args.size(); i++) { + new_copy_args.push_back(call_op->args[i]); + } + + // Create the modified copy call + Call new_copy_call = Call(call_op->dtype, call_op->op, new_copy_args, call_op->span); + + // Wrap it in an Evaluate statement + Stmt modified_stmt = Evaluate(new_copy_call); + + // Wrap with LetStmt that defines the symm pointer + return LetStmt(symm_dst_var, symm_dst_ptr_expr, modified_stmt); + } else if (parsed_op.as()->is_remote_pull()) { + // LOG(INFO) << "Found remote pull"; + + Buffer src = parsed_op.as()->src; + Array src_range = parsed_op.as()->src_range; + + // 1. Calculate symm src ptr + PrimExpr symm_src_ptr_expr = CalculateSymmPtr(src->data, parsed_op.as()->src_pe); + // LOG(INFO) << "Symm src ptr expr: " << symm_src_ptr_expr; + + // 2. Create a let binding + String storage_scope = src->data->type_annotation.as()->storage_scope; + Var symm_src_var = Var(src->name+"_symm_"+std::to_string(name_suffix_id++), + PointerType(PrimType(src->dtype), storage_scope)); + + // 3. Create modified src buffer with symm var + src.CopyOnWrite()->data = symm_src_var; + + // 4. Rebuild the source region call with the modified buffer + // RegionOp args: [BufferLoad(min_indices), access_mask, extent_0, extent_1, ...] + Array src_region_mins; + Array src_region_extents; + for (const Range& r : src_range) { + src_region_mins.push_back(r->min); + src_region_extents.push_back(r->extent); + } + BufferLoad src_load(src, src_region_mins); + + Array src_region_args; + src_region_args.push_back(src_load); + src_region_args.push_back(IntImm(DataType::Int(32), call_op->args[1].as()->args[1].as()->value)); // access_mask + for (const PrimExpr& extent : src_region_extents) { + src_region_args.push_back(extent); + } + + // Create new Call for the source region + Call src_region_call = Call(call_op->args[0].as()->dtype, + call_op->args[0].as()->op, + src_region_args, + call_op->args[0].as()->span); + + // 5. Rebuild the Copy call with modified args + Array new_copy_args; + new_copy_args.push_back(src_region_call); // modified src region + new_copy_args.push_back(call_op->args[1]); // dst region (unchanged) + // Copy remaining args + for (size_t i = 2; i < call_op->args.size(); i++) { + new_copy_args.push_back(call_op->args[i]); } - return Call(op->dtype, op->op, new_copy_args, op->span); + // Create the modified copy call + Call new_copy_call = Call(call_op->dtype, call_op->op, new_copy_args, call_op->span); + + // Wrap it in an Evaluate statement + Stmt modified_stmt = Evaluate(new_copy_call); + + // Wrap with LetStmt that defines the symm pointer + return LetStmt(symm_src_var, symm_src_ptr_expr, modified_stmt); } } } - return StmtExprMutator::VisitExpr_(op); + + // Default: use parent's visitor + return StmtExprMutator::VisitStmt_(op); } Map buffer_data_to_buffer_; - std::unordered_map let_bindings_; - bool let_bindings_inserted_ = false; }; tvm::transform::Pass DeclareSymmBuffer() { diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index b514627d7..1728ec279 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -39,8 +39,22 @@ class LowerHopperIntrin : public StmtExprMutator { CHECK(0) << call->op; } init_desc_args.push_back(var); - init_desc_args.insert(init_desc_args.end(), call->args.begin(), - call->args.end()); + + // Inline let-bound variables in the descriptor arguments + for (const auto &arg : call->args) { + if (auto arg_var = arg.as()) { + auto it = substituter.tma_let_bindings_.find(arg_var.value()); + if (it != substituter.tma_let_bindings_.end()) { + // Replace variable with its let-bound expression + init_desc_args.push_back(it->second); + } else { + init_desc_args.push_back(arg); + } + } else { + init_desc_args.push_back(arg); + } + } + // add to function attribute Call init_desc = Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args); @@ -49,6 +63,7 @@ class LowerHopperIntrin : public StmtExprMutator { init_desc_arg_map.Set(var, init_desc_args); } f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map); + return f; } @@ -99,9 +114,43 @@ class LowerHopperIntrin : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } + Stmt VisitStmt_(const LetStmtNode *op) final { + PrimExpr value = this->VisitExpr(op->value); + Stmt body = this->VisitStmt(op->body); + + // Check if this variable is related to TMA (used in descriptor creation) + if (tma_related_vars_.count(op->var)) { + tma_let_bindings_[op->var] = value; + } + + if (value.same_as(op->value) && body.same_as(op->body)) { + return GetRef(op); + } else { + return LetStmt(op->var, value, body); + } + } + PrimExpr VisitExpr_(const CallNode *call) final { - if (call->op.same_as(create_tma_descriptor()) || - call->op.same_as(create_tma_im2col_descriptor())) { + if (call->op.same_as(create_tma_descriptor())) { + Var var; + auto iter = desc_map_.find(GetRef(call)); + if (iter != desc_map_.end()) { + var = iter->second; + } else { + String name = call->args[2].as().value()->name_hint; + var = Var(name + "_desc", + PointerType(PrimType(cuTensorMapType()), "grid_constant")); + desc_map_[GetRef(call)] = var; + prefetch_calls_.push_back( + Evaluate(Call(DataType::Handle(), builtin::call_extern(), + {StringImm("tl::prefetch_tma_descriptor"), var}))); + // Mark the base pointer variable as TMA-related + if (auto base_var = call->args[2].as()) { + tma_related_vars_.insert(base_var.value()); + } + } + return var; + } else if (call->op.same_as(create_tma_im2col_descriptor())) { Var var; auto iter = desc_map_.find(GetRef(call)); if (iter != desc_map_.end()) { @@ -135,6 +184,8 @@ class LowerHopperIntrin : public StmtExprMutator { Array prefetch_calls_; Array init_mbarrier_calls_; std::unordered_map desc_map_; + std::unordered_set tma_related_vars_; + std::unordered_map tma_let_bindings_; LowerHopperIntrin(bool disable_shuffle_elect) : disable_shuffle_elect_(disable_shuffle_elect) {} bool disable_shuffle_elect_; diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index d3bafa598..f20755dcd 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -105,9 +105,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: # Infer memory layouts for fragments and shared memory mod = tilelang.transform.LayoutInference()(mod) # Declare symmetric buffer on peer's global memory - print(f"before:\n{mod}") mod = tilelang.transform.DeclareSymmBuffer()(mod) - print(f"after:\n{mod}") # Lower high-level tile operations to low-level operations mod = tilelang.transform.LowerTileOp()(mod) # Lower l2 persistent map @@ -217,6 +215,5 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # Transform threadblock to persistent threadblock mod = tilelang.transform.PersistThreadblock()(mod) - print(mod) return mod diff --git a/tilelang/jit/adapter/utils.py b/tilelang/jit/adapter/utils.py index efc965e1b..2ce09495c 100644 --- a/tilelang/jit/adapter/utils.py +++ b/tilelang/jit/adapter/utils.py @@ -216,3 +216,129 @@ def _visitor(node): tvm.tir.stmt_functor.post_order_visit(expr, _visitor) return next(iter(node_to_result_map[expr]), "") + + +def tilescale_pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None) -> str: + """ + Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. + Compare to pythonic_expr, this function is used for TileScale with support for parsing specified calls. + + Args: + expr: The TVM PrimExpr to convert. + + Returns: + A string representation of the expression. + """ + if not isinstance(expr, tvm.tir.PrimExpr): + return str(expr) + + # 1. Define operator precedence (higher value means higher precedence) + # Based on Python's operator precedence + PRECEDENCE = { + tvm.tir.Call: 20, # Includes min, max + tvm.tir.Cast: 20, # Treated like a function call + tvm.tir.Mul: 13, + tvm.tir.FloorDiv: 13, + tvm.tir.Div: 13, # For tvm.tir.Div if it appears + tvm.tir.FloorMod: 13, + tvm.tir.Add: 12, + tvm.tir.Sub: 12, + tvm.tir.LT: 10, + tvm.tir.LE: 10, + tvm.tir.GT: 10, + tvm.tir.GE: 10, + tvm.tir.EQ: 10, + tvm.tir.NE: 10, + tvm.tir.And: 5, + tvm.tir.Or: 4, + # Atoms (Var, IntImm) have the highest precedence implicitly + } + # By default, atomic expressions (variables, constants) have the highest precedence + ATOMIC_PRECEDENCE = 100 + + node_to_result_map = {} # Stores (string, precedence) for each node + + def _visitor(node): + # 2. Visitor returns (str, precedence) tuple + if node in node_to_result_map: + return + + if isinstance(node, tvm.tir.Var): + s, p = node.name, ATOMIC_PRECEDENCE + elif isinstance(node, (tvm.tir.IntImm, tvm.tir.FloatImm)): + s, p = str(node.value), ATOMIC_PRECEDENCE + elif isinstance(node, tvm.tir.Cast): + # C-style cast has high precedence + value_str, _ = node_to_result_map[node.value] + if dtype_map is None: + s = f"({node.dtype}){value_str}" + else: + s = f"({dtype_map[node.dtype]}){value_str}" + p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE) + elif isinstance( + node, + (tvm.tir.Mul, tvm.tir.FloorDiv, tvm.tir.Add, tvm.tir.Sub, tvm.tir.FloorMod, tvm.tir.LT, + tvm.tir.LE, tvm.tir.GT, tvm.tir.GE, tvm.tir.EQ, tvm.tir.NE, tvm.tir.And, tvm.tir.Or)): + op_map = { + tvm.tir.Mul: "*", + tvm.tir.FloorDiv: "/", + tvm.tir.Add: "+", + tvm.tir.Sub: "-", + tvm.tir.FloorMod: "%", + tvm.tir.LT: "<", + tvm.tir.LE: "<=", + tvm.tir.GT: ">", + tvm.tir.GE: ">=", + tvm.tir.EQ: "==", + tvm.tir.NE: "!=", + tvm.tir.And: "and", + tvm.tir.Or: "or", + } + op_str = f" {op_map[type(node)]} " + my_precedence = PRECEDENCE[type(node)] + + a_str, a_precedence = node_to_result_map[node.a] + b_str, b_precedence = node_to_result_map[node.b] + + # 3. Add parentheses intelligently + # Add parentheses if the left operand's precedence is lower than the current operator + if a_precedence < my_precedence: + a_str = f"({a_str})" + # Add parentheses if the right operand's precedence is lower than or equal to the current operator + # 'Equal' is to handle non-associative operations, e.g., a - (b - c) + if b_precedence <= my_precedence: + b_str = f"({b_str})" + + s = f"{a_str}{op_str}{b_str}" + p = my_precedence + elif isinstance(node, (tvm.tir.Min, tvm.tir.Max)): + op_name = "min" if isinstance(node, tvm.tir.Min) else "max" + a_str, _ = node_to_result_map[node.a] + b_str, _ = node_to_result_map[node.b] + s = f"{op_name}({a_str}, {b_str})" + # Function calls have high precedence + p = PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + + # Parse known calls in TileScale + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_rank"): + s, p = "tl::get_rank()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_num_ranks"): + s, p = "tl::get_num_ranks()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_remote_base_ptr"): + pe_str, _ = node_to_result_map[node.args[0]] + s, p = "tl::get_remote_base_ptr(" + pe_str + ")", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_local_base"): + s, p = "tl::get_local_base()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_uintptr_t"): + ptr_str, _ = node_to_result_map[node.args[0]] + s, p = "tl::get_uintptr_t(" + ptr_str + ")", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + else: + # Fallback for unhandled expression types + s, p = str(node), 0 + + node_to_result_map[node] = (s, p) + + # Perform post-order traversal + tvm.tir.stmt_functor.post_order_visit(expr, _visitor) + + return next(iter(node_to_result_map[expr]), "") diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index c88cbc6cc..f046f82fa 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -6,7 +6,7 @@ from tvm import IRModule from tvm.target import Target from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, - is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr) + is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr, tilescale_pythonic_expr) import re import logging import textwrap @@ -61,6 +61,7 @@ if (error_buf) std::snprintf(error_buf, 256, "cudaMemcpyToSymbol failed: %s", cudaGetErrorString(err)); return static_cast(err); }} + host_meta_data = (uint64_t*)host_table; return 0; }} """ @@ -283,6 +284,9 @@ def __init__(self, def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: return pythonic_expr(expr, self._TYPE_MAP) + def _tilescale_pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: + return tilescale_pythonic_expr(expr, self._TYPE_MAP) + def is_tma_descriptor_arg(self, arg_name: str) -> bool: return arg_name in self.prim_func.buffer_map @@ -325,7 +329,12 @@ def func_call_args(s, # Extract the function call arguments matching the function definition def maybe_desc(name: str, matches: list[str], i: int): match = matches[i] - if not (match == name + "_desc" or match.startswith(name + "_desc_")): + if not (match == name + "_desc" \ + or match.startswith(name + "_desc_") + or (match.startswith(name + "_symm_") and match.endswith("_desc")) + # The last cases belongs to TMA copy from symm buffer + # Check naming in src/transform/declare_symm_buffer.cc + ): return False desc_decls = [] if desc_name_map is not None: @@ -453,6 +462,8 @@ def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], tma_create_str, _, dtype, tensor_rank, globalAddress, *remaining_args = args + globalAddress = self._tilescale_pythonic_expr(globalAddress) + is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col") dtype = self._pythonic_expr(dtype) tensor_rank = int(self._pythonic_expr(tensor_rank)) diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index 46e1ef58b..52cfa1999 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -10,8 +10,8 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, dst: tir.Buffer | tir.BufferLoad, - src_pe: Optional[tir.PrimExpr] = -1, - dst_pe: Optional[tir.PrimExpr] = -1, + src_pe: Optional[tir.PrimExpr | tir.IntImm] = -1, + dst_pe: Optional[tir.PrimExpr | tir.IntImm] = -1, coalesced_width: int | None = None, disable_tma: bool = False, eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, @@ -92,10 +92,9 @@ def _to_region(data, access_type): eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] assert src_pe == -1 or dst_pe == -1, "At least one of src_pe or dst_pe must be local rank" - is_remote_copy = src_pe is not None or dst_pe is not None return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width, - disable_tma, eviction_policy, src_pe, dst_pe, is_remote_copy) + disable_tma, eviction_policy, src_pe, dst_pe) def c2d_im2col(img: tir.Buffer, From ee46984f249c1be460b6aa506e6af5cac6cf0c5e Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Sat, 1 Nov 2025 08:05:46 +0000 Subject: [PATCH 05/14] lint --- .../primitives/example_tilescale_copy.py | 21 +-- src/op/copy.cc | 16 ++- src/op/copy.h | 4 +- src/op/distributed.cc | 6 +- src/op/distributed.h | 2 +- src/target/codegen_cuda.cc | 3 +- src/tl_templates/cuda/distributed.h | 23 ++-- src/transform/declare_symm_buffer.cc | 126 +++++++++++------- src/transform/lower_hopper_intrin.cc | 11 +- tilelang/jit/adapter/utils.py | 6 +- tilelang/jit/adapter/wrapper.py | 3 +- tilelang/language/copy.py | 19 +-- tilelang/language/proxy.py | 22 ++- tilelang/transform/__init__.py | 3 +- 14 files changed, 151 insertions(+), 114 deletions(-) diff --git a/examples/distributed/primitives/example_tilescale_copy.py b/examples/distributed/primitives/example_tilescale_copy.py index 6d96ab4db..d0588371c 100644 --- a/examples/distributed/primitives/example_tilescale_copy.py +++ b/examples/distributed/primitives/example_tilescale_copy.py @@ -19,14 +19,14 @@ def simt_push_buffer( dst: T.Tensor((M, N), "float32"), src: T.Tensor((M, N), "float32"), ): - with T.Kernel((1), threads=threads) as (bx): + with T.Kernel((1), threads=threads): rank = T.alloc_local([1], "uint64") rank[0] = T.get_rank() T.copy( src, dst, - dst_pe=1-rank[0], + dst_pe=1 - rank[0], disable_tma=True # Ensure testing SIMT remote copy ) @@ -51,7 +51,7 @@ def simt_push_tile( T.copy( smem, dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - dst_pe=1-rank[0], + dst_pe=1 - rank[0], disable_tma=True # Ensure testing SIMT remote copy ) @@ -70,7 +70,7 @@ def simt_pull_tile( T.copy( src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], smem, - src_pe=1-rank[0], + src_pe=1 - rank[0], disable_tma=True # Ensure testing SIMT remote copy ) @@ -95,7 +95,8 @@ def tma_load_tile( T.copy( src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], smem, - src_pe=1-rank, + src_pe=1 - T.get_rank(), + # NOTE(wt): We cannot use rank[0] as above for TMA remote copy currently. ) T.copy( @@ -124,7 +125,7 @@ def tma_store_tile( T.copy( smem, dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - dst_pe=1-rank + dst_pe=1 - T.get_rank() ) return { @@ -137,7 +138,7 @@ def tma_store_tile( def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): - M = args.M + M = args.M N = args.N BLOCK_M = 64 BLOCK_N = 128 @@ -152,8 +153,10 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): local_rank=local_rank, num_local_ranks=num_local_ranks, group=group) - - kernel = get_kernel(M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel, rank=local_rank) # only TMA kernels need compile-time aware peer rank + + kernel = get_kernel( + M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel, + rank=local_rank) # only TMA kernels need compile-time aware peer rank kernel.initialize(allocator=allocator) if local_rank == 0: print(kernel.get_kernel_source()) diff --git a/src/op/copy.cc b/src/op/copy.cc index 2fccd512d..6ca796e28 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -165,20 +165,26 @@ Copy::Copy(Array args, BufferMap vmap) { } if (node->is_remote_push()) { - ICHECK(node->dst.scope()=="global") << "Can only copy to peer's global memory, but got " << node->dst.scope(); + ICHECK(node->dst.scope() == "global") + << "Can only copy to peer's global memory, but got " + << node->dst.scope(); } else if (node->is_remote_pull()) { - ICHECK(node->src.scope()=="global") << "Can only pull from peer's global memory, but got " << node->src.scope(); + ICHECK(node->src.scope() == "global") + << "Can only pull from peer's global memory, but got " + << node->src.scope(); } - + data_ = std::move(node); } bool CopyNode::is_remote_push() const { - return !(dst_pe->IsInstance() && dst_pe.as()->value == -1); + return !(dst_pe->IsInstance() && + dst_pe.as()->value == -1); } bool CopyNode::is_remote_pull() const { - return !(src_pe->IsInstance() && src_pe.as()->value == -1); + return !(src_pe->IsInstance() && + src_pe.as()->value == -1); } bool CopyNode::is_remote_copy() const { diff --git a/src/op/copy.h b/src/op/copy.h index ad6037227..a4ea4f53a 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -93,8 +93,8 @@ class CopyNode : public TileOperatorNode { Bool disable_tma = Bool(false); // Whether to disable TMA acceleration // Params for remote copy - PrimExpr src_pe; // Source PE for remote copy - PrimExpr dst_pe; // Destination PE for remote copy + PrimExpr src_pe; // Source PE for remote copy + PrimExpr dst_pe; // Destination PE for remote copy Buffer symm_buffer; // Symmetric buffer for remote copy bool is_remote_copy() const; diff --git a/src/op/distributed.cc b/src/op/distributed.cc index 20635479e..aa55c065b 100644 --- a/src/op/distributed.cc +++ b/src/op/distributed.cc @@ -208,9 +208,9 @@ TIR_DEFINE_TL_BUILTIN(get_remote_base_ptr) Integer(CallEffectKind::kOpaque)); TIR_DEFINE_TL_BUILTIN(get_local_base) -.set_num_inputs(0) -.set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); TIR_DEFINE_TL_BUILTIN(get_uintptr_t) .set_num_inputs(1) diff --git a/src/op/distributed.h b/src/op/distributed.h index aaae98aef..66907ee53 100644 --- a/src/op/distributed.h +++ b/src/op/distributed.h @@ -229,7 +229,7 @@ const Op &get_remote_base_ptr(); /*! * \brief tvm intrinsics for getting the local base pointer */ - const Op &get_local_base(); +const Op &get_local_base(); /*! * \brief tvm intrinsics for getting the uintptr_t of a pointer diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 97dc0fa33..660125d16 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -297,7 +297,8 @@ std::string CodeGenTileLangCUDA::Finish() { if (use_distributed_) { decl_stream << "uint64_t __constant__ meta_data[1024];\n"; - decl_stream << "uint64_t* host_meta_data = nullptr;\n"; // An alias of host_table + decl_stream + << "uint64_t* host_meta_data = nullptr;\n"; // An alias of host_table } decl_stream << "#ifdef ENABLE_BF16\n"; decl_stream << "#include \n"; diff --git a/src/tl_templates/cuda/distributed.h b/src/tl_templates/cuda/distributed.h index a250ca010..b5c043e05 100644 --- a/src/tl_templates/cuda/distributed.h +++ b/src/tl_templates/cuda/distributed.h @@ -6,7 +6,7 @@ namespace tl { extern "C" __device__ uint64_t meta_data[1024]; -extern "C" __host__ uint64_t* host_meta_data; +extern "C" __host__ uint64_t *host_meta_data; TL_HOST_DEVICE uint64_t get_rank() { #ifdef __CUDA_ARCH__ @@ -24,17 +24,17 @@ TL_HOST_DEVICE uint64_t get_num_ranks() { #endif } -TL_HOST_DEVICE void* get_remote_base_ptr(uint64_t rank) { - #ifdef __CUDA_ARCH__ - return (void*)meta_data[2 + rank]; - #else - return (void*)host_meta_data[2 + rank]; - #endif - } - +TL_HOST_DEVICE void *get_remote_base_ptr(uint64_t rank) { +#ifdef __CUDA_ARCH__ + return (void *)meta_data[2 + rank]; +#else + return (void *)host_meta_data[2 + rank]; +#endif +} // NOTE(wt): Be careful about the return types here! -// get_local_base() returns u64 since I could not find a way cast u64 to ptr in tir +// get_local_base() returns u64 since I could not find a way cast u64 to ptr in +// tir TL_HOST_DEVICE uint64_t get_local_base() { #ifdef __CUDA_ARCH__ return meta_data[2 + get_rank()]; @@ -43,7 +43,8 @@ TL_HOST_DEVICE uint64_t get_local_base() { #endif } -template TL_HOST_DEVICE uint64_t get_uintptr_t(dtype_t *ptr) { +template +TL_HOST_DEVICE uint64_t get_uintptr_t(dtype_t *ptr) { return reinterpret_cast(ptr); } diff --git a/src/transform/declare_symm_buffer.cc b/src/transform/declare_symm_buffer.cc index db934cab0..90174c9dc 100644 --- a/src/transform/declare_symm_buffer.cc +++ b/src/transform/declare_symm_buffer.cc @@ -2,7 +2,8 @@ /*! * \file declare_symm_buffer.cc - * \brief Declare the symmetry buffer to prepare for operators that need buffers on peer's symm heap + * \brief Declare the symmetry buffer to prepare for operators that need buffers + * on peer's symm heap */ #include @@ -25,22 +26,23 @@ namespace tl { using namespace tir; using tvm::transform::PassContext; -static int name_suffix_id = 0; // Avoid name collision for symm buffers, start from 0 +static int name_suffix_id = + 0; // Avoid name collision for symm buffers, start from 0 -/* Create a PrimExpr to calculate the symmetry pointer given a local ptr and target PE */ +/* Create a PrimExpr to calculate the symmetry pointer given a local ptr and + * target PE */ PrimExpr CalculateSymmPtr(PrimExpr ptr, PrimExpr pe) { PrimExpr local_rank = Call(DataType::Int(64), tl::get_rank(), {}); - PrimExpr local_base_ptr = - Call(DataType::Handle(), tl::get_local_base(), {}); + PrimExpr local_base_ptr = Call(DataType::Handle(), tl::get_local_base(), {}); PrimExpr offset_to_base = - Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {ptr}), - local_base_ptr); - PrimExpr result = Call(DataType::Handle(), tl::get_remote_base_ptr(), {pe}) + offset_to_base; + Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {ptr}), local_base_ptr); + PrimExpr result = Call(DataType::Handle(), tl::get_remote_base_ptr(), {pe}) + + offset_to_base; return result; } /*! - * \brief Declare the symmetry buffer to prepare for operators + * \brief Declare the symmetry buffer to prepare for operators * that need buffers on peer's symm heap */ class SymmBufferDeclarer : public StmtExprMutator { @@ -63,7 +65,8 @@ class SymmBufferDeclarer : public StmtExprMutator { Stmt VisitStmt_(const EvaluateNode *op) final { // Check if this Evaluate contains a Call if (const CallNode *call_op = op->value.as()) { - auto parsed_op = ParseOperator(GetRef(call_op), buffer_data_to_buffer_); + auto parsed_op = + ParseOperator(GetRef(call_op), buffer_data_to_buffer_); if (parsed_op.defined() && parsed_op.as()) { // LOG(INFO) << "Found copy"; @@ -74,55 +77,65 @@ class SymmBufferDeclarer : public StmtExprMutator { Array dst_range = parsed_op.as()->dst_range; // 1. Calculate symm dst ptr - PrimExpr symm_dst_ptr_expr = CalculateSymmPtr(dst->data, parsed_op.as()->dst_pe); + PrimExpr symm_dst_ptr_expr = + CalculateSymmPtr(dst->data, parsed_op.as()->dst_pe); // LOG(INFO) << "Symm dst ptr expr: " << symm_dst_ptr_expr; // 2. Create a let binding - String storage_scope = dst->data->type_annotation.as()->storage_scope; - Var symm_dst_var = Var(dst->name+"_symm_"+std::to_string(name_suffix_id++), - PointerType(PrimType(dst->dtype), storage_scope)); + String storage_scope = + dst->data->type_annotation.as()->storage_scope; + Var symm_dst_var = + Var(dst->name + "_symm_" + std::to_string(name_suffix_id++), + PointerType(PrimType(dst->dtype), storage_scope)); // 3. Create modified dst buffer with symm var dst.CopyOnWrite()->data = symm_dst_var; // 4. Rebuild the destination region call with the modified buffer - // RegionOp args: [BufferLoad(min_indices), access_mask, extent_0, extent_1, ...] + // RegionOp args: [BufferLoad(min_indices), access_mask, extent_0, + // extent_1, ...] Array dst_region_mins; Array dst_region_extents; - for (const Range& r : dst_range) { + for (const Range &r : dst_range) { dst_region_mins.push_back(r->min); dst_region_extents.push_back(r->extent); } BufferLoad dst_load(dst, dst_region_mins); - + Array dst_region_args; dst_region_args.push_back(dst_load); - dst_region_args.push_back(IntImm(DataType::Int(32), call_op->args[1].as()->args[1].as()->value)); // access_mask - for (const PrimExpr& extent : dst_region_extents) { + dst_region_args.push_back( + IntImm(DataType::Int(32), call_op->args[1] + .as() + ->args[1] + .as() + ->value)); // access_mask + for (const PrimExpr &extent : dst_region_extents) { dst_region_args.push_back(extent); } - + // Create new Call for the destination region - Call dst_region_call = Call(call_op->args[1].as()->dtype, - call_op->args[1].as()->op, - dst_region_args, - call_op->args[1].as()->span); + Call dst_region_call = + Call(call_op->args[1].as()->dtype, + call_op->args[1].as()->op, dst_region_args, + call_op->args[1].as()->span); // 5. Rebuild the Copy call with modified args Array new_copy_args; new_copy_args.push_back(call_op->args[0]); // src region (unchanged) - new_copy_args.push_back(dst_region_call); // modified dst region + new_copy_args.push_back(dst_region_call); // modified dst region // Copy remaining args for (size_t i = 2; i < call_op->args.size(); i++) { new_copy_args.push_back(call_op->args[i]); } - + // Create the modified copy call - Call new_copy_call = Call(call_op->dtype, call_op->op, new_copy_args, call_op->span); - + Call new_copy_call = + Call(call_op->dtype, call_op->op, new_copy_args, call_op->span); + // Wrap it in an Evaluate statement Stmt modified_stmt = Evaluate(new_copy_call); - + // Wrap with LetStmt that defines the symm pointer return LetStmt(symm_dst_var, symm_dst_ptr_expr, modified_stmt); } else if (parsed_op.as()->is_remote_pull()) { @@ -132,61 +145,71 @@ class SymmBufferDeclarer : public StmtExprMutator { Array src_range = parsed_op.as()->src_range; // 1. Calculate symm src ptr - PrimExpr symm_src_ptr_expr = CalculateSymmPtr(src->data, parsed_op.as()->src_pe); + PrimExpr symm_src_ptr_expr = + CalculateSymmPtr(src->data, parsed_op.as()->src_pe); // LOG(INFO) << "Symm src ptr expr: " << symm_src_ptr_expr; - + // 2. Create a let binding - String storage_scope = src->data->type_annotation.as()->storage_scope; - Var symm_src_var = Var(src->name+"_symm_"+std::to_string(name_suffix_id++), - PointerType(PrimType(src->dtype), storage_scope)); + String storage_scope = + src->data->type_annotation.as()->storage_scope; + Var symm_src_var = + Var(src->name + "_symm_" + std::to_string(name_suffix_id++), + PointerType(PrimType(src->dtype), storage_scope)); // 3. Create modified src buffer with symm var src.CopyOnWrite()->data = symm_src_var; // 4. Rebuild the source region call with the modified buffer - // RegionOp args: [BufferLoad(min_indices), access_mask, extent_0, extent_1, ...] + // RegionOp args: [BufferLoad(min_indices), access_mask, extent_0, + // extent_1, ...] Array src_region_mins; Array src_region_extents; - for (const Range& r : src_range) { + for (const Range &r : src_range) { src_region_mins.push_back(r->min); src_region_extents.push_back(r->extent); } BufferLoad src_load(src, src_region_mins); - + Array src_region_args; src_region_args.push_back(src_load); - src_region_args.push_back(IntImm(DataType::Int(32), call_op->args[1].as()->args[1].as()->value)); // access_mask - for (const PrimExpr& extent : src_region_extents) { + src_region_args.push_back( + IntImm(DataType::Int(32), call_op->args[1] + .as() + ->args[1] + .as() + ->value)); // access_mask + for (const PrimExpr &extent : src_region_extents) { src_region_args.push_back(extent); } - + // Create new Call for the source region - Call src_region_call = Call(call_op->args[0].as()->dtype, - call_op->args[0].as()->op, - src_region_args, - call_op->args[0].as()->span); + Call src_region_call = + Call(call_op->args[0].as()->dtype, + call_op->args[0].as()->op, src_region_args, + call_op->args[0].as()->span); // 5. Rebuild the Copy call with modified args Array new_copy_args; - new_copy_args.push_back(src_region_call); // modified src region + new_copy_args.push_back(src_region_call); // modified src region new_copy_args.push_back(call_op->args[1]); // dst region (unchanged) // Copy remaining args for (size_t i = 2; i < call_op->args.size(); i++) { new_copy_args.push_back(call_op->args[i]); } - + // Create the modified copy call - Call new_copy_call = Call(call_op->dtype, call_op->op, new_copy_args, call_op->span); - + Call new_copy_call = + Call(call_op->dtype, call_op->op, new_copy_args, call_op->span); + // Wrap it in an Evaluate statement Stmt modified_stmt = Evaluate(new_copy_call); - + // Wrap with LetStmt that defines the symm pointer return LetStmt(symm_src_var, symm_src_ptr_expr, modified_stmt); } } } - + // Default: use parent's visitor return StmtExprMutator::VisitStmt_(op); } @@ -199,7 +222,8 @@ tvm::transform::Pass DeclareSymmBuffer() { f = SymmBufferDeclarer::Apply(std::move(f)); return f; }; - return tir::transform::CreatePrimFuncPass(pass_func, 0, "tl.DeclareSymmBuffer", {}); + return tir::transform::CreatePrimFuncPass(pass_func, 0, + "tl.DeclareSymmBuffer", {}); } TVM_FFI_STATIC_INIT_BLOCK({ diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index 1728ec279..2751aeab7 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -39,7 +39,7 @@ class LowerHopperIntrin : public StmtExprMutator { CHECK(0) << call->op; } init_desc_args.push_back(var); - + // Inline let-bound variables in the descriptor arguments for (const auto &arg : call->args) { if (auto arg_var = arg.as()) { @@ -54,7 +54,7 @@ class LowerHopperIntrin : public StmtExprMutator { init_desc_args.push_back(arg); } } - + // add to function attribute Call init_desc = Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args); @@ -117,12 +117,12 @@ class LowerHopperIntrin : public StmtExprMutator { Stmt VisitStmt_(const LetStmtNode *op) final { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); - + // Check if this variable is related to TMA (used in descriptor creation) if (tma_related_vars_.count(op->var)) { tma_let_bindings_[op->var] = value; } - + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { @@ -185,7 +185,8 @@ class LowerHopperIntrin : public StmtExprMutator { Array init_mbarrier_calls_; std::unordered_map desc_map_; std::unordered_set tma_related_vars_; - std::unordered_map tma_let_bindings_; + std::unordered_map + tma_let_bindings_; LowerHopperIntrin(bool disable_shuffle_elect) : disable_shuffle_elect_(disable_shuffle_elect) {} bool disable_shuffle_elect_; diff --git a/tilelang/jit/adapter/utils.py b/tilelang/jit/adapter/utils.py index 2ce09495c..a8beca1b5 100644 --- a/tilelang/jit/adapter/utils.py +++ b/tilelang/jit/adapter/utils.py @@ -326,12 +326,14 @@ def _visitor(node): s, p = "tl::get_num_ranks()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_remote_base_ptr"): pe_str, _ = node_to_result_map[node.args[0]] - s, p = "tl::get_remote_base_ptr(" + pe_str + ")", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + s, p = "tl::get_remote_base_ptr(" + pe_str + ")", PRECEDENCE.get( + tvm.tir.Call, ATOMIC_PRECEDENCE) elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_local_base"): s, p = "tl::get_local_base()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_uintptr_t"): ptr_str, _ = node_to_result_map[node.args[0]] - s, p = "tl::get_uintptr_t(" + ptr_str + ")", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + s, p = "tl::get_uintptr_t(" + ptr_str + ")", PRECEDENCE.get( + tvm.tir.Call, ATOMIC_PRECEDENCE) else: # Fallback for unhandled expression types s, p = str(node), 0 diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index f046f82fa..f0061cdd5 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -6,7 +6,8 @@ from tvm import IRModule from tvm.target import Target from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, - is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr, tilescale_pythonic_expr) + is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr, + tilescale_pythonic_expr) import re import logging import textwrap diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index 52cfa1999..519059e99 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -8,14 +8,15 @@ from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region -def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, - dst: tir.Buffer | tir.BufferLoad, - src_pe: Optional[tir.PrimExpr | tir.IntImm] = -1, - dst_pe: Optional[tir.PrimExpr | tir.IntImm] = -1, - coalesced_width: int | None = None, - disable_tma: bool = False, - eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, - ): +def copy( + src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, + dst: tir.Buffer | tir.BufferLoad, + src_pe: Optional[tir.PrimExpr | tir.IntImm] = -1, + dst_pe: Optional[tir.PrimExpr | tir.IntImm] = -1, + coalesced_width: int | None = None, + disable_tma: bool = False, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, +): """Copy data between memory regions. Args: @@ -90,7 +91,7 @@ def _to_region(data, access_type): eviction_policy = 0 else: eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] - + assert src_pe == -1 or dst_pe == -1, "At least one of src_pe or dst_pe must be local rank" return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width, diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py index fade0599f..49fa09dad 100644 --- a/tilelang/language/proxy.py +++ b/tilelang/language/proxy.py @@ -302,16 +302,12 @@ def make_tensor(ptr: Var, return Tensor.from_ptr(ptr, shape, dtype, strides) -def make_tensor_like( - tensor, - ptr: Var, - shape: Optional[tuple[PrimExpr, ...]] = None, - dtype: Optional[str] = None, - strides: Optional[tuple[PrimExpr, ...]] = None -) -> tir.Buffer: - return Tensor.from_ptr( - ptr if ptr is not None else tensor.data, - shape if shape is not None else tensor.shape, - dtype if dtype is not None else tensor.dtype, - strides if strides is not None else tensor.strides - ) \ No newline at end of file +def make_tensor_like(tensor, + ptr: Var, + shape: Optional[tuple[PrimExpr, ...]] = None, + dtype: Optional[str] = None, + strides: Optional[tuple[PrimExpr, ...]] = None) -> tir.Buffer: + return Tensor.from_ptr(ptr if ptr is not None else tensor.data, + shape if shape is not None else tensor.shape, + dtype if dtype is not None else tensor.dtype, + strides if strides is not None else tensor.strides) diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 88cdb121e..3c174ecab 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -488,10 +488,11 @@ def LayoutReducer(): # TileScale passes + def DeclareSymmBuffer(): """ Declare symmetric buffer on peer's global memory. This pass prepares for distributed operators, such as remote copy. """ - return _ffi_api.DeclareSymmBuffer() # type: ignore \ No newline at end of file + return _ffi_api.DeclareSymmBuffer() # type: ignore From 3664eb22052d70a316528931246f137611f5fa1c Mon Sep 17 00:00:00 2001 From: Tong WU <109033598+Rachmanino@users.noreply.github.com> Date: Sun, 2 Nov 2025 18:50:57 +0800 Subject: [PATCH 06/14] Update src/tl_templates/cuda/common.h Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- src/tl_templates/cuda/common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index ad3a3e81e..f57073ff3 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -33,7 +33,7 @@ using int4_t = int4; #define TL_DEVICE __forceinline__ __device__ #define TL_DEVICE_NOINLINE __noinline__ __device__ #define TL_HOST __forceinline__ __host__ -#define TL_HOST_NOINLINE __noinline__ _host__ +#define TL_HOST_NOINLINE __noinline__ __host__ #define TL_HOST_DEVICE __forceinline__ __host__ __device__ #define TL_HOST_DEVICE_NOINLINE __noinline__ __host__ __device__ #define TL_PATCH From e79f865dcafdade4e27f9baa92943e1f1c10cc8d Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Mon, 3 Nov 2025 11:17:14 +0800 Subject: [PATCH 07/14] lint --- examples/distributed/primitives/example_tilescale_copy.py | 5 ++--- tilelang/language/copy.py | 8 ++++---- tilelang/language/proxy.py | 8 ++++---- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/examples/distributed/primitives/example_tilescale_copy.py b/examples/distributed/primitives/example_tilescale_copy.py index d0588371c..4344111d1 100644 --- a/examples/distributed/primitives/example_tilescale_copy.py +++ b/examples/distributed/primitives/example_tilescale_copy.py @@ -95,7 +95,7 @@ def tma_load_tile( T.copy( src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], smem, - src_pe=1 - T.get_rank(), + src_pe=1 - T.get_rank(), # NOTE(wt): We cannot use rank[0] as above for TMA remote copy currently. ) @@ -125,8 +125,7 @@ def tma_store_tile( T.copy( smem, dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], - dst_pe=1 - T.get_rank() - ) + dst_pe=1 - T.get_rank()) return { 'simt_push_buffer': simt_push_buffer, diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index 519059e99..fc900f7a5 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -1,7 +1,7 @@ """The language interface for tl programs.""" from __future__ import annotations -from typing import Literal, Optional +from typing import Literal from tilelang import language as T from tilelang.utils.language import get_buffer_region_from_load from tvm import ir, tir @@ -11,8 +11,8 @@ def copy( src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, dst: tir.Buffer | tir.BufferLoad, - src_pe: Optional[tir.PrimExpr | tir.IntImm] = -1, - dst_pe: Optional[tir.PrimExpr | tir.IntImm] = -1, + src_pe: tir.PrimExpr | tir.IntImm | None = -1, + dst_pe: tir.PrimExpr | tir.IntImm | None = -1, coalesced_width: int | None = None, disable_tma: bool = False, eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, @@ -27,7 +27,7 @@ def copy( coalesced_width (Optional[int], optional): Width for coalesced memory access. Defaults to None. disable_tma (bool, optional): Whether to disable TMA. Defaults to False. eviction_policy (Optional[Literal["evict_normal", "evict_first", "evict_last"]], optional): Eviction policy. Defaults to None. - + Raises: TypeError: If copy extents cannot be deduced from arguments diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py index 49fa09dad..5aced421a 100644 --- a/tilelang/language/proxy.py +++ b/tilelang/language/proxy.py @@ -1,7 +1,7 @@ """The language interface for tl programs.""" from __future__ import annotations -from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING +from typing import Any, Sequence, SupportsIndex, TYPE_CHECKING from typing_extensions import Self from tvm import tir @@ -304,9 +304,9 @@ def make_tensor(ptr: Var, def make_tensor_like(tensor, ptr: Var, - shape: Optional[tuple[PrimExpr, ...]] = None, - dtype: Optional[str] = None, - strides: Optional[tuple[PrimExpr, ...]] = None) -> tir.Buffer: + shape: tuple[PrimExpr, ...] | None = None, + dtype: str | None = None, + strides: tuple[PrimExpr, ...] | None = None) -> tir.Buffer: return Tensor.from_ptr(ptr if ptr is not None else tensor.data, shape if shape is not None else tensor.shape, dtype if dtype is not None else tensor.dtype, From 7cb7ec69c5043efd0d2e63af6dd1de528230588c Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Mon, 3 Nov 2025 11:34:43 +0800 Subject: [PATCH 08/14] fix bot's comments --- examples/distributed/primitives/example_tilescale_copy.py | 5 ++--- src/op/copy.cc | 3 +++ src/tl_templates/cuda/distributed.h | 2 +- tilelang/language/copy.py | 2 -- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/distributed/primitives/example_tilescale_copy.py b/examples/distributed/primitives/example_tilescale_copy.py index 4344111d1..d3d76a403 100644 --- a/examples/distributed/primitives/example_tilescale_copy.py +++ b/examples/distributed/primitives/example_tilescale_copy.py @@ -12,7 +12,7 @@ @tilelang.jit -def get_kernel(M, N, block_M, block_N, threads, kernel='simt_push_tile', rank=None): +def get_kernel(M, N, block_M, block_N, threads, kernel='simt_push_tile'): @T.prim_func def simt_push_buffer( @@ -154,8 +154,7 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): group=group) kernel = get_kernel( - M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel, - rank=local_rank) # only TMA kernels need compile-time aware peer rank + M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel) kernel.initialize(allocator=allocator) if local_rank == 0: print(kernel.get_kernel_source()) diff --git a/src/op/copy.cc b/src/op/copy.cc index 6ca796e28..1eb40a277 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -164,6 +164,9 @@ Copy::Copy(Array args, BufferMap vmap) { node->dst_pe = args[6]; } + ICHECK(!(node->is_remote_push() && node->is_remote_pull())) + << "At least one of src_pe or dst_pe must be local rank"; + if (node->is_remote_push()) { ICHECK(node->dst.scope() == "global") << "Can only copy to peer's global memory, but got " diff --git a/src/tl_templates/cuda/distributed.h b/src/tl_templates/cuda/distributed.h index b5c043e05..a0bde18da 100644 --- a/src/tl_templates/cuda/distributed.h +++ b/src/tl_templates/cuda/distributed.h @@ -6,7 +6,7 @@ namespace tl { extern "C" __device__ uint64_t meta_data[1024]; -extern "C" __host__ uint64_t *host_meta_data; +extern "C" uint64_t *host_meta_data; TL_HOST_DEVICE uint64_t get_rank() { #ifdef __CUDA_ARCH__ diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index fc900f7a5..ba27f24d4 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -92,8 +92,6 @@ def _to_region(data, access_type): else: eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] - assert src_pe == -1 or dst_pe == -1, "At least one of src_pe or dst_pe must be local rank" - return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width, disable_tma, eviction_policy, src_pe, dst_pe) From 08b84c880a80770957f2be2c00704d25ee504e60 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Mon, 3 Nov 2025 11:35:04 +0800 Subject: [PATCH 09/14] lint --- examples/distributed/primitives/example_tilescale_copy.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/distributed/primitives/example_tilescale_copy.py b/examples/distributed/primitives/example_tilescale_copy.py index d3d76a403..5ec7e5b6e 100644 --- a/examples/distributed/primitives/example_tilescale_copy.py +++ b/examples/distributed/primitives/example_tilescale_copy.py @@ -153,8 +153,7 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): num_local_ranks=num_local_ranks, group=group) - kernel = get_kernel( - M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel) + kernel = get_kernel(M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel) kernel.initialize(allocator=allocator) if local_rank == 0: print(kernel.get_kernel_source()) From f2906f7c36c519f704af45776a8475f7d48511af Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Mon, 3 Nov 2025 14:40:20 +0000 Subject: [PATCH 10/14] bugfix of parse_op --- src/transform/declare_symm_buffer.cc | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/transform/declare_symm_buffer.cc b/src/transform/declare_symm_buffer.cc index 90174c9dc..23e487136 100644 --- a/src/transform/declare_symm_buffer.cc +++ b/src/transform/declare_symm_buffer.cc @@ -54,6 +54,11 @@ class SymmBufferDeclarer : public StmtExprMutator { SymmBufferDeclarer declarer; + // Initialize buffer map from function's buffer_map + for (const auto &[_, buffer] : f->buffer_map) { + declarer.buffer_data_to_buffer_.Set(buffer->data, buffer); + } + // Extract symm buffer info and replace them // The LetStmt insertion will happen in VisitStmt_ before each copy f.CopyOnWrite()->body = declarer.VisitStmt(f->body); @@ -62,11 +67,28 @@ class SymmBufferDeclarer : public StmtExprMutator { }; private: + Stmt VisitStmt_(const BlockNode *op) final { + // Record the mapping from buffer data var to buffer for later lookup + for (auto buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + for (auto match_buffer : op->match_buffers) { + buffer_data_to_buffer_.Set(match_buffer->buffer->data, match_buffer->buffer); + } + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const EvaluateNode *op) final { // Check if this Evaluate contains a Call if (const CallNode *call_op = op->value.as()) { + if (call_op->op.as()) { + // Do not process the call node to the global function. + return StmtExprMutator::VisitStmt_(op); + } + // LOG(INFO) << "Found call"; auto parsed_op = ParseOperator(GetRef(call_op), buffer_data_to_buffer_); + // LOG(INFO) << "Parsed op: " << parsed_op; if (parsed_op.defined() && parsed_op.as()) { // LOG(INFO) << "Found copy"; @@ -139,7 +161,7 @@ class SymmBufferDeclarer : public StmtExprMutator { // Wrap with LetStmt that defines the symm pointer return LetStmt(symm_dst_var, symm_dst_ptr_expr, modified_stmt); } else if (parsed_op.as()->is_remote_pull()) { - // LOG(INFO) << "Found remote pull"; + LOG(INFO) << "Found remote pull"; Buffer src = parsed_op.as()->src; Array src_range = parsed_op.as()->src_range; From 10c63945a185ced4cb7bd8f89e9446a5bc5104a3 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Mon, 3 Nov 2025 14:45:28 +0000 Subject: [PATCH 11/14] fix --- src/op/copy.cc | 4 ++-- src/transform/declare_symm_buffer.cc | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index 1eb40a277..aacc62b89 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1977,10 +1977,10 @@ Array TMAIm2ColDesc::EncodeCallArgs() const { // Register the Copy operation with TVM's TIR system // This makes the copy operation available for use in TVM programs // - Takes 8 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, -// eviction_policy, src_pe, dst_pe, is_remote_copy +// eviction_policy, src_pe, dst_pe // - Marked as opaque since it has side effects (memory writes) TIR_REGISTER_TL_OP(Copy, copy) - .set_num_inputs(8) + .set_num_inputs(7) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/transform/declare_symm_buffer.cc b/src/transform/declare_symm_buffer.cc index 23e487136..3c66a9e8c 100644 --- a/src/transform/declare_symm_buffer.cc +++ b/src/transform/declare_symm_buffer.cc @@ -73,7 +73,8 @@ class SymmBufferDeclarer : public StmtExprMutator { buffer_data_to_buffer_.Set(buffer->data, buffer); } for (auto match_buffer : op->match_buffers) { - buffer_data_to_buffer_.Set(match_buffer->buffer->data, match_buffer->buffer); + buffer_data_to_buffer_.Set(match_buffer->buffer->data, + match_buffer->buffer); } return StmtExprMutator::VisitStmt_(op); } From fba413467af8cfdfc5fda1918edc3d15402dac37 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Tue, 11 Nov 2025 00:26:24 +0800 Subject: [PATCH 12/14] apply bot's suggestions --- src/transform/declare_symm_buffer.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transform/declare_symm_buffer.cc b/src/transform/declare_symm_buffer.cc index 3c66a9e8c..dfd405eee 100644 --- a/src/transform/declare_symm_buffer.cc +++ b/src/transform/declare_symm_buffer.cc @@ -32,7 +32,6 @@ static int name_suffix_id = /* Create a PrimExpr to calculate the symmetry pointer given a local ptr and * target PE */ PrimExpr CalculateSymmPtr(PrimExpr ptr, PrimExpr pe) { - PrimExpr local_rank = Call(DataType::Int(64), tl::get_rank(), {}); PrimExpr local_base_ptr = Call(DataType::Handle(), tl::get_local_base(), {}); PrimExpr offset_to_base = Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {ptr}), local_base_ptr); @@ -162,7 +161,7 @@ class SymmBufferDeclarer : public StmtExprMutator { // Wrap with LetStmt that defines the symm pointer return LetStmt(symm_dst_var, symm_dst_ptr_expr, modified_stmt); } else if (parsed_op.as()->is_remote_pull()) { - LOG(INFO) << "Found remote pull"; + // LOG(INFO) << "Found remote pull"; Buffer src = parsed_op.as()->src; Array src_range = parsed_op.as()->src_range; @@ -196,7 +195,7 @@ class SymmBufferDeclarer : public StmtExprMutator { Array src_region_args; src_region_args.push_back(src_load); src_region_args.push_back( - IntImm(DataType::Int(32), call_op->args[1] + IntImm(DataType::Int(32), call_op->args[0] .as() ->args[1] .as() From a510e74623f6a4b0badd84570f66fcaddc0881c1 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Tue, 11 Nov 2025 15:46:22 +0800 Subject: [PATCH 13/14] add test and fix compatibility --- .../primitives/test_tilescale_copy.py | 30 +++++++++++++++++++ src/op/distributed.cc | 10 +++++++ src/op/distributed.h | 20 +++++++++---- src/op/remote_copy.cc | 8 ++--- src/op/sync.cc | 2 +- src/target/codegen_cuda.cc | 7 +++++ src/tl_templates/cuda/distributed.h | 18 +++++++---- src/tl_templates/cuda/sync.h | 2 +- src/transform/declare_symm_buffer.cc | 1 - tilelang/jit/adapter/utils.py | 6 ++++ 10 files changed, 86 insertions(+), 18 deletions(-) create mode 100644 examples/distributed/primitives/test_tilescale_copy.py diff --git a/examples/distributed/primitives/test_tilescale_copy.py b/examples/distributed/primitives/test_tilescale_copy.py new file mode 100644 index 000000000..313b1e097 --- /dev/null +++ b/examples/distributed/primitives/test_tilescale_copy.py @@ -0,0 +1,30 @@ +import argparse +import tilelang.testing +import torch +import torch.multiprocessing + +import example_tilescale_copy + + +@tilelang.testing.requires_cuda +def test_example_tilescale_copy_simt_push_tile(): + args = argparse.Namespace(M=1024, N=1024, kernel='simt_push_tile') + torch.multiprocessing.spawn(example_tilescale_copy.main, args=(2, args), nprocs=2) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_tilescale_copy_tma_load_tile(): + args = argparse.Namespace(M=1024, N=1024, kernel='tma_load_tile') + torch.multiprocessing.spawn(example_tilescale_copy.main, args=(2, args), nprocs=2) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_tilescale_copy_tma_store_tile(): + args = argparse.Namespace(M=1024, N=1024, kernel='tma_store_tile') + torch.multiprocessing.spawn(example_tilescale_copy.main, args=(2, args), nprocs=2) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/src/op/distributed.cc b/src/op/distributed.cc index aa55c065b..54ed41ff9 100644 --- a/src/op/distributed.cc +++ b/src/op/distributed.cc @@ -202,6 +202,11 @@ TIR_DEFINE_TL_BUILTIN(get_num_ranks) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(get_remote_base) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(get_remote_base_ptr) .set_num_inputs(1) .set_attr("TCallEffectKind", @@ -212,6 +217,11 @@ TIR_DEFINE_TL_BUILTIN(get_local_base) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(get_local_base_ptr) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(get_uintptr_t) .set_num_inputs(1) .set_attr("TCallEffectKind", diff --git a/src/op/distributed.h b/src/op/distributed.h index 66907ee53..0331d6116 100644 --- a/src/op/distributed.h +++ b/src/op/distributed.h @@ -212,27 +212,37 @@ const Op &FcollectBlock(); const Op &CpengineCpAsync(); /*! - * \brief tvm intrinsics for getting the rank of the current process + * \brief tilescale intrinsics for getting the rank of the current process */ const Op &get_rank(); /*! - * \brief tvm intrinsics for getting the number of processes + * \brief tilescale intrinsics for getting the number of ranks */ const Op &get_num_ranks(); /*! - * \brief tvm intrinsics for getting the remote base pointer + * \brief tilescale intrinsics for getting the remote base address (u64) + */ +const Op &get_remote_base(); + +/*! + * \brief tilescale intrinsics for getting the remote base pointer */ const Op &get_remote_base_ptr(); /*! - * \brief tvm intrinsics for getting the local base pointer + * \brief tilescale intrinsics for getting the local base address (u64) */ const Op &get_local_base(); /*! - * \brief tvm intrinsics for getting the uintptr_t of a pointer + * \brief tilescale intrinsics for getting the local base pointer + */ +const Op &get_local_base_ptr(); + +/*! + * \brief tilescale intrinsics for getting the u64 value of a pointer */ const Op &get_uintptr_t(); } // namespace tl diff --git a/src/op/remote_copy.cc b/src/op/remote_copy.cc index 059d545b9..06125e875 100644 --- a/src/op/remote_copy.cc +++ b/src/op/remote_copy.cc @@ -100,12 +100,12 @@ Stmt PutOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { PrimExpr dst_addr_expr = MakeRemappedAddress(T, dst_buffer, dst_indices); PrimExpr local_rank = Call(DataType::Int(64), tl::get_rank(), {}); PrimExpr local_base_ptr = - Call(DataType::Handle(), tl::get_remote_base_ptr(), {local_rank}); + Call(DataType::Handle(), tl::get_remote_base(), {local_rank}); PrimExpr offset_to_base = Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {dst_addr_expr}), local_base_ptr); new_args.push_back( - Call(DataType::Handle(), tl::get_remote_base_ptr(), {dst_pe}) + + Call(DataType::Handle(), tl::get_remote_base(), {dst_pe}) + offset_to_base); } else { new_args.push_back(MakeRemappedAddress(T, dst_buffer, dst_indices)); @@ -206,12 +206,12 @@ Stmt GetOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { PrimExpr src_addr_expr = MakeRemappedAddress(T, src_buffer, src_indices); PrimExpr local_rank = Call(DataType::Int(64), tl::get_rank(), {}); PrimExpr local_base_ptr = - Call(DataType::Handle(), tl::get_remote_base_ptr(), {local_rank}); + Call(DataType::Handle(), tl::get_remote_base(), {local_rank}); PrimExpr offset_to_base = Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {src_addr_expr}), local_base_ptr); new_args.push_back( - Call(DataType::Handle(), tl::get_remote_base_ptr(), {src_pe}) + + Call(DataType::Handle(), tl::get_remote_base(), {src_pe}) + offset_to_base); } else { new_args.push_back(MakeRemappedAddress(T, src_buffer, src_indices)); diff --git a/src/op/sync.cc b/src/op/sync.cc index 0c83a7b8d..863fc6766 100644 --- a/src/op/sync.cc +++ b/src/op/sync.cc @@ -94,7 +94,7 @@ Stmt BarrierAllBlocksSysOpNode::Lower(const LowerArgs &T, PrimExpr rank = Call(DataType::Int(64), tl::get_rank(), {}); PrimExpr num_ranks = Call(DataType::Int(64), tl::get_num_ranks(), {}); PrimExpr local_base_ptr = - Call(DataType::Handle(), tl::get_remote_base_ptr(), {rank}); + Call(DataType::Handle(), tl::get_remote_base(), {rank}); PrimExpr offset_to_base = Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {bar_addr}), local_base_ptr); diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 660125d16..b39228b09 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1541,6 +1541,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::get_num_ranks())) { this->use_distributed_ = true; os << "tl::get_num_ranks()"; + } else if (op->op.same_as(tl::get_remote_base())) { + this->use_distributed_ = true; + std::string pe_str = this->PrintExpr(op->args[0]); + os << "tl::get_remote_base(" << pe_str << ")"; } else if (op->op.same_as(tl::get_remote_base_ptr())) { this->use_distributed_ = true; std::string pe_str = this->PrintExpr(op->args[0]); @@ -1548,6 +1552,9 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::get_local_base())) { this->use_distributed_ = true; os << "tl::get_local_base()"; + } else if (op->op.same_as(tl::get_local_base_ptr())) { + this->use_distributed_ = true; + os << "tl::get_local_base_ptr()"; } else if (op->op.same_as(tl::get_uintptr_t())) { os << "tl::get_uintptr_t(" << this->PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(builtin::tvm_fill_fragment())) { diff --git a/src/tl_templates/cuda/distributed.h b/src/tl_templates/cuda/distributed.h index a0bde18da..47a2c8e5a 100644 --- a/src/tl_templates/cuda/distributed.h +++ b/src/tl_templates/cuda/distributed.h @@ -24,17 +24,21 @@ TL_HOST_DEVICE uint64_t get_num_ranks() { #endif } -TL_HOST_DEVICE void *get_remote_base_ptr(uint64_t rank) { +// NOTE(wt): Be careful about the return types here! +// I could not find a way cast u64 to ptr in tir ? + +TL_HOST_DEVICE uint64_t get_remote_base(uint64_t rank) { #ifdef __CUDA_ARCH__ - return (void *)meta_data[2 + rank]; + return meta_data[2 + rank]; #else - return (void *)host_meta_data[2 + rank]; + return host_meta_data[2 + rank]; #endif } -// NOTE(wt): Be careful about the return types here! -// get_local_base() returns u64 since I could not find a way cast u64 to ptr in -// tir +TL_HOST_DEVICE void *get_remote_base_ptr(uint64_t rank) { + return (void *)get_remote_base(rank); +} + TL_HOST_DEVICE uint64_t get_local_base() { #ifdef __CUDA_ARCH__ return meta_data[2 + get_rank()]; @@ -43,6 +47,8 @@ TL_HOST_DEVICE uint64_t get_local_base() { #endif } +TL_HOST_DEVICE void *get_local_base_ptr() { return (void *)get_local_base(); } + template TL_HOST_DEVICE uint64_t get_uintptr_t(dtype_t *ptr) { return reinterpret_cast(ptr); diff --git a/src/tl_templates/cuda/sync.h b/src/tl_templates/cuda/sync.h index b7b4b8cb9..64e7b9ecd 100644 --- a/src/tl_templates/cuda/sync.h +++ b/src/tl_templates/cuda/sync.h @@ -155,7 +155,7 @@ TL_DEVICE void sync_grid(uint32_t *barrier) { TL_DEVICE void barrier_all_blocks_sys(int offset, int rank, int num_ranks) { // Macro to compute the barrier pointer for a given target rank #define BARRIER_PTR(tgt_rank) \ - (reinterpret_cast(get_remote_base_ptr(tgt_rank) + offset)) + (reinterpret_cast(get_remote_base(tgt_rank) + offset)) memory_fence_sys(); __syncthreads(); diff --git a/src/transform/declare_symm_buffer.cc b/src/transform/declare_symm_buffer.cc index dfd405eee..39af91819 100644 --- a/src/transform/declare_symm_buffer.cc +++ b/src/transform/declare_symm_buffer.cc @@ -39,7 +39,6 @@ PrimExpr CalculateSymmPtr(PrimExpr ptr, PrimExpr pe) { offset_to_base; return result; } - /*! * \brief Declare the symmetry buffer to prepare for operators * that need buffers on peer's symm heap diff --git a/tilelang/jit/adapter/utils.py b/tilelang/jit/adapter/utils.py index a8beca1b5..a32615f39 100644 --- a/tilelang/jit/adapter/utils.py +++ b/tilelang/jit/adapter/utils.py @@ -324,12 +324,18 @@ def _visitor(node): s, p = "tl::get_rank()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_num_ranks"): s, p = "tl::get_num_ranks()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_remote_base"): + pe_str, _ = node_to_result_map[node.args[0]] + s, p = "tl::get_remote_base(" + pe_str + ")", PRECEDENCE.get( + tvm.tir.Call, ATOMIC_PRECEDENCE) elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_remote_base_ptr"): pe_str, _ = node_to_result_map[node.args[0]] s, p = "tl::get_remote_base_ptr(" + pe_str + ")", PRECEDENCE.get( tvm.tir.Call, ATOMIC_PRECEDENCE) elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_local_base"): s, p = "tl::get_local_base()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_local_base_ptr"): + s, p = "tl::get_local_base_ptr()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_uintptr_t"): ptr_str, _ = node_to_result_map[node.args[0]] s, p = "tl::get_uintptr_t(" + ptr_str + ")", PRECEDENCE.get( From a536bdd1c87447b8d2df0e124dc33aacd4b3375b Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Tue, 11 Nov 2025 19:41:12 +0800 Subject: [PATCH 14/14] fix previous bug --- src/tl_templates/cuda/threadblock_swizzle.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/tl_templates/cuda/threadblock_swizzle.h b/src/tl_templates/cuda/threadblock_swizzle.h index 1539b657d..00a230c1a 100644 --- a/src/tl_templates/cuda/threadblock_swizzle.h +++ b/src/tl_templates/cuda/threadblock_swizzle.h @@ -4,7 +4,7 @@ namespace tl { -template TL_DEVICE dim3 rasterization2DRow() { +template TL_DEVICE dim3 rasterization2DRow() { const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; const unsigned int grid_size = gridDim.x * gridDim.y; const unsigned int panel_size = panel_width * gridDim.x; @@ -23,7 +23,8 @@ template TL_DEVICE dim3 rasterization2DRow() { return {col_idx, row_idx, blockIdx.z}; } -template TL_DEVICE dim3 rasterization2DColumn() { +template +TL_DEVICE dim3 rasterization2DColumn() { const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; const unsigned int grid_size = gridDim.x * gridDim.y; const unsigned int panel_size = panel_width * gridDim.y;