Skip to content
192 changes: 192 additions & 0 deletions examples/distributed/primitives/example_tilescale_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import os
import tilelang
import tilelang.language as T
import argparse
import torch
import torch.distributed as dist
import torch.multiprocessing
from tilelang.distributed import init_dist

tilelang.disable_cache()
os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log


@tilelang.jit
def get_kernel(M, N, block_M, block_N, threads, kernel='simt_push_tile'):

@T.prim_func
def simt_push_buffer(
dst: T.Tensor((M, N), "float32"),
src: T.Tensor((M, N), "float32"),
):
with T.Kernel((1), threads=threads):
rank = T.alloc_local([1], "uint64")
rank[0] = T.get_rank()

T.copy(
src,
dst,
dst_pe=1 - rank[0],
disable_tma=True # Ensure testing SIMT remote copy
)

@T.prim_func
def simt_push_tile(
dst: T.Tensor((M, N), "float32"),
src: T.Tensor((M, N), "float32"),
):
with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by):
rank = T.alloc_local([1], "uint64")
rank[0] = T.get_rank()

smem = T.alloc_shared((block_M, block_N), "float32")
T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)})

T.copy(
src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
smem,
disable_tma=True # Ensure testing SIMT remote copy
)

T.copy(
smem,
dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
dst_pe=1 - rank[0],
disable_tma=True # Ensure testing SIMT remote copy
)

@T.prim_func
def simt_pull_tile(
dst: T.Tensor((M, N), "float32"),
src: T.Tensor((M, N), "float32"),
):
with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by):
rank = T.alloc_local([1], "uint64")
rank[0] = T.get_rank()

smem = T.alloc_shared((block_M, block_N), "float32")
T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)})

T.copy(
src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
smem,
src_pe=1 - rank[0],
disable_tma=True # Ensure testing SIMT remote copy
)

T.copy(
smem,
dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
disable_tma=True # Ensure testing SIMT remote copy
)

# TMA kernel requires run-time aware peer rank
@T.prim_func
def tma_load_tile(
dst: T.Tensor((M, N), "float32"),
src: T.Tensor((M, N), "float32"),
):
with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by):

smem = T.alloc_shared((block_M, block_N), "float32")
T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)})

# TMA load
T.copy(
src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
smem,
src_pe=1 - T.get_rank(),
# NOTE(wt): We cannot use rank[0] as above for TMA remote copy currently.
)

T.copy(
smem,
dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
disable_tma=True # Ensure testing SIMT remote copy
)

@T.prim_func
def tma_store_tile(
dst: T.Tensor((M, N), "float32"),
src: T.Tensor((M, N), "float32"),
):
with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by):

smem = T.alloc_shared((block_M, block_N), "float32")
T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)})

T.copy(
src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
smem,
disable_tma=True # Ensure testing SIMT remote copy
)

# TMA store
T.copy(
smem,
dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
dst_pe=1 - T.get_rank())

return {
'simt_push_buffer': simt_push_buffer,
'simt_push_tile': simt_push_tile,
'simt_pull_tile': simt_pull_tile,
'tma_load_tile': tma_load_tile,
'tma_store_tile': tma_store_tile
}[kernel]


def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
M = args.M
N = args.N
BLOCK_M = 64
BLOCK_N = 128
threads = 128
assert num_local_ranks == 2, "this example only supports 2 ranks copying to each other"

_, _, group = init_dist(local_rank, num_local_ranks)
allocator = tilelang.get_allocator(
size=2**25,
device="cuda",
is_distributed=True,
local_rank=local_rank,
num_local_ranks=num_local_ranks,
group=group)

kernel = get_kernel(M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel)
kernel.initialize(allocator=allocator)
if local_rank == 0:
print(kernel.get_kernel_source())

src = tilelang.tensor((M, N), torch.float32, allocator=allocator).normal_()
dst = tilelang.tensor((M, N), torch.float32, allocator=allocator)

torch.cuda.synchronize()
torch.distributed.barrier(group)
kernel(dst, src)
torch.cuda.synchronize()
torch.distributed.barrier(group)

dst_torchs = [torch.empty_like(src) for _ in range(num_local_ranks)]
dist.all_gather(dst_torchs, src, group)
dst_torch = dst_torchs[local_rank ^ 1]

if torch.allclose(dst_torch, dst, atol=1e-6, rtol=1e-6):
print(f"rank {local_rank} check passed.✅")
else:
print(f"rank {local_rank} check failed.❌")
print(f"dst_torch: {dst_torch}, dst: {dst}")
raise ValueError("Test failed")

