From f60641148d7f35b7054a7ce80f0fe9feb1eccf6b Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Fri, 13 Feb 2026 22:47:00 +0800 Subject: [PATCH 1/3] move intranode ipc-based examples to unique folder and add test --- .../example_allgather_gemm_overlapped.py | 0 .../example_gemm_rs_overlapped.py | 0 .../example_sp_ag_attention_intra_node.py | 0 .../{ => intranode}/gemm_rs_utils.py | 0 .../{ => intranode}/reduce_scatter.py | 0 .../sp_ag_attention_intra_node.py | 0 .../distributed/intranode/test_intranode.py | 30 +++++++++++++++++++ 7 files changed, 30 insertions(+) rename examples/distributed/{ => intranode}/example_allgather_gemm_overlapped.py (100%) rename examples/distributed/{ => intranode}/example_gemm_rs_overlapped.py (100%) rename examples/distributed/{ => intranode}/example_sp_ag_attention_intra_node.py (100%) rename examples/distributed/{ => intranode}/gemm_rs_utils.py (100%) rename examples/distributed/{ => intranode}/reduce_scatter.py (100%) rename examples/distributed/{ => intranode}/sp_ag_attention_intra_node.py (100%) create mode 100644 examples/distributed/intranode/test_intranode.py diff --git a/examples/distributed/example_allgather_gemm_overlapped.py b/examples/distributed/intranode/example_allgather_gemm_overlapped.py similarity index 100% rename from examples/distributed/example_allgather_gemm_overlapped.py rename to examples/distributed/intranode/example_allgather_gemm_overlapped.py diff --git a/examples/distributed/example_gemm_rs_overlapped.py b/examples/distributed/intranode/example_gemm_rs_overlapped.py similarity index 100% rename from examples/distributed/example_gemm_rs_overlapped.py rename to examples/distributed/intranode/example_gemm_rs_overlapped.py diff --git a/examples/distributed/example_sp_ag_attention_intra_node.py b/examples/distributed/intranode/example_sp_ag_attention_intra_node.py similarity index 100% rename from examples/distributed/example_sp_ag_attention_intra_node.py rename to examples/distributed/intranode/example_sp_ag_attention_intra_node.py diff --git a/examples/distributed/gemm_rs_utils.py b/examples/distributed/intranode/gemm_rs_utils.py similarity index 100% rename from examples/distributed/gemm_rs_utils.py rename to examples/distributed/intranode/gemm_rs_utils.py diff --git a/examples/distributed/reduce_scatter.py b/examples/distributed/intranode/reduce_scatter.py similarity index 100% rename from examples/distributed/reduce_scatter.py rename to examples/distributed/intranode/reduce_scatter.py diff --git a/examples/distributed/sp_ag_attention_intra_node.py b/examples/distributed/intranode/sp_ag_attention_intra_node.py similarity index 100% rename from examples/distributed/sp_ag_attention_intra_node.py rename to examples/distributed/intranode/sp_ag_attention_intra_node.py diff --git a/examples/distributed/intranode/test_intranode.py b/examples/distributed/intranode/test_intranode.py new file mode 100644 index 000000000..6a86c54d2 --- /dev/null +++ b/examples/distributed/intranode/test_intranode.py @@ -0,0 +1,30 @@ +import torch +import tilelang +import tilelang.language as T +import tilelang.testing + +import example_allgather_gemm_overlapped +import example_gemm_rs_overlapped +import example_sp_ag_attention_intra_node + + +@tilelang.testing.requires_distributed +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_allgather_gemm_overlapped(): + torch.multiprocessing.spawn(example_allgather_gemm_overlapped.main, args=(2, None), nprocs=2) + + +@tilelang.testing.requires_distributed +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_gemm_rs_overlapped(): + torch.multiprocessing.spawn(example_gemm_rs_overlapped.main, args=(2, None), nprocs=2) + + + +@tilelang.testing.requires_distributed +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) +def test_example_sp_ag_attention_intra_node(): + torch.multiprocessing.spawn(example_sp_ag_attention_intra_node.main, args=(2, None), nprocs=2) From 059ce9743a3a23225595f86f2f3feecfb6f90587 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Fri, 13 Feb 2026 22:48:28 +0800 Subject: [PATCH 2/3] remove legacy allgather gemm example --- .../distributed/example_allgather_gemm.py | 113 ------------------ 1 file changed, 113 deletions(-) delete mode 100644 examples/distributed/example_allgather_gemm.py diff --git a/examples/distributed/example_allgather_gemm.py b/examples/distributed/example_allgather_gemm.py deleted file mode 100644 index 702f1264a..000000000 --- a/examples/distributed/example_allgather_gemm.py +++ /dev/null @@ -1,113 +0,0 @@ -import torch -import pynvshmem -import os -import tilelang -import tilelang.language as T -from tilelang.profiler import TensorSupplyType -from tilelang.distributed import init_distributed - - -def allgather_gemm(PE_num, M, N, K, block_M, block_N, block_K, dtype="float16"): - accum_dtype = "float" - - @T.prim_func - def main( - A: T.Buffer((M, K), dtype), - A_ag: T.Buffer((M * PE_num, K), dtype), - B: T.Buffer((K, N), dtype), - signal: T.Buffer((PE_num,), "uint64"), - C: T.Buffer((M * PE_num, N), dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - mype = T.alloc_local([1], "int32") - npes = T.alloc_local([1], "int32") - peer = T.alloc_local([1], "int32") - - mype[0] = T.get_pe() - npes[0] = T.get_pe_num() - - T.copy(A[by * block_M, bx * block_K], A_shared) - T.copy(A_shared, A_ag[mype[0] * M, bx * block_K]) - for k in T.serial(PE_num - 1): - peer[0] = (mype[0] + 1 + k) % npes[0] - T.putmem_signal_nbi_block( - T.address_of(A_ag[mype[0] * M, 0]), - T.address_of(A[0, 0]), - block_M * block_K * 2, - T.address_of(signal[k]), - k + 1, - 9, - peer[0], - ) - for k in T.serial(PE_num - 1): - T.signal_wait_until(T.address_of(signal[k]), 0, k + 1) - - for bk in T.serial(PE_num): - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): - T.copy(A_ag[bk * M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local) - T.copy(C_local, C[bk * M, bx * block_N]) - - return main - - -tilelang.disable_cache() -M, N, K, block_M, block_N, block_K = 64, 64, 64, 64, 64, 64 -dtype = torch.float16 - -RANK = int(os.environ.get("RANK", 0)) -WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) -PE_num = WORLD_SIZE -func = allgather_gemm(PE_num, M, N, K, block_M, block_N, block_K) -kernel = tilelang.compile(func, out_idx=-1, pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) - -# Get CUDA Source -if RANK == 0: - print(kernel.get_kernel_source()) - -profiler = kernel.get_profiler(tensor_supply_type=TensorSupplyType.Randn) - -A_tensor = torch.arange(M * PE_num * K, dtype=dtype).cuda() * 0.001 -A_tensor = A_tensor.reshape(M * PE_num, K) -B_tensor = torch.arange(K * N, dtype=dtype).cuda() * 0.001 -B_tensor = B_tensor.reshape(K, N) - -print("A_tensor:", A_tensor) -print("B_tensor:", B_tensor) - - -def ref_program(A, B): - return A @ B - - -C_ref = ref_program(A_tensor, B_tensor) -print("C_ref:", C_ref) - -# profiler.init_distributed() -A_local = pynvshmem.nvshmem_create_tensor([M, K], dtype) -A_local[:].copy_(A_tensor[M * RANK : M * (RANK + 1), :]) - -A_ag_local = pynvshmem.nvshmem_create_tensor([M * PE_num, K], dtype) -A_ag_local.fill_(0) - -B_local = pynvshmem.nvshmem_create_tensor([K, N], dtype) -B_local[:].copy_(B_tensor) - -signal_local = pynvshmem.nvshmem_create_tensor([PE_num], torch.uint64) -signal_local.fill_(0) - -out = kernel(A_local, A_ag_local, B_local, signal_local) -print("out:", out) - -ref_cpu = C_ref.cpu() -for i in range(PE_num): - if i == RANK: - out_cpu = out.cpu() - assert torch.allclose(out_cpu, ref_cpu, atol=1e-2, rtol=1e-2) - print(f"rank {i} check passed.") From 711c8b5d25130cd986358e46bd0ee1de6e72797e Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Fri, 13 Feb 2026 22:53:20 +0800 Subject: [PATCH 3/3] lint --- examples/distributed/intranode/test_intranode.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/distributed/intranode/test_intranode.py b/examples/distributed/intranode/test_intranode.py index 6a86c54d2..6192ed3ac 100644 --- a/examples/distributed/intranode/test_intranode.py +++ b/examples/distributed/intranode/test_intranode.py @@ -1,6 +1,5 @@ import torch import tilelang -import tilelang.language as T import tilelang.testing import example_allgather_gemm_overlapped @@ -22,7 +21,6 @@ def test_example_gemm_rs_overlapped(): torch.multiprocessing.spawn(example_gemm_rs_overlapped.main, args=(2, None), nprocs=2) - @tilelang.testing.requires_distributed @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_eq(9, 0)