diff --git a/examples/distributed/primitives/example_tilescale_copy.py b/examples/distributed/primitives/example_tilescale_copy.py new file mode 100644 index 000000000..5ec7e5b6e --- /dev/null +++ b/examples/distributed/primitives/example_tilescale_copy.py @@ -0,0 +1,192 @@ +import os +import tilelang +import tilelang.language as T +import argparse +import torch +import torch.distributed as dist +import torch.multiprocessing +from tilelang.distributed import init_dist + +tilelang.disable_cache() +os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log + + +@tilelang.jit +def get_kernel(M, N, block_M, block_N, threads, kernel='simt_push_tile'): + + @T.prim_func + def simt_push_buffer( + dst: T.Tensor((M, N), "float32"), + src: T.Tensor((M, N), "float32"), + ): + with T.Kernel((1), threads=threads): + rank = T.alloc_local([1], "uint64") + rank[0] = T.get_rank() + + T.copy( + src, + dst, + dst_pe=1 - rank[0], + disable_tma=True # Ensure testing SIMT remote copy + ) + + @T.prim_func + def simt_push_tile( + dst: T.Tensor((M, N), "float32"), + src: T.Tensor((M, N), "float32"), + ): + with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by): + rank = T.alloc_local([1], "uint64") + rank[0] = T.get_rank() + + smem = T.alloc_shared((block_M, block_N), "float32") + T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)}) + + T.copy( + src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + smem, + disable_tma=True # Ensure testing SIMT remote copy + ) + + T.copy( + smem, + dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + dst_pe=1 - rank[0], + disable_tma=True # Ensure testing SIMT remote copy + ) + + @T.prim_func + def simt_pull_tile( + dst: T.Tensor((M, N), "float32"), + src: T.Tensor((M, N), "float32"), + ): + with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by): + rank = T.alloc_local([1], "uint64") + rank[0] = T.get_rank() + + smem = T.alloc_shared((block_M, block_N), "float32") + T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)}) + + T.copy( + src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + smem, + src_pe=1 - rank[0], + disable_tma=True # Ensure testing SIMT remote copy + ) + + T.copy( + smem, + dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + disable_tma=True # Ensure testing SIMT remote copy + ) + + # TMA kernel requires run-time aware peer rank + @T.prim_func + def tma_load_tile( + dst: T.Tensor((M, N), "float32"), + src: T.Tensor((M, N), "float32"), + ): + with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by): + + smem = T.alloc_shared((block_M, block_N), "float32") + T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)}) + + # TMA load + T.copy( + src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + smem, + src_pe=1 - T.get_rank(), + # NOTE(wt): We cannot use rank[0] as above for TMA remote copy currently. + ) + + T.copy( + smem, + dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + disable_tma=True # Ensure testing SIMT remote copy + ) + + @T.prim_func + def tma_store_tile( + dst: T.Tensor((M, N), "float32"), + src: T.Tensor((M, N), "float32"), + ): + with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by): + + smem = T.alloc_shared((block_M, block_N), "float32") + T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)}) + + T.copy( + src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + smem, + disable_tma=True # Ensure testing SIMT remote copy + ) + + # TMA store + T.copy( + smem, + dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N], + dst_pe=1 - T.get_rank()) + + return { + 'simt_push_buffer': simt_push_buffer, + 'simt_push_tile': simt_push_tile, + 'simt_pull_tile': simt_pull_tile, + 'tma_load_tile': tma_load_tile, + 'tma_store_tile': tma_store_tile + }[kernel] + + +def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + M = args.M + N = args.N + BLOCK_M = 64 + BLOCK_N = 128 + threads = 128 + assert num_local_ranks == 2, "this example only supports 2 ranks copying to each other" + + _, _, group = init_dist(local_rank, num_local_ranks) + allocator = tilelang.get_allocator( + size=2**25, + device="cuda", + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_local_ranks, + group=group) + + kernel = get_kernel(M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel) + kernel.initialize(allocator=allocator) + if local_rank == 0: + print(kernel.get_kernel_source()) + + src = tilelang.tensor((M, N), torch.float32, allocator=allocator).normal_() + dst = tilelang.tensor((M, N), torch.float32, allocator=allocator) + + torch.cuda.synchronize() + torch.distributed.barrier(group) + kernel(dst, src) + torch.cuda.synchronize() + torch.distributed.barrier(group) + + dst_torchs = [torch.empty_like(src) for _ in range(num_local_ranks)] + dist.all_gather(dst_torchs, src, group) + dst_torch = dst_torchs[local_rank ^ 1] + + if torch.allclose(dst_torch, dst, atol=1e-6, rtol=1e-6): + print(f"rank {local_rank} check passed.✅") + else: + print(f"rank {local_rank} check failed.❌") + print(f"dst_torch: {dst_torch}, dst: {dst}") + raise ValueError("Test failed") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--M', type=int, default=1024, help='M dimension') + parser.add_argument('--N', type=int, default=1024, help='N dimension') + parser.add_argument('--kernel', type=str, default='simt_push_tile', help='kernel to use') + args = parser.parse_args() + num_processes = 2 + + torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes) diff --git a/examples/distributed/primitives/test_tilescale_copy.py b/examples/distributed/primitives/test_tilescale_copy.py new file mode 100644 index 000000000..313b1e097 --- /dev/null +++ b/examples/distributed/primitives/test_tilescale_copy.py @@ -0,0 +1,30 @@ +import argparse +import tilelang.testing +import torch +import torch.multiprocessing + +import example_tilescale_copy + + +@tilelang.testing.requires_cuda +def test_example_tilescale_copy_simt_push_tile(): + args = argparse.Namespace(M=1024, N=1024, kernel='simt_push_tile') + torch.multiprocessing.spawn(example_tilescale_copy.main, args=(2, args), nprocs=2) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_tilescale_copy_tma_load_tile(): + args = argparse.Namespace(M=1024, N=1024, kernel='tma_load_tile') + torch.multiprocessing.spawn(example_tilescale_copy.main, args=(2, args), nprocs=2) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_tilescale_copy_tma_store_tile(): + args = argparse.Namespace(M=1024, N=1024, kernel='tma_store_tile') + torch.multiprocessing.spawn(example_tilescale_copy.main, args=(2, args), nprocs=2) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/src/op/copy.cc b/src/op/copy.cc index 754dd7336..aacc62b89 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -155,9 +155,45 @@ Copy::Copy(Array args, BufferMap vmap) { if (args.size() >= 5) { node->eviction_policy = args[4].as()->value; } + + // Parse remote copy params + if (args.size() >= 6) { + node->src_pe = args[5]; + } + if (args.size() >= 7) { + node->dst_pe = args[6]; + } + + ICHECK(!(node->is_remote_push() && node->is_remote_pull())) + << "At least one of src_pe or dst_pe must be local rank"; + + if (node->is_remote_push()) { + ICHECK(node->dst.scope() == "global") + << "Can only copy to peer's global memory, but got " + << node->dst.scope(); + } else if (node->is_remote_pull()) { + ICHECK(node->src.scope() == "global") + << "Can only pull from peer's global memory, but got " + << node->src.scope(); + } + data_ = std::move(node); } +bool CopyNode::is_remote_push() const { + return !(dst_pe->IsInstance() && + dst_pe.as()->value == -1); +} + +bool CopyNode::is_remote_pull() const { + return !(src_pe->IsInstance() && + src_pe.as()->value == -1); +} + +bool CopyNode::is_remote_copy() const { + return is_remote_push() || is_remote_pull(); +} + /** * @brief Create a shallow clone of this CopyNode as a TileOperator. * @@ -1940,11 +1976,11 @@ Array TMAIm2ColDesc::EncodeCallArgs() const { // Register the Copy operation with TVM's TIR system // This makes the copy operation available for use in TVM programs -// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, -// eviction_policy +// - Takes 8 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, +// eviction_policy, src_pe, dst_pe // - Marked as opaque since it has side effects (memory writes) TIR_REGISTER_TL_OP(Copy, copy) - .set_num_inputs(5) + .set_num_inputs(7) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/copy.h b/src/op/copy.h index 00d07f169..a4ea4f53a 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -92,6 +92,15 @@ class CopyNode : public TileOperatorNode { IntImm coalesced_width; // Width (in elements) for coalesced memory access Bool disable_tma = Bool(false); // Whether to disable TMA acceleration + // Params for remote copy + PrimExpr src_pe; // Source PE for remote copy + PrimExpr dst_pe; // Destination PE for remote copy + Buffer symm_buffer; // Symmetric buffer for remote copy + + bool is_remote_copy() const; + bool is_remote_push() const; + bool is_remote_pull() const; + mutable ParallelOp par_op_; // Optional associated parallelization operator enum class EvictionPolicy : uint8_t { diff --git a/src/op/distributed.cc b/src/op/distributed.cc index 84a23afa7..54ed41ff9 100644 --- a/src/op/distributed.cc +++ b/src/op/distributed.cc @@ -202,11 +202,26 @@ TIR_DEFINE_TL_BUILTIN(get_num_ranks) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(get_remote_base) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(get_remote_base_ptr) .set_num_inputs(1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(get_local_base) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(get_local_base_ptr) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(get_uintptr_t) .set_num_inputs(1) .set_attr("TCallEffectKind", diff --git a/src/op/distributed.h b/src/op/distributed.h index 5170cc7ea..0331d6116 100644 --- a/src/op/distributed.h +++ b/src/op/distributed.h @@ -212,22 +212,37 @@ const Op &FcollectBlock(); const Op &CpengineCpAsync(); /*! - * \brief tvm intrinsics for getting the rank of the current process + * \brief tilescale intrinsics for getting the rank of the current process */ const Op &get_rank(); /*! - * \brief tvm intrinsics for getting the number of processes + * \brief tilescale intrinsics for getting the number of ranks */ const Op &get_num_ranks(); /*! - * \brief tvm intrinsics for getting the remote base pointer + * \brief tilescale intrinsics for getting the remote base address (u64) + */ +const Op &get_remote_base(); + +/*! + * \brief tilescale intrinsics for getting the remote base pointer */ const Op &get_remote_base_ptr(); /*! - * \brief tvm intrinsics for getting the uintptr_t of a pointer + * \brief tilescale intrinsics for getting the local base address (u64) + */ +const Op &get_local_base(); + +/*! + * \brief tilescale intrinsics for getting the local base pointer + */ +const Op &get_local_base_ptr(); + +/*! + * \brief tilescale intrinsics for getting the u64 value of a pointer */ const Op &get_uintptr_t(); } // namespace tl diff --git a/src/op/remote_copy.cc b/src/op/remote_copy.cc index 059d545b9..06125e875 100644 --- a/src/op/remote_copy.cc +++ b/src/op/remote_copy.cc @@ -100,12 +100,12 @@ Stmt PutOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { PrimExpr dst_addr_expr = MakeRemappedAddress(T, dst_buffer, dst_indices); PrimExpr local_rank = Call(DataType::Int(64), tl::get_rank(), {}); PrimExpr local_base_ptr = - Call(DataType::Handle(), tl::get_remote_base_ptr(), {local_rank}); + Call(DataType::Handle(), tl::get_remote_base(), {local_rank}); PrimExpr offset_to_base = Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {dst_addr_expr}), local_base_ptr); new_args.push_back( - Call(DataType::Handle(), tl::get_remote_base_ptr(), {dst_pe}) + + Call(DataType::Handle(), tl::get_remote_base(), {dst_pe}) + offset_to_base); } else { new_args.push_back(MakeRemappedAddress(T, dst_buffer, dst_indices)); @@ -206,12 +206,12 @@ Stmt GetOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { PrimExpr src_addr_expr = MakeRemappedAddress(T, src_buffer, src_indices); PrimExpr local_rank = Call(DataType::Int(64), tl::get_rank(), {}); PrimExpr local_base_ptr = - Call(DataType::Handle(), tl::get_remote_base_ptr(), {local_rank}); + Call(DataType::Handle(), tl::get_remote_base(), {local_rank}); PrimExpr offset_to_base = Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {src_addr_expr}), local_base_ptr); new_args.push_back( - Call(DataType::Handle(), tl::get_remote_base_ptr(), {src_pe}) + + Call(DataType::Handle(), tl::get_remote_base(), {src_pe}) + offset_to_base); } else { new_args.push_back(MakeRemappedAddress(T, src_buffer, src_indices)); diff --git a/src/op/sync.cc b/src/op/sync.cc index 0c83a7b8d..863fc6766 100644 --- a/src/op/sync.cc +++ b/src/op/sync.cc @@ -94,7 +94,7 @@ Stmt BarrierAllBlocksSysOpNode::Lower(const LowerArgs &T, PrimExpr rank = Call(DataType::Int(64), tl::get_rank(), {}); PrimExpr num_ranks = Call(DataType::Int(64), tl::get_num_ranks(), {}); PrimExpr local_base_ptr = - Call(DataType::Handle(), tl::get_remote_base_ptr(), {rank}); + Call(DataType::Handle(), tl::get_remote_base(), {rank}); PrimExpr offset_to_base = Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {bar_addr}), local_base_ptr); diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index e93b6fc4e..b39228b09 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -297,6 +297,8 @@ std::string CodeGenTileLangCUDA::Finish() { if (use_distributed_) { decl_stream << "uint64_t __constant__ meta_data[1024];\n"; + decl_stream + << "uint64_t* host_meta_data = nullptr;\n"; // An alias of host_table } decl_stream << "#ifdef ENABLE_BF16\n"; decl_stream << "#include \n"; @@ -1539,10 +1541,20 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::get_num_ranks())) { this->use_distributed_ = true; os << "tl::get_num_ranks()"; + } else if (op->op.same_as(tl::get_remote_base())) { + this->use_distributed_ = true; + std::string pe_str = this->PrintExpr(op->args[0]); + os << "tl::get_remote_base(" << pe_str << ")"; } else if (op->op.same_as(tl::get_remote_base_ptr())) { this->use_distributed_ = true; std::string pe_str = this->PrintExpr(op->args[0]); os << "tl::get_remote_base_ptr(" << pe_str << ")"; + } else if (op->op.same_as(tl::get_local_base())) { + this->use_distributed_ = true; + os << "tl::get_local_base()"; + } else if (op->op.same_as(tl::get_local_base_ptr())) { + this->use_distributed_ = true; + os << "tl::get_local_base_ptr()"; } else if (op->op.same_as(tl::get_uintptr_t())) { os << "tl::get_uintptr_t(" << this->PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(builtin::tvm_fill_fragment())) { diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index dfbc062cf..f57073ff3 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -32,6 +32,10 @@ using int4_t = int4; #define TL_DEVICE __forceinline__ __device__ #define TL_DEVICE_NOINLINE __noinline__ __device__ +#define TL_HOST __forceinline__ __host__ +#define TL_HOST_NOINLINE __noinline__ __host__ +#define TL_HOST_DEVICE __forceinline__ __host__ __device__ +#define TL_HOST_DEVICE_NOINLINE __noinline__ __host__ __device__ #define TL_PATCH #define TILELANG_CHECK(stmt) \ diff --git a/src/tl_templates/cuda/distributed.h b/src/tl_templates/cuda/distributed.h index 7eca3a708..47a2c8e5a 100644 --- a/src/tl_templates/cuda/distributed.h +++ b/src/tl_templates/cuda/distributed.h @@ -1,20 +1,56 @@ #pragma once #include "common.h" +#include namespace tl { -extern "C" extern __device__ uint64_t meta_data[1024]; +extern "C" __device__ uint64_t meta_data[1024]; +extern "C" uint64_t *host_meta_data; -TL_DEVICE uint64_t get_rank() { return meta_data[0]; } +TL_HOST_DEVICE uint64_t get_rank() { +#ifdef __CUDA_ARCH__ + return meta_data[0]; +#else + return host_meta_data[0]; +#endif +} + +TL_HOST_DEVICE uint64_t get_num_ranks() { +#ifdef __CUDA_ARCH__ + return meta_data[1]; +#else + return host_meta_data[1]; +#endif +} -TL_DEVICE uint64_t get_num_ranks() { return meta_data[1]; } +// NOTE(wt): Be careful about the return types here! +// I could not find a way cast u64 to ptr in tir ? -TL_DEVICE uint64_t get_remote_base_ptr(uint64_t rank) { +TL_HOST_DEVICE uint64_t get_remote_base(uint64_t rank) { +#ifdef __CUDA_ARCH__ return meta_data[2 + rank]; +#else + return host_meta_data[2 + rank]; +#endif +} + +TL_HOST_DEVICE void *get_remote_base_ptr(uint64_t rank) { + return (void *)get_remote_base(rank); +} + +TL_HOST_DEVICE uint64_t get_local_base() { +#ifdef __CUDA_ARCH__ + return meta_data[2 + get_rank()]; +#else + return host_meta_data[2 + get_rank()]; +#endif } -template TL_DEVICE uint64_t get_uintptr_t(dtype_t *ptr) { +TL_HOST_DEVICE void *get_local_base_ptr() { return (void *)get_local_base(); } + +template +TL_HOST_DEVICE uint64_t get_uintptr_t(dtype_t *ptr) { return reinterpret_cast(ptr); } diff --git a/src/tl_templates/cuda/sync.h b/src/tl_templates/cuda/sync.h index b7b4b8cb9..64e7b9ecd 100644 --- a/src/tl_templates/cuda/sync.h +++ b/src/tl_templates/cuda/sync.h @@ -155,7 +155,7 @@ TL_DEVICE void sync_grid(uint32_t *barrier) { TL_DEVICE void barrier_all_blocks_sys(int offset, int rank, int num_ranks) { // Macro to compute the barrier pointer for a given target rank #define BARRIER_PTR(tgt_rank) \ - (reinterpret_cast(get_remote_base_ptr(tgt_rank) + offset)) + (reinterpret_cast(get_remote_base(tgt_rank) + offset)) memory_fence_sys(); __syncthreads(); diff --git a/src/tl_templates/cuda/threadblock_swizzle.h b/src/tl_templates/cuda/threadblock_swizzle.h index 1539b657d..00a230c1a 100644 --- a/src/tl_templates/cuda/threadblock_swizzle.h +++ b/src/tl_templates/cuda/threadblock_swizzle.h @@ -4,7 +4,7 @@ namespace tl { -template TL_DEVICE dim3 rasterization2DRow() { +template TL_DEVICE dim3 rasterization2DRow() { const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; const unsigned int grid_size = gridDim.x * gridDim.y; const unsigned int panel_size = panel_width * gridDim.x; @@ -23,7 +23,8 @@ template TL_DEVICE dim3 rasterization2DRow() { return {col_idx, row_idx, blockIdx.z}; } -template TL_DEVICE dim3 rasterization2DColumn() { +template +TL_DEVICE dim3 rasterization2DColumn() { const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; const unsigned int grid_size = gridDim.x * gridDim.y; const unsigned int panel_size = panel_width * gridDim.y; diff --git a/src/transform/declare_symm_buffer.cc b/src/transform/declare_symm_buffer.cc new file mode 100644 index 000000000..39af91819 --- /dev/null +++ b/src/transform/declare_symm_buffer.cc @@ -0,0 +1,256 @@ +// TileScale pass + +/*! + * \file declare_symm_buffer.cc + * \brief Declare the symmetry buffer to prepare for operators that need buffers + * on peer's symm heap + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/copy.h" +#include "../op/distributed.h" + +namespace tvm { +namespace tl { + +using namespace tir; +using tvm::transform::PassContext; + +static int name_suffix_id = + 0; // Avoid name collision for symm buffers, start from 0 + +/* Create a PrimExpr to calculate the symmetry pointer given a local ptr and + * target PE */ +PrimExpr CalculateSymmPtr(PrimExpr ptr, PrimExpr pe) { + PrimExpr local_base_ptr = Call(DataType::Handle(), tl::get_local_base(), {}); + PrimExpr offset_to_base = + Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {ptr}), local_base_ptr); + PrimExpr result = Call(DataType::Handle(), tl::get_remote_base_ptr(), {pe}) + + offset_to_base; + return result; +} +/*! + * \brief Declare the symmetry buffer to prepare for operators + * that need buffers on peer's symm heap + */ +class SymmBufferDeclarer : public StmtExprMutator { +public: + static PrimFunc Apply(PrimFunc f) { + if (!f->body.defined()) { + return f; + } + + SymmBufferDeclarer declarer; + + // Initialize buffer map from function's buffer_map + for (const auto &[_, buffer] : f->buffer_map) { + declarer.buffer_data_to_buffer_.Set(buffer->data, buffer); + } + + // Extract symm buffer info and replace them + // The LetStmt insertion will happen in VisitStmt_ before each copy + f.CopyOnWrite()->body = declarer.VisitStmt(f->body); + + return f; + }; + +private: + Stmt VisitStmt_(const BlockNode *op) final { + // Record the mapping from buffer data var to buffer for later lookup + for (auto buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + for (auto match_buffer : op->match_buffers) { + buffer_data_to_buffer_.Set(match_buffer->buffer->data, + match_buffer->buffer); + } + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const EvaluateNode *op) final { + // Check if this Evaluate contains a Call + if (const CallNode *call_op = op->value.as()) { + if (call_op->op.as()) { + // Do not process the call node to the global function. + return StmtExprMutator::VisitStmt_(op); + } + // LOG(INFO) << "Found call"; + auto parsed_op = + ParseOperator(GetRef(call_op), buffer_data_to_buffer_); + // LOG(INFO) << "Parsed op: " << parsed_op; + if (parsed_op.defined() && parsed_op.as()) { + // LOG(INFO) << "Found copy"; + + if (parsed_op.as()->is_remote_push()) { + // LOG(INFO) << "Found remote push"; + + Buffer dst = parsed_op.as()->dst; + Array dst_range = parsed_op.as()->dst_range; + + // 1. Calculate symm dst ptr + PrimExpr symm_dst_ptr_expr = + CalculateSymmPtr(dst->data, parsed_op.as()->dst_pe); + // LOG(INFO) << "Symm dst ptr expr: " << symm_dst_ptr_expr; + + // 2. Create a let binding + String storage_scope = + dst->data->type_annotation.as()->storage_scope; + Var symm_dst_var = + Var(dst->name + "_symm_" + std::to_string(name_suffix_id++), + PointerType(PrimType(dst->dtype), storage_scope)); + + // 3. Create modified dst buffer with symm var + dst.CopyOnWrite()->data = symm_dst_var; + + // 4. Rebuild the destination region call with the modified buffer + // RegionOp args: [BufferLoad(min_indices), access_mask, extent_0, + // extent_1, ...] + Array dst_region_mins; + Array dst_region_extents; + for (const Range &r : dst_range) { + dst_region_mins.push_back(r->min); + dst_region_extents.push_back(r->extent); + } + BufferLoad dst_load(dst, dst_region_mins); + + Array dst_region_args; + dst_region_args.push_back(dst_load); + dst_region_args.push_back( + IntImm(DataType::Int(32), call_op->args[1] + .as() + ->args[1] + .as() + ->value)); // access_mask + for (const PrimExpr &extent : dst_region_extents) { + dst_region_args.push_back(extent); + } + + // Create new Call for the destination region + Call dst_region_call = + Call(call_op->args[1].as()->dtype, + call_op->args[1].as()->op, dst_region_args, + call_op->args[1].as()->span); + + // 5. Rebuild the Copy call with modified args + Array new_copy_args; + new_copy_args.push_back(call_op->args[0]); // src region (unchanged) + new_copy_args.push_back(dst_region_call); // modified dst region + // Copy remaining args + for (size_t i = 2; i < call_op->args.size(); i++) { + new_copy_args.push_back(call_op->args[i]); + } + + // Create the modified copy call + Call new_copy_call = + Call(call_op->dtype, call_op->op, new_copy_args, call_op->span); + + // Wrap it in an Evaluate statement + Stmt modified_stmt = Evaluate(new_copy_call); + + // Wrap with LetStmt that defines the symm pointer + return LetStmt(symm_dst_var, symm_dst_ptr_expr, modified_stmt); + } else if (parsed_op.as()->is_remote_pull()) { + // LOG(INFO) << "Found remote pull"; + + Buffer src = parsed_op.as()->src; + Array src_range = parsed_op.as()->src_range; + + // 1. Calculate symm src ptr + PrimExpr symm_src_ptr_expr = + CalculateSymmPtr(src->data, parsed_op.as()->src_pe); + // LOG(INFO) << "Symm src ptr expr: " << symm_src_ptr_expr; + + // 2. Create a let binding + String storage_scope = + src->data->type_annotation.as()->storage_scope; + Var symm_src_var = + Var(src->name + "_symm_" + std::to_string(name_suffix_id++), + PointerType(PrimType(src->dtype), storage_scope)); + + // 3. Create modified src buffer with symm var + src.CopyOnWrite()->data = symm_src_var; + + // 4. Rebuild the source region call with the modified buffer + // RegionOp args: [BufferLoad(min_indices), access_mask, extent_0, + // extent_1, ...] + Array src_region_mins; + Array src_region_extents; + for (const Range &r : src_range) { + src_region_mins.push_back(r->min); + src_region_extents.push_back(r->extent); + } + BufferLoad src_load(src, src_region_mins); + + Array src_region_args; + src_region_args.push_back(src_load); + src_region_args.push_back( + IntImm(DataType::Int(32), call_op->args[0] + .as() + ->args[1] + .as() + ->value)); // access_mask + for (const PrimExpr &extent : src_region_extents) { + src_region_args.push_back(extent); + } + + // Create new Call for the source region + Call src_region_call = + Call(call_op->args[0].as()->dtype, + call_op->args[0].as()->op, src_region_args, + call_op->args[0].as()->span); + + // 5. Rebuild the Copy call with modified args + Array new_copy_args; + new_copy_args.push_back(src_region_call); // modified src region + new_copy_args.push_back(call_op->args[1]); // dst region (unchanged) + // Copy remaining args + for (size_t i = 2; i < call_op->args.size(); i++) { + new_copy_args.push_back(call_op->args[i]); + } + + // Create the modified copy call + Call new_copy_call = + Call(call_op->dtype, call_op->op, new_copy_args, call_op->span); + + // Wrap it in an Evaluate statement + Stmt modified_stmt = Evaluate(new_copy_call); + + // Wrap with LetStmt that defines the symm pointer + return LetStmt(symm_src_var, symm_src_ptr_expr, modified_stmt); + } + } + } + + // Default: use parent's visitor + return StmtExprMutator::VisitStmt_(op); + } + + Map buffer_data_to_buffer_; +}; + +tvm::transform::Pass DeclareSymmBuffer() { + auto pass_func = [](PrimFunc f, const IRModule &, const PassContext &) { + f = SymmBufferDeclarer::Apply(std::move(f)); + return f; + }; + return tir::transform::CreatePrimFuncPass(pass_func, 0, + "tl.DeclareSymmBuffer", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.DeclareSymmBuffer", DeclareSymmBuffer); +}); + +} // namespace tl +} // namespace tvm diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index b514627d7..2751aeab7 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -39,8 +39,22 @@ class LowerHopperIntrin : public StmtExprMutator { CHECK(0) << call->op; } init_desc_args.push_back(var); - init_desc_args.insert(init_desc_args.end(), call->args.begin(), - call->args.end()); + + // Inline let-bound variables in the descriptor arguments + for (const auto &arg : call->args) { + if (auto arg_var = arg.as()) { + auto it = substituter.tma_let_bindings_.find(arg_var.value()); + if (it != substituter.tma_let_bindings_.end()) { + // Replace variable with its let-bound expression + init_desc_args.push_back(it->second); + } else { + init_desc_args.push_back(arg); + } + } else { + init_desc_args.push_back(arg); + } + } + // add to function attribute Call init_desc = Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args); @@ -49,6 +63,7 @@ class LowerHopperIntrin : public StmtExprMutator { init_desc_arg_map.Set(var, init_desc_args); } f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map); + return f; } @@ -99,9 +114,43 @@ class LowerHopperIntrin : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } + Stmt VisitStmt_(const LetStmtNode *op) final { + PrimExpr value = this->VisitExpr(op->value); + Stmt body = this->VisitStmt(op->body); + + // Check if this variable is related to TMA (used in descriptor creation) + if (tma_related_vars_.count(op->var)) { + tma_let_bindings_[op->var] = value; + } + + if (value.same_as(op->value) && body.same_as(op->body)) { + return GetRef(op); + } else { + return LetStmt(op->var, value, body); + } + } + PrimExpr VisitExpr_(const CallNode *call) final { - if (call->op.same_as(create_tma_descriptor()) || - call->op.same_as(create_tma_im2col_descriptor())) { + if (call->op.same_as(create_tma_descriptor())) { + Var var; + auto iter = desc_map_.find(GetRef(call)); + if (iter != desc_map_.end()) { + var = iter->second; + } else { + String name = call->args[2].as().value()->name_hint; + var = Var(name + "_desc", + PointerType(PrimType(cuTensorMapType()), "grid_constant")); + desc_map_[GetRef(call)] = var; + prefetch_calls_.push_back( + Evaluate(Call(DataType::Handle(), builtin::call_extern(), + {StringImm("tl::prefetch_tma_descriptor"), var}))); + // Mark the base pointer variable as TMA-related + if (auto base_var = call->args[2].as()) { + tma_related_vars_.insert(base_var.value()); + } + } + return var; + } else if (call->op.same_as(create_tma_im2col_descriptor())) { Var var; auto iter = desc_map_.find(GetRef(call)); if (iter != desc_map_.end()) { @@ -135,6 +184,9 @@ class LowerHopperIntrin : public StmtExprMutator { Array prefetch_calls_; Array init_mbarrier_calls_; std::unordered_map desc_map_; + std::unordered_set tma_related_vars_; + std::unordered_map + tma_let_bindings_; LowerHopperIntrin(bool disable_shuffle_elect) : disable_shuffle_elect_(disable_shuffle_elect) {} bool disable_shuffle_elect_; diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 6c34eae08..f20755dcd 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -104,6 +104,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.LayoutReducer()(mod) # Infer memory layouts for fragments and shared memory mod = tilelang.transform.LayoutInference()(mod) + # Declare symmetric buffer on peer's global memory + mod = tilelang.transform.DeclareSymmBuffer()(mod) # Lower high-level tile operations to low-level operations mod = tilelang.transform.LowerTileOp()(mod) # Lower l2 persistent map diff --git a/tilelang/jit/adapter/utils.py b/tilelang/jit/adapter/utils.py index efc965e1b..a32615f39 100644 --- a/tilelang/jit/adapter/utils.py +++ b/tilelang/jit/adapter/utils.py @@ -216,3 +216,137 @@ def _visitor(node): tvm.tir.stmt_functor.post_order_visit(expr, _visitor) return next(iter(node_to_result_map[expr]), "") + + +def tilescale_pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None) -> str: + """ + Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. + Compare to pythonic_expr, this function is used for TileScale with support for parsing specified calls. + + Args: + expr: The TVM PrimExpr to convert. + + Returns: + A string representation of the expression. + """ + if not isinstance(expr, tvm.tir.PrimExpr): + return str(expr) + + # 1. Define operator precedence (higher value means higher precedence) + # Based on Python's operator precedence + PRECEDENCE = { + tvm.tir.Call: 20, # Includes min, max + tvm.tir.Cast: 20, # Treated like a function call + tvm.tir.Mul: 13, + tvm.tir.FloorDiv: 13, + tvm.tir.Div: 13, # For tvm.tir.Div if it appears + tvm.tir.FloorMod: 13, + tvm.tir.Add: 12, + tvm.tir.Sub: 12, + tvm.tir.LT: 10, + tvm.tir.LE: 10, + tvm.tir.GT: 10, + tvm.tir.GE: 10, + tvm.tir.EQ: 10, + tvm.tir.NE: 10, + tvm.tir.And: 5, + tvm.tir.Or: 4, + # Atoms (Var, IntImm) have the highest precedence implicitly + } + # By default, atomic expressions (variables, constants) have the highest precedence + ATOMIC_PRECEDENCE = 100 + + node_to_result_map = {} # Stores (string, precedence) for each node + + def _visitor(node): + # 2. Visitor returns (str, precedence) tuple + if node in node_to_result_map: + return + + if isinstance(node, tvm.tir.Var): + s, p = node.name, ATOMIC_PRECEDENCE + elif isinstance(node, (tvm.tir.IntImm, tvm.tir.FloatImm)): + s, p = str(node.value), ATOMIC_PRECEDENCE + elif isinstance(node, tvm.tir.Cast): + # C-style cast has high precedence + value_str, _ = node_to_result_map[node.value] + if dtype_map is None: + s = f"({node.dtype}){value_str}" + else: + s = f"({dtype_map[node.dtype]}){value_str}" + p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE) + elif isinstance( + node, + (tvm.tir.Mul, tvm.tir.FloorDiv, tvm.tir.Add, tvm.tir.Sub, tvm.tir.FloorMod, tvm.tir.LT, + tvm.tir.LE, tvm.tir.GT, tvm.tir.GE, tvm.tir.EQ, tvm.tir.NE, tvm.tir.And, tvm.tir.Or)): + op_map = { + tvm.tir.Mul: "*", + tvm.tir.FloorDiv: "/", + tvm.tir.Add: "+", + tvm.tir.Sub: "-", + tvm.tir.FloorMod: "%", + tvm.tir.LT: "<", + tvm.tir.LE: "<=", + tvm.tir.GT: ">", + tvm.tir.GE: ">=", + tvm.tir.EQ: "==", + tvm.tir.NE: "!=", + tvm.tir.And: "and", + tvm.tir.Or: "or", + } + op_str = f" {op_map[type(node)]} " + my_precedence = PRECEDENCE[type(node)] + + a_str, a_precedence = node_to_result_map[node.a] + b_str, b_precedence = node_to_result_map[node.b] + + # 3. Add parentheses intelligently + # Add parentheses if the left operand's precedence is lower than the current operator + if a_precedence < my_precedence: + a_str = f"({a_str})" + # Add parentheses if the right operand's precedence is lower than or equal to the current operator + # 'Equal' is to handle non-associative operations, e.g., a - (b - c) + if b_precedence <= my_precedence: + b_str = f"({b_str})" + + s = f"{a_str}{op_str}{b_str}" + p = my_precedence + elif isinstance(node, (tvm.tir.Min, tvm.tir.Max)): + op_name = "min" if isinstance(node, tvm.tir.Min) else "max" + a_str, _ = node_to_result_map[node.a] + b_str, _ = node_to_result_map[node.b] + s = f"{op_name}({a_str}, {b_str})" + # Function calls have high precedence + p = PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + + # Parse known calls in TileScale + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_rank"): + s, p = "tl::get_rank()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_num_ranks"): + s, p = "tl::get_num_ranks()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_remote_base"): + pe_str, _ = node_to_result_map[node.args[0]] + s, p = "tl::get_remote_base(" + pe_str + ")", PRECEDENCE.get( + tvm.tir.Call, ATOMIC_PRECEDENCE) + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_remote_base_ptr"): + pe_str, _ = node_to_result_map[node.args[0]] + s, p = "tl::get_remote_base_ptr(" + pe_str + ")", PRECEDENCE.get( + tvm.tir.Call, ATOMIC_PRECEDENCE) + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_local_base"): + s, p = "tl::get_local_base()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_local_base_ptr"): + s, p = "tl::get_local_base_ptr()", PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE) + elif isinstance(node, tvm.tir.Call) and node.op == tir.op.Op.get("tl.get_uintptr_t"): + ptr_str, _ = node_to_result_map[node.args[0]] + s, p = "tl::get_uintptr_t(" + ptr_str + ")", PRECEDENCE.get( + tvm.tir.Call, ATOMIC_PRECEDENCE) + else: + # Fallback for unhandled expression types + s, p = str(node), 0 + + node_to_result_map[node] = (s, p) + + # Perform post-order traversal + tvm.tir.stmt_functor.post_order_visit(expr, _visitor) + + return next(iter(node_to_result_map[expr]), "") diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index c88cbc6cc..f0061cdd5 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -6,7 +6,8 @@ from tvm import IRModule from tvm.target import Target from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, - is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr) + is_hip_target, is_cpu_target, get_annotated_mod, pythonic_expr, + tilescale_pythonic_expr) import re import logging import textwrap @@ -61,6 +62,7 @@ if (error_buf) std::snprintf(error_buf, 256, "cudaMemcpyToSymbol failed: %s", cudaGetErrorString(err)); return static_cast(err); }} + host_meta_data = (uint64_t*)host_table; return 0; }} """ @@ -283,6 +285,9 @@ def __init__(self, def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: return pythonic_expr(expr, self._TYPE_MAP) + def _tilescale_pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: + return tilescale_pythonic_expr(expr, self._TYPE_MAP) + def is_tma_descriptor_arg(self, arg_name: str) -> bool: return arg_name in self.prim_func.buffer_map @@ -325,7 +330,12 @@ def func_call_args(s, # Extract the function call arguments matching the function definition def maybe_desc(name: str, matches: list[str], i: int): match = matches[i] - if not (match == name + "_desc" or match.startswith(name + "_desc_")): + if not (match == name + "_desc" \ + or match.startswith(name + "_desc_") + or (match.startswith(name + "_symm_") and match.endswith("_desc")) + # The last cases belongs to TMA copy from symm buffer + # Check naming in src/transform/declare_symm_buffer.cc + ): return False desc_decls = [] if desc_name_map is not None: @@ -453,6 +463,8 @@ def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], tma_create_str, _, dtype, tensor_rank, globalAddress, *remaining_args = args + globalAddress = self._tilescale_pythonic_expr(globalAddress) + is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col") dtype = self._pythonic_expr(dtype) tensor_rank = int(self._pythonic_expr(tensor_rank)) diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 27f5432fc..023271a7d 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -16,6 +16,7 @@ from .proxy import ( ptr, # noqa: F401 make_tensor, # noqa: F401 + make_tensor_like, # noqa: F401 Buffer, # noqa: F401 Tensor, # noqa: F401 StridedTensor, # noqa: F401 diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index 84444b8c6..ba27f24d4 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -8,17 +8,25 @@ from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region -def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, - dst: tir.Buffer | tir.BufferLoad, - coalesced_width: int | None = None, - disable_tma: bool = False, - eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None): +def copy( + src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, + dst: tir.Buffer | tir.BufferLoad, + src_pe: tir.PrimExpr | tir.IntImm | None = -1, + dst_pe: tir.PrimExpr | tir.IntImm | None = -1, + coalesced_width: int | None = None, + disable_tma: bool = False, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, +): """Copy data between memory regions. Args: src (Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion]): Source memory region dst (Union[tir.Buffer, tir.BufferLoad]): Destination memory region + src_pe (Optional[tir.PrimExpr], optional): Source PE index. Defaults to -1, which means copy from local + dst_pe (Optional[tir.PrimExpr], optional): Destination PE index. Defaults to -1, which means copy to local. coalesced_width (Optional[int], optional): Width for coalesced memory access. Defaults to None. + disable_tma (bool, optional): Whether to disable TMA. Defaults to False. + eviction_policy (Optional[Literal["evict_normal", "evict_first", "evict_last"]], optional): Eviction policy. Defaults to None. Raises: TypeError: If copy extents cannot be deduced from arguments @@ -83,8 +91,9 @@ def _to_region(data, access_type): eviction_policy = 0 else: eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] + return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width, - disable_tma, eviction_policy) + disable_tma, eviction_policy, src_pe, dst_pe) def c2d_im2col(img: tir.Buffer, diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py index 539c1d94c..5aced421a 100644 --- a/tilelang/language/proxy.py +++ b/tilelang/language/proxy.py @@ -300,3 +300,14 @@ def make_tensor(ptr: Var, dtype: str = "float32", strides: tuple[PrimExpr, ...] = None) -> tir.Buffer: return Tensor.from_ptr(ptr, shape, dtype, strides) + + +def make_tensor_like(tensor, + ptr: Var, + shape: tuple[PrimExpr, ...] | None = None, + dtype: str | None = None, + strides: tuple[PrimExpr, ...] | None = None) -> tir.Buffer: + return Tensor.from_ptr(ptr if ptr is not None else tensor.data, + shape if shape is not None else tensor.shape, + dtype if dtype is not None else tensor.dtype, + strides if strides is not None else tensor.strides) diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 808c97dc6..3c174ecab 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -484,3 +484,15 @@ def LayoutReducer(): The transform pass object produced by the FFI backend. """ return _ffi_api.LayoutReducer() # type: ignore + + +# TileScale passes + + +def DeclareSymmBuffer(): + """ + Declare symmetric buffer on peer's global memory. + + This pass prepares for distributed operators, such as remote copy. + """ + return _ffi_api.DeclareSymmBuffer() # type: ignore