dist.destroy_process_group()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=1024, help='M dimension')
parser.add_argument('--N', type=int, default=1024, help='N dimension')
parser.add_argument('--kernel', type=str, default='simt_push_tile', help='kernel to use')
args = parser.parse_args()
num_processes = 2

torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes)
30 changes: 30 additions & 0 deletions examples/distributed/primitives/test_tilescale_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import argparse
import tilelang.testing
import torch
import torch.multiprocessing

import example_tilescale_copy


@tilelang.testing.requires_cuda
def test_example_tilescale_copy_simt_push_tile():
args = argparse.Namespace(M=1024, N=1024, kernel='simt_push_tile')
torch.multiprocessing.spawn(example_tilescale_copy.main, args=(2, args), nprocs=2)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_tilescale_copy_tma_load_tile():
args = argparse.Namespace(M=1024, N=1024, kernel='tma_load_tile')
torch.multiprocessing.spawn(example_tilescale_copy.main, args=(2, args), nprocs=2)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_tilescale_copy_tma_store_tile():
args = argparse.Namespace(M=1024, N=1024, kernel='tma_store_tile')
torch.multiprocessing.spawn(example_tilescale_copy.main, args=(2, args), nprocs=2)


if __name__ == "__main__":
tilelang.testing.main()
42 changes: 39 additions & 3 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,45 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() >= 5) {
node->eviction_policy = args[4].as<IntImmNode>()->value;
}

// Parse remote copy params
if (args.size() >= 6) {
node->src_pe = args[5];
}
if (args.size() >= 7) {
node->dst_pe = args[6];
}

ICHECK(!(node->is_remote_push() && node->is_remote_pull()))
<< "At least one of src_pe or dst_pe must be local rank";

if (node->is_remote_push()) {
ICHECK(node->dst.scope() == "global")
<< "Can only copy to peer's global memory, but got "
<< node->dst.scope();
} else if (node->is_remote_pull()) {
ICHECK(node->src.scope() == "global")
<< "Can only pull from peer's global memory, but got "
<< node->src.scope();
}

data_ = std::move(node);
}

bool CopyNode::is_remote_push() const {
return !(dst_pe->IsInstance<IntImmNode>() &&
dst_pe.as<IntImmNode>()->value == -1);
}

bool CopyNode::is_remote_pull() const {
return !(src_pe->IsInstance<IntImmNode>() &&
src_pe.as<IntImmNode>()->value == -1);
}

bool CopyNode::is_remote_copy() const {
return is_remote_push() || is_remote_pull();
}
Comment on lines +159 to +195
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard remote PE defaults before remote checks.

is_remote_push() / is_remote_pull() run immediately after the optional assignments. Legacy modules compiled before this PR still emit the 5-argument tl.copy; when we deserialize them here, args.size() is only 5, so dst_pe/src_pe stay PrimExpr() and the subsequent dst_pe->IsInstance dereferences a null handle. That’s a hard crash/regression for any pre-existing artifact. Please seed both fields with -1 before the if blocks and make the helpers resilient to an undefined PrimExpr.

Apply this diff to fix the issue:

Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
   ObjectPtr<CopyNode> node = make_object<CopyNode>();
+  node->src_pe = Integer(-1);
+  node->dst_pe = Integer(-1);
   Array<Range> rgs[2];
@@
-  if (args.size() >= 6) {
-    node->src_pe = args[5];
-  }
-  if (args.size() >= 7) {
-    node->dst_pe = args[6];
-  }
+  if (args.size() >= 6) {
+    node->src_pe = args[5];
+  }
+  if (args.size() >= 7) {
+    node->dst_pe = args[6];
+  }
@@
 bool CopyNode::is_remote_push() const {
-  return !(dst_pe->IsInstance<IntImmNode>() &&
-           dst_pe.as<IntImmNode>()->value == -1);
+  if (!dst_pe.defined()) {
+    return false;
+  }
+  if (const auto *imm = dst_pe.as<IntImmNode>()) {
+    return imm->value != -1;
+  }
+  return true;
 }
@@
 bool CopyNode::is_remote_pull() const {
-  return !(src_pe->IsInstance<IntImmNode>() &&
-           src_pe.as<IntImmNode>()->value == -1);
+  if (!src_pe.defined()) {
+    return false;
+  }
+  if (const auto *imm = src_pe.as<IntImmNode>()) {
+    return imm->value != -1;
+  }
+  return true;
 }

Committable suggestion skipped: line range outside the PR's diff.


/**
* @brief Create a shallow clone of this CopyNode as a TileOperator.
*
Expand Down Expand Up @@ -1940,11 +1976,11 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {

// Register the Copy operation with TVM's TIR system
// This makes the copy operation available for use in TVM programs
// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
// eviction_policy
// - Takes 8 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
// eviction_policy, src_pe, dst_pe
// - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Copy, copy)
.set_num_inputs(5)
.set_num_inputs(7)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

Expand Down
9 changes: 9 additions & 0 deletions src/op/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ class CopyNode : public TileOperatorNode {
IntImm coalesced_width; // Width (in elements) for coalesced memory access
Bool disable_tma = Bool(false); // Whether to disable TMA acceleration

// Params for remote copy
PrimExpr src_pe; // Source PE for remote copy
PrimExpr dst_pe; // Destination PE for remote copy
Buffer symm_buffer; // Symmetric buffer for remote copy

bool is_remote_copy() const;
bool is_remote_push() const;
bool is_remote_pull() const;

mutable ParallelOp par_op_; // Optional associated parallelization operator

enum class EvictionPolicy : uint8_t {
Expand Down
15 changes: 15 additions & 0 deletions src/op/distributed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,26 @@ TIR_DEFINE_TL_BUILTIN(get_num_ranks)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(get_remote_base)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(get_remote_base_ptr)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(get_local_base)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(get_local_base_ptr)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(get_uintptr_t)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
23 changes: 19 additions & 4 deletions src/op/distributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,22 +212,37 @@ const Op &FcollectBlock();
const Op &CpengineCpAsync();

/*!
* \brief tvm intrinsics for getting the rank of the current process
* \brief tilescale intrinsics for getting the rank of the current process
*/
const Op &get_rank();

/*!
* \brief tvm intrinsics for getting the number of processes
* \brief tilescale intrinsics for getting the number of ranks
*/
const Op &get_num_ranks();

/*!
* \brief tvm intrinsics for getting the remote base pointer
* \brief tilescale intrinsics for getting the remote base address (u64)
*/
const Op &get_remote_base();

/*!
* \brief tilescale intrinsics for getting the remote base pointer
*/
const Op &get_remote_base_ptr();

/*!
* \brief tvm intrinsics for getting the uintptr_t of a pointer
* \brief tilescale intrinsics for getting the local base address (u64)
*/
const Op &get_local_base();

/*!
* \brief tilescale intrinsics for getting the local base pointer
*/
const Op &get_local_base_ptr();

/*!
* \brief tilescale intrinsics for getting the u64 value of a pointer
*/
const Op &get_uintptr_t();
} // namespace tl
Expand Down
8 changes: 4 additions & 4 deletions src/op/remote_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ Stmt PutOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
PrimExpr dst_addr_expr = MakeRemappedAddress(T, dst_buffer, dst_indices);
PrimExpr local_rank = Call(DataType::Int(64), tl::get_rank(), {});
PrimExpr local_base_ptr =
Call(DataType::Handle(), tl::get_remote_base_ptr(), {local_rank});
Call(DataType::Handle(), tl::get_remote_base(), {local_rank});
PrimExpr offset_to_base =
Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {dst_addr_expr}),
local_base_ptr);
new_args.push_back(
Call(DataType::Handle(), tl::get_remote_base_ptr(), {dst_pe}) +
Call(DataType::Handle(), tl::get_remote_base(), {dst_pe}) +
offset_to_base);
} else {
new_args.push_back(MakeRemappedAddress(T, dst_buffer, dst_indices));
Expand Down Expand Up @@ -206,12 +206,12 @@ Stmt GetOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
PrimExpr src_addr_expr = MakeRemappedAddress(T, src_buffer, src_indices);
PrimExpr local_rank = Call(DataType::Int(64), tl::get_rank(), {});
PrimExpr local_base_ptr =
Call(DataType::Handle(), tl::get_remote_base_ptr(), {local_rank});
Call(DataType::Handle(), tl::get_remote_base(), {local_rank});
PrimExpr offset_to_base =
Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {src_addr_expr}),
local_base_ptr);
new_args.push_back(
Call(DataType::Handle(), tl::get_remote_base_ptr(), {src_pe}) +
Call(DataType::Handle(), tl::get_remote_base(), {src_pe}) +
offset_to_base);
} else {
new_args.push_back(MakeRemappedAddress(T, src_buffer, src_indices));
Expand Down
Loading
Loading