diff --git a/benchmark/distributed/README.md b/benchmark/distributed/README.md deleted file mode 100644 index ac1cea257..000000000 --- a/benchmark/distributed/README.md +++ /dev/null @@ -1 +0,0 @@ -To compare with [TileLink](https://arxiv.org/abs/2503.20313), please install [Triton-distributed](https://github.com/ByteDance-Seed/Triton-distributed). \ No newline at end of file diff --git a/benchmark/distributed/benchmark_ag_gemm.py b/benchmark/distributed/benchmark_ag_gemm.py deleted file mode 100644 index a4b0bd785..000000000 --- a/benchmark/distributed/benchmark_ag_gemm.py +++ /dev/null @@ -1,264 +0,0 @@ -'''Bugfix first: -Triton-distributed/python/triton_dist/kernels/nvidia/allgather_gemm.py:566 -```python -M = M_per_rank * ctx.num_ranks -``` -should be: -```python -M = M_per_rank * num_ranks -``` -''' - -#TODO: further tune the performance - -import argparse -import torch -import torch.distributed as dist -import pynvshmem -import tilelang -import tilelang.language as T -from tilelang.carver.arch import driver -from tilelang.distributed import init_distributed, dtype_map, perf_fn -from triton_dist.kernels.nvidia.allgather_gemm import ag_gemm, create_ag_gemm_context -from functools import partial - -tilelang.disable_cache() - - -@tilelang.jit( - out_idx=-1, - pass_configs={"tl.disable_rdc": True} - #FIXME: https://github.com/tile-ai/tilelang/issues/659 -) -def matmut_transpose(rank, - num_ranks, - M, - N_per_rank, - K, - block_M, - block_N, - block_K, - dtype="float16", - threads=256, - persistent=False) -> tilelang.JITKernel: - accum_dtype = "float32" - signal_dtype = "uint64" # NVSHMEM requires uint64 for signal - - assert M % block_M == 0 and N_per_rank % block_N == 0 and K % block_K == 0 - M_blocks, N_blocks, K_stages = T.ceildiv(M, block_M), T.ceildiv(N_per_rank, - block_N), T.ceildiv(K, block_K) - M_blocks_per_rank = M_blocks // num_ranks - - sm_num = driver.get_num_sms() # Get # of SMs for persistent kernel - - @T.prim_func - def nonpersistent_kernel( - A: T.Tensor((M, K), dtype), # type: ignore - B: T.Tensor((N_per_rank, K), dtype), # type: ignore - signal: T.Tensor((num_ranks), signal_dtype), # type: ignore - C: T.Tensor((M, N_per_rank), dtype), # type: ignore - ): - with T.Kernel(N_blocks, M_blocks, threads=threads) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_N, block_K), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), dtype) - - # thread-block swizzle for allgather - T.use_swizzle(10, order="column", offset=rank * M_blocks_per_rank) - - T.clear(C_local) - - src_rank = by // M_blocks_per_rank - T.signal_wait_until(T.address_of(signal[src_rank]), T.CmpType.EQ, 1) - for k in T.Pipelined(K_stages, num_stages=3): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm(A_shared, B_shared, C_local, transpose_B=True) - - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - @T.prim_func - def persistent_kernel( - A: T.Tensor((M, K), dtype), # type: ignore - B: T.Tensor((N_per_rank, K), dtype), # type: ignore - signal: T.Tensor((num_ranks), signal_dtype), # type: ignore - C: T.Tensor((M, N_per_rank), dtype), # type: ignore - ): - with T.Kernel(sm_num, threads=threads) as (block_id): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_N, block_K), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), dtype) - - for bx, by in T.Persistent([M_blocks, N_blocks], sm_num, block_id): - T.clear(C_local) - - src_rank = bx // M_blocks_per_rank - T.signal_wait_until(T.address_of(signal[src_rank]), T.CmpType.EQ, 1) - - for k in T.Pipelined(K_stages, num_stages=3): - T.copy(A[bx * block_M, k * block_K], A_shared) - T.copy(B[by * block_N, k * block_K], B_shared) - T.gemm(A_shared, B_shared, C_local, transpose_B=True) - - T.copy(C_local, C_shared) - T.copy(C_shared, C[bx * block_M, by * block_N]) - - return persistent_kernel if persistent else nonpersistent_kernel - - -def overlapped_ag_gemm( - A: torch.Tensor, - B: torch.Tensor, - rank: int, - num_ranks: int, - persistent: bool = False, -) -> torch.Tensor: - """ - Overlapped AllGather-GEMM. - Args: - A: local input of shape (M_per_rank, K) - B: local weight of shape (N_per_rank, K) - rank: current rank - num_ranks: total number of ranks - persistent: whether to use persistent GEMM consumers - Returns: - Output of shape (M, N_per_rank) - """ - - M_per_rank, K = A.shape - N_per_rank, _ = B.shape - assert A.shape[1] == B.shape[1], "A and B must have the same inner dimension" - M = M_per_rank * num_ranks - - # Prepare kernel and buffers - consumer = matmut_transpose( - rank=rank, - num_ranks=num_ranks, - M=M, - N_per_rank=N_per_rank, - K=K, - block_M=128, - block_N=256, - block_K=64, - dtype=dtype, - threads=threads, - persistent=persistent) - if RANK == 0 and args.print_source: - print('We currently use cp-engine for producer, print consumer kernel code only...') - print(consumer.get_kernel_source()) - - ag_buffer = pynvshmem.nvshmem_create_tensor_list_intra_node( - shape=[M, K], - dtype=A.dtype, - ) - signal_buffer = torch.zeros([num_ranks], dtype=torch.uint64, device="cuda") - - # We place copy-based AllGather and GEMM on two streams to implement inter-op comm-comp overlapping - ag_stream = torch.cuda.current_stream() - gemm_stream = torch.cuda.Stream(priority=-1) - current_stream = torch.cuda.current_stream() - ag_stream.wait_stream(current_stream) - gemm_stream.wait_stream(current_stream) - - with torch.cuda.stream(ag_stream): - ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A) - pynvshmem.write64_on_stream(signal_buffer[rank], 1, ag_stream) - pynvshmem.nvshmemx_barrier_all_on_stream( - ag_stream.cuda_stream) # Ensure visible to all ranks - rank_orders = [(rank + i) % num_ranks for i in range(1, num_ranks)] - for src_rank in rank_orders: - dst = ag_buffer[rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :] - src = ag_buffer[src_rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :] - dst.copy_(src) - pynvshmem.write64_on_stream(signal_buffer[src_rank], 1, ag_stream) - - with torch.cuda.stream(gemm_stream): - out = consumer(ag_buffer[rank], B, signal_buffer) - - current_stream.wait_stream(ag_stream) - current_stream.wait_stream(gemm_stream) - return out - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--M", type=int, default=8192) - parser.add_argument("--N", type=int, default=49152) - parser.add_argument("--K", type=int, default=12288) - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) - parser.add_argument("--threads", type=int, default=256, help="number of threads in a block") - parser.add_argument( - "--persistent", action='store_true', default=False, help="use persistent GEMM consumers") - parser.add_argument("--print_source", action="store_true", help="print kernel source code") - parser.add_argument("--warmup", type=int, default=5, help="number of warmup iterations") - parser.add_argument("--repeat", type=int, default=10, help="number of repeat iterations") - return parser.parse_args() - - -if __name__ == '__main__': - assert torch.cuda.get_device_capability()[0] >= 9, '❗This benchmark requires sm_90 or higher' - - WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) - assert WORLD_SIZE <= 8, "This benchmark is designed for intra-node AG-GEMM" - - args = parse_args() - M, N, K, dtype, threads, warmup, repeat = args.M, args.N, args.K, args.dtype, args.threads, args.warmup, args.repeat - PE_num = WORLD_SIZE - assert M % PE_num == 0 and N % PE_num == 0, "M and N must be divisible by PE_num" - M_per_rank, N_per_rank = M // PE_num, N // PE_num - torch_dtype = dtype_map[dtype] - - ## Inputs: A (M_per_rank, K), B (N_per_rank, K) - ## Output: ag(A) @ B.T (M, N_per_rank) - - A = torch.randn([M_per_rank, K], dtype=torch_dtype, device="cuda") - B = torch.randn([N_per_rank, K], dtype=torch_dtype, device="cuda") - - # Benchmark Torch (non-overlapped baseline) - def torch_ag_gemm(): - ag_buffer = torch.empty([M, K], dtype=torch_dtype, device="cuda") - dist.all_gather_into_tensor(ag_buffer, A, TP_GROUP) - return ag_buffer @ B.T - - dist.barrier(TP_GROUP) - torch_out, torch_t = perf_fn(torch_ag_gemm, warmup, repeat) - print(f"rank {RANK} torch AG-GEMM avg time: {torch_t} ms") - - # Benchmark Triton-dist (overlapped) - ag_intranode_stream = torch.cuda.Stream(priority=-1) - - ctx = create_ag_gemm_context( - A, B, RANK, PE_num, max_M=M, for_correctness=False, ag_intranode_stream=ag_intranode_stream) - - def triton_ag_gemm(persistent, autotune): - return ag_gemm( - A, B, ctx=ctx, rank=RANK, num_ranks=PE_num, persistent=persistent, autotune=autotune) - - dist.barrier(TP_GROUP) - triton_ag_gemm = partial(triton_ag_gemm, persistent=False, autotune=False) - tt_out, tt_t = perf_fn(triton_ag_gemm, warmup, repeat) - print(f"rank {RANK} triton AG-GEMM avg time: {tt_t} ms") - - # Benchmark Tilelang-dist (overlapped) - if args.persistent: - print("Use persistent GEMM consumers...") - else: - print("Use non-persistent GEMM consumers...") - - def tilelang_ag_gemm(): - return overlapped_ag_gemm(A, B, rank=RANK, num_ranks=PE_num, persistent=args.persistent) - - dist.barrier(TP_GROUP) - tl_out, tl_t = perf_fn(tilelang_ag_gemm, warmup, repeat) - print(f"rank {RANK} tilelang AG-GEMM avg time: {tl_t} ms") - - # Check correctness - assert torch.allclose( - tl_out, torch_out, atol=1e-2, rtol=1e-2), f'max error: {(tl_out - torch_out).abs().max()}' - print(f"rank {RANK} check passed.✅") - - dist.destroy_process_group() diff --git a/benchmark/distributed/benchmark_all_gather.py b/benchmark/distributed/benchmark_all_gather.py deleted file mode 100644 index 24d3445b2..000000000 --- a/benchmark/distributed/benchmark_all_gather.py +++ /dev/null @@ -1,152 +0,0 @@ -from __future__ import annotations - -import argparse -import torch -import torch.distributed as dist -import pynvshmem -import tilelang -import tilelang.language as T -from tilelang.distributed import init_distributed, dtype_map, perf_fn - -tilelang.disable_cache() - - -# Copied from Triton-distributed/tutorials/02-intra-node-allgather.py -# This is the default AllGather impl. in Triton-dist given full-mesh NVLink -def cp_engine_producer_all_gather_full_mesh_pull( - rank, - num_ranks, - local_tensor: torch.Tensor, - remote_tensor_buffers: list[torch.Tensor], - ag_stream: torch.cuda.Stream, - barrier_buffers: list[torch.Tensor], -): - M_per_rank, _ = local_tensor.shape - - rank_orders = [(rank + i) % num_ranks for i in range(num_ranks)] - - with torch.cuda.stream(ag_stream): - for src_rank in rank_orders: - if src_rank == rank: - continue - # peer: src_rank, offset src_rank[src_rank] -> rank[src_rank] - dst = remote_tensor_buffers[rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :] - src = remote_tensor_buffers[src_rank][src_rank * M_per_rank:(src_rank + 1) * - M_per_rank, :] - dst.copy_(src) - pynvshmem.write64_on_stream( - barrier_buffers[rank][src_rank], - 1, - stream=ag_stream, - ) - - -def allgather(PE_num, M, N, dtype="float16", threads=128): - M_per_rank = M // PE_num - block_M = 4 - - @T.prim_func - def a2a_pull( - A: T.Tensor((M_per_rank, N), dtype), # type: ignore - B: T.Tensor((M, N), dtype), # type: ignore - ): - with T.Kernel(M_per_rank // block_M, PE_num - 1, threads=threads) as (bx, by): - mype = T.get_pe() - npes = T.get_pe_num() - peer = (mype + by + 1) % npes - - T.getmem_nbi_block( - T.address_of(B[peer * M_per_rank + bx * block_M, 0]), - T.address_of(A[bx * block_M, 0]), block_M * N * dtype_map[dtype].itemsize, peer) - # We don't need a barrier for the pull mode - - return a2a_pull - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--M", type=int, - default=8192) # Follow Triton-setting, we benchmark on (M, N) = (8192, 12288) - parser.add_argument("--N", type=int, default=12288) - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) - parser.add_argument("--threads", type=int, default=128, help="number of threads in a block") - parser.add_argument("--print_source", action="store_true", help="print kernel source code") - parser.add_argument("--warmup", type=int, default=5, help="number of warmup iterations") - parser.add_argument("--repeat", type=int, default=10, help="number of repeat iterations") - return parser.parse_args() - - -if __name__ == '__main__': - WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) - assert WORLD_SIZE <= 8, "This benchmark is designed for intra-node communication" - - args = parse_args() - M, N, dtype, threads, warmup, repeat = args.M, args.N, args.dtype, args.threads, args.warmup, args.repeat - PE_num = WORLD_SIZE - assert M % PE_num == 0, "M must be divisible by PE_num" - M_per_rank = M // PE_num - torch_dtype = dtype_map[dtype] - nelems = M * PE_num - - func = allgather(PE_num, M, N, dtype=dtype, threads=threads) - kernel = tilelang.compile(func, pass_configs={"tl.disable_tma_lower": True}) - - # Get CUDA Source - if RANK == 0 and args.print_source: - print(kernel.get_kernel_source()) - - local_data = torch.randn([M_per_rank, N], dtype=torch_dtype).cuda() - - # Benchmark Torch - def torch_ag(): - out = torch.empty((M, N), dtype=torch_dtype).cuda() - dist.all_gather_into_tensor(out, local_data, group=TP_GROUP) - return out - - dist.barrier(TP_GROUP) - torch_out, torch_t = perf_fn(torch_ag, warmup, repeat) - print(f"rank {RANK} torch all_gather avg time: {torch_t} ms") - - # Benchmark Triton-dist - def triton_ag(): - ag_buffer_ptrs = pynvshmem.nvshmem_create_tensor_list_intra_node( - [M, N], torch_dtype) # buffer for dist-triton allgather - signal = pynvshmem.nvshmem_create_tensor_list_intra_node( - ([PE_num]), torch.uint64) # each rank corresponds to one barrier - ag_buffer_ptrs[RANK][ - RANK * M_per_rank:(RANK + 1) * M_per_rank, - ].copy_(local_data) - signal[RANK].zero_() - pynvshmem.nvshmemx_barrier_all_on_stream(torch.cuda.current_stream().cuda_stream) - cp_engine_producer_all_gather_full_mesh_pull( - RANK, PE_num, local_data, ag_buffer_ptrs, torch.cuda.current_stream(), signal - ) # Here we use current stream for allgather, we can pass any other stream for comm-comp fusion. - return ag_buffer_ptrs[RANK] - - dist.barrier(TP_GROUP) - tt_out, tt_t = perf_fn(triton_ag, warmup, repeat) - print(f"rank {RANK} triton all_gather avg time: {tt_t} ms") - - # Benchmark Tilelang-dist - def tilelang_ag(): - ag_buffer = pynvshmem.nvshmem_create_tensor([M_per_rank, N], torch_dtype) - ag_buffer.copy_(local_data) - out = pynvshmem.nvshmem_create_tensor([M, N], torch_dtype) - out[RANK * M_per_rank:(RANK + 1) * M_per_rank, :].copy_(local_data) - kernel(ag_buffer, out) - - return out - - dist.barrier(TP_GROUP) - tl_out, tl_t = perf_fn(tilelang_ag, warmup, repeat) - print(f"rank {RANK} tilelang all_gather avg time: {tl_t} ms") - # Tested on 4A100 with full-mesh NVLink, comparable with Triton-dist and ~20x faster than Torch - - # Check correctness - assert torch.allclose( - tl_out, torch_out, atol=0, rtol=0), f'max error: {(tl_out - torch_out).abs().max()}' - print(f"rank {RANK} check passed.✅") - - dist.destroy_process_group() diff --git a/benchmark/distributed/benchmark_all_to_all.py b/benchmark/distributed/benchmark_all_to_all.py deleted file mode 100644 index 6aae8b203..000000000 --- a/benchmark/distributed/benchmark_all_to_all.py +++ /dev/null @@ -1,321 +0,0 @@ -from __future__ import annotations - -import torch -import tilelang -import tilelang.language as T -from tilelang.distributed import init_distributed, dtype_map -import argparse -import random -from triton_dist.kernels.nvidia import fast_all_to_all, all_to_all_post_process -from benchmark.distributed.utils import create_all_to_all_context, AllToAllContext - -tilelang.disable_cache() - - -def all_to_all(max_m, hidden, num_tot_experts, WORLD_SIZE, threads=128, dtype="float16"): - - scale_dtype = "float" - EXPERTS_PER_RANK = num_tot_experts // WORLD_SIZE - - @T.prim_func - def main( - send_buf: T.Tensor((max_m, hidden), dtype), # type: ignore - recv_buf: T.Tensor((WORLD_SIZE * max_m * 2, hidden), dtype), # type: ignore - scale_send_buf: T.Tensor((max_m), scale_dtype), # type: ignore - scale_recv_buf: T.Tensor((WORLD_SIZE * max_m * 2), scale_dtype), # type: ignore - split_send_buf: T.Tensor((num_tot_experts), "int32"), # type: ignore - split_recv_buf: T.Tensor((num_tot_experts * 2), "int32"), # type: ignore - signal_buf: T.Tensor((WORLD_SIZE * 2), "uint64"), # type: ignore - ): - with T.Kernel(WORLD_SIZE, threads=threads) as (bx): - peer = bx - tx = T.thread_binding(threads, thread="threadIdx.x") - - mype = T.alloc_local([1], "int32") - npes = T.alloc_local([1], "int32") - m_start = T.alloc_local([1], "int32") - m_end = T.alloc_local([1], "int32") - mype[0] = T.get_pe() - npes[0] = T.get_pe_num() - m_start[0] = split_send_buf[peer * EXPERTS_PER_RANK] - m_end[0] = split_send_buf[(peer + 1) * EXPERTS_PER_RANK] - - # T.putmem_nbi_block( - # T.address_of(recv_buf[0, 0]), T.address_of(send_buf[m_start[0], 0]), - # (m_end[0] - m_start[0]) * hidden * 2, peer) - - T.fence() - - if tx == 0: - T.signal_op( - T.address_of(signal_buf[mype[0]]), - 99, - T.Amo.SIGNAL_SET, - peer, - ) - T.signal_wait_until( - T.address_of(signal_buf[peer]), - T.CmpType.EQ, - 99, - ) - - return main - - -class TilelangAllToAll: - - def __init__(self, ctx: AllToAllContext): - self.ctx = ctx - self.func = all_to_all( - ctx.max_m, ctx.hidden, ctx.num_tot_experts, ctx.WORLD_SIZE, threads=128) - self.kernel = tilelang.compile(self.func, pass_configs={"tl.disable_tma_lower": True}) - if self.ctx.rank == 0: - print(self.kernel.get_kernel_source()) - - def __call__(self, send_tensor: torch.Tensor, send_split_cumsum: torch.Tensor, - send_scale: torch.Tensor | None): - """ - low-latency all-to-all communication - """ - with_scale = send_scale is not None - - act_pos = self.ctx.call_count % 2 - - split_buf_st = act_pos * self.ctx.num_tot_experts - split_buf_ed = split_buf_st + self.ctx.num_tot_experts - - data_buf_st = act_pos * self.ctx.WORLD_SIZE * self.ctx.max_m - data_buf_ed = data_buf_st + self.ctx.WORLD_SIZE * self.ctx.max_m - - scale_buf_st = act_pos * self.ctx.WORLD_SIZE * self.ctx.max_m - scale_buf_ed = scale_buf_st + self.ctx.WORLD_SIZE * self.ctx.max_m - - num_tokens = send_tensor.shape[0] - assert num_tokens <= self.ctx.max_m - self.ctx.send_buf[:num_tokens, :] = send_tensor - if with_scale: - self.ctx.scale_send_buf[:num_tokens] = send_scale - - self.kernel( - self.ctx.send_buf, - self.ctx.recv_buf, - self.ctx.scale_send_buf, - self.ctx.scale_recv_buf, - self.ctx.split_send_buf, - self.ctx.split_recv_buf, - self.ctx.signal_buf, - ) - - self.ctx.call_count = (self.ctx.call_count + 1) % self.ctx.MOD_VALUE - out_lis: list[torch.Tensor] = [] - out_lis.append(self.ctx.split_recv_buf[split_buf_st:split_buf_ed]) - out_lis.append(self.ctx.recv_buf[data_buf_st:data_buf_ed, :]) - if with_scale: - out_lis.append(self.ctx.scale_recv_buf[scale_buf_st:scale_buf_ed]) - else: - out_lis.append(None) - return out_lis - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("-M", type=int, default=8) - parser.add_argument("-N", type=int, default=3584) - parser.add_argument("-G", type=int, default=128) - parser.add_argument("--topk", type=int, default=8) - parser.add_argument("--bench_iters", default=1, type=int, help="perf iterations") - parser.add_argument("--rounds", default=1, type=int, help="random data round") - parser.add_argument("--sm_margin", default=16, type=int, help="sm margin") - parser.add_argument("--dtype", default="float16", help="data type") - parser.add_argument("--profile", action="store_true") - parser.add_argument("--with_scale", action="store_true") - parser.add_argument("--print_source", action="store_true") - parser.add_argument("--threads", type=int, default=128) - return parser.parse_args() - - -def generate_random_exp_indices(token_num, total_num_experts, topk): - exp_indices = [] - exp_list = list(range(total_num_experts)) - for _ in range(token_num): - top_selected = random.sample(exp_list, topk) - exp_indices.append(top_selected) - return torch.Tensor(exp_indices).int() - - -def splits_to_cumsum(splits: torch.Tensor): - out = torch.empty(splits.shape[0] + 1, dtype=splits.dtype, device=splits.device) - out[0] = 0 - _ = torch.cumsum(splits, 0, out=out[1:]) - return out - - -import torch.distributed -import triton -import triton.language as tl - - -def calc_gather_index( - scatter_index: torch.Tensor, - row_start: int, - row_end: int, - BLOCK_SIZE: int = 1024, -): - - @triton.jit - def _kernel( - scatter_index: torch.Tensor, - gather_index: torch.Tensor, - topk_index: torch.Tensor, - ntokens: int, - topk: int, - row_start: int, - row_end: int, - BLOCK_SIZE: tl.constexpr, - ): - pid = tl.program_id(axis=0) - offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offset < ntokens * topk - scatter_idx = tl.load(scatter_index + offset, mask=mask, other=-1) - token_idx = offset // topk - topk_idx = offset % topk - token_idx_mask = (scatter_idx >= row_start) & (scatter_idx < row_end) - tl.store(gather_index + scatter_idx - row_start, token_idx, mask=token_idx_mask) - tl.store(topk_index + scatter_idx - row_start, topk_idx, mask=token_idx_mask) - - ntokens, topk = scatter_index.shape - gather_index = torch.zeros(row_end - row_start, dtype=torch.int32, device=scatter_index.device) - topk_index = torch.zeros(row_end - row_start, dtype=torch.int32, device=scatter_index.device) - grid = lambda META: (triton.cdiv(ntokens * topk, META["BLOCK_SIZE"]),) # noqa: E731 - _kernel[grid]( - scatter_index, - gather_index, - topk_index, - ntokens, - topk, - row_start, - row_end, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=BLOCK_SIZE // 32, - ) - return gather_index, topk_index - - -def calc_scatter_index_stable(choosed_experts: torch.Tensor): - return (choosed_experts.flatten().argsort(stable=True).argsort().int().view( - choosed_experts.shape)) - - -def main(): - WORLD_SIZE, RANK, LOCAL_RANK = init_distributed() - - args = parse_args() - token_num = args.M // 2 - experts_per_rank = args.G // WORLD_SIZE - - all_to_all_ctx = create_all_to_all_context( - # set max_m to 2 * M * topk to avoid bug in combine for now - # TODO: Check this - args.M * args.topk * 2, - args.N, - RANK, - args.G, - WORLD_SIZE, - experts_per_rank, - dtype_map[args.dtype], - torch.float, - ) - - def perf_triton(input: torch.Tensor, scale_tensor: torch.Tensor, exp_indices: torch.Tensor): - - # prepare the indexes - splits_gpu_cur_rank = torch.bincount(exp_indices.view(-1), minlength=args.G).to(torch.int32) - split_cumsum = splits_to_cumsum(splits_gpu_cur_rank) - - # calculate the scatter idx - scatter_idx_cur_rank = calc_scatter_index_stable(exp_indices) - # calculate the gather idx accordingly - gather_idx_cur_rank, _ = calc_gather_index(scatter_idx_cur_rank, 0, token_num * args.topk) - # use torch native scatter forward(will not be included in the e2e time measurement) - scattered_input = torch.empty( - input.size(0) * args.topk, input.size(1), dtype=input.dtype, device=input.device) - scattered_scale_tensor = torch.empty( - (scale_tensor.size(0) * args.topk), - dtype=scale_tensor.dtype, - device=scale_tensor.device, - ) - scattered_input.copy_(torch.index_select(input, dim=0, index=gather_idx_cur_rank)) - scattered_scale_tensor.copy_( - torch.index_select(scale_tensor, dim=0, index=gather_idx_cur_rank)) - - def fwd(): - return fast_all_to_all(all_to_all_ctx, scattered_input, split_cumsum, - scattered_scale_tensor if args.with_scale else None) - - torch.cuda._sleep(1000000000) - # warmup - for _ in range(20): - fwd() - - st = torch.cuda.Event(enable_timing=True) - ed = torch.cuda.Event(enable_timing=True) - # bench - st.record() - for _ in range(args.bench_iters): - _ = fwd() - ed.record() - torch.cuda.synchronize() - avg_time = st.elapsed_time(ed) / args.bench_iters - - # 1. dispatch - dispatch_splits, dispatch_token, dispatch_scale = fast_all_to_all( - all_to_all_ctx, scattered_input, split_cumsum, - scattered_scale_tensor if args.with_scale else None) - dispatch_token, dispatch_scale = all_to_all_post_process( - all_to_all_ctx, dispatch_splits, dispatch_token, - dispatch_scale if args.with_scale else None) - - # 2. compute: moe_compute(dispatch_token, dispatch_scale, moe_weight, ...) - # ... - - # 3. combine - combine_splits, combine_token, combine_scale = fast_all_to_all( - all_to_all_ctx, dispatch_token, splits_to_cumsum(dispatch_splits), dispatch_scale) - combine_token, combine_scale = all_to_all_post_process( - all_to_all_ctx, combine_splits, combine_token, - combine_scale if args.with_scale else None) - - # 3.1. reduce: [num_tokens_local_rank * topk] => [num_tokens_local_rank] - combine_reduced_out = torch.zeros_like(input) - combine_reduced_out.index_add_(0, gather_idx_cur_rank, combine_token) - - # check the output of `dispatch => => combine` - torch.testing.assert_close(combine_reduced_out, input * args.topk, rtol=1e-2, atol=1e-2) - - tilelang_all_to_all = TilelangAllToAll(all_to_all_ctx) - tilelang_all_to_all(scattered_input, split_cumsum, - scattered_scale_tensor if args.with_scale else None) - - # torch.testing.assert_close(tilelang_out[1], dispatch_token, rtol=1e-2, atol=1e-2) - # torch.testing.assert_close(tilelang_scale, dispatch_scale, rtol=1e-2, atol=1e-2) - - return dispatch_token, dispatch_scale, avg_time - - # random simulate token received from dataloader - print(f"Rank-{RANK}: Received {token_num} tokens") - - exp_indices = generate_random_exp_indices(token_num, args.G, args.topk) - assert exp_indices.size(0) == token_num and exp_indices.size(1) == args.topk - exp_indices = exp_indices.to("cuda") - input = ( - torch.rand(token_num, args.N, dtype=torch.float32).to(dtype_map[args.dtype]).to("cuda")) - scale_tensor = torch.rand(token_num, dtype=torch.float32).to("cuda") - - torch.cuda.synchronize() - triton_out, triton_scale, triton_time = perf_triton(input, scale_tensor, exp_indices) - torch.cuda.synchronize() - torch.distributed.barrier() - - -if __name__ == "__main__": - main() diff --git a/benchmark/distributed/benchmark_gemm_rs.py b/benchmark/distributed/benchmark_gemm_rs.py deleted file mode 100644 index 5be4431c3..000000000 --- a/benchmark/distributed/benchmark_gemm_rs.py +++ /dev/null @@ -1,191 +0,0 @@ -# Currently we only implement in Tilelang -#TODO: add Triton-dist v3.4 impl -#TODO: further tune the performance - -import argparse -import torch -import torch.distributed as dist -import pynvshmem -import tilelang -import tilelang.language as T -# from tilelang.carver.arch import driver -from tilelang.distributed import init_distributed, dtype_map, perf_fn - -tilelang.disable_cache() - - -@tilelang.jit(pass_configs={"tl.disable_rdc": True} - #FIXME: https://github.com/tile-ai/tilelang/issues/659 - ) -def fused_gemm_scatter(rank, - num_ranks, - M, - N, - K_per_rank, - block_M, - block_N, - block_K, - dtype="float16", - threads=128, - persistent=False) -> tilelang.JITKernel: - accum_dtype = "float32" - - assert M % block_M == 0 and N % block_N == 0 and K_per_rank % block_K == 0 - M_blocks, N_blocks, K_stages = T.ceildiv(M, block_M), T.ceildiv(N, block_N), T.ceildiv( - K_per_rank, block_K) - M_blocks_per_rank = M_blocks // num_ranks - - # sm_num = driver.get_num_sms() # Get # of SMs for persistent kernel - - @T.prim_func - def nonpersistent_kernel( - A: T.Tensor((M, K_per_rank), dtype), # type: ignore - B: T.Tensor((N, K_per_rank), dtype), # type: ignore - C: T.Tensor((M_blocks, N_blocks, block_M, block_N), dtype), # type: ignore - ): - with T.Kernel(N_blocks, M_blocks, threads=threads) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_N, block_K), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), dtype) - - # thread-block swizzle for allgather - T.use_swizzle(M_blocks, order="column") - - T.clear(C_local) - - for k in T.Pipelined(K_stages, num_stages=3): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm(A_shared, B_shared, C_local, transpose_B=True) - - T.copy(C_local, C_shared) - T.copy(C_shared, C[by, bx, :, :]) - peer = by // M_blocks_per_rank - T.putmem_nbi_block( - T.address_of(C[by, bx, 0, 0]), T.address_of(C[by, bx, 0, 0]), - block_M * block_N * dtype_map[dtype].itemsize, peer) - - assert not persistent - return nonpersistent_kernel - - -# https://github.com/bytedance/flux/blob/main/docs/design.md -def overlapped_gemm_rs( - input: torch.Tensor, - weight: torch.Tensor, - rank: int, - num_ranks: int, - persistent: bool = False, - block_M: int = 128, - block_N: int = 128, - block_K: int = 128, -) -> torch.Tensor: - """Overlapped GEMM with Reduce-Scatter using Tilelang. - Args: - input (torch.Tensor): Input tensor of shape (M, K_per_rank). - weight (torch.Tensor): Weight tensor of shape (N, K_per_rank). - rank (int): Current rank. - num_ranks (int): Total number of ranks. - persistent (bool): Whether to use persistent GEMM producers. - Returns: - torch.Tensor: Output tensor of shape (M_per_rank, N). - """ - - M, K_per_rank = input.shape - N, _ = weight.shape - assert weight.shape[1] == K_per_rank, "Weight tensor's second dimension must match K_per_rank" - M_per_rank = M // num_ranks - M_blocks, N_blocks = M // block_M, N // block_N - - # Prepare kernels and buffers - fused_gemm_scatter_kernel = fused_gemm_scatter( - rank=rank, - num_ranks=num_ranks, - M=M, - N=N, - K_per_rank=K_per_rank, - block_M=block_M, - block_N=block_N, - block_K=block_K, - dtype=dtype, - threads=threads, - persistent=persistent) - - gemm_output = pynvshmem.nvshmem_create_tensor_list_intra_node( - [M_blocks, N_blocks, block_M, block_N], dtype=input.dtype) - output = torch.empty((M_per_rank, N), dtype=input.dtype, device="cuda") - fused_gemm_scatter_kernel(input, weight, gemm_output[rank]) - dist.barrier(TP_GROUP) - output = gemm_output[rank].transpose(1, 2).view((num_ranks, M_per_rank, N)).sum(0) - return output - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--M", type=int, default=16384) - parser.add_argument("--N", type=int, default=12288) - parser.add_argument("--K", type=int, default=49152) - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) - parser.add_argument("--threads", type=int, default=128, help="number of threads in a block") - parser.add_argument( - "--persistent", action='store_true', default=False, help="use persistent GEMM producers") - parser.add_argument("--print_source", action="store_true", help="print kernel source code") - parser.add_argument("--warmup", type=int, default=5, help="number of warmup iterations") - parser.add_argument("--repeat", type=int, default=10, help="number of repeat iterations") - return parser.parse_args() - - -if __name__ == '__main__': - assert torch.cuda.get_device_capability()[0] >= 9, '❗This benchmark requires sm_90 or higher' - - WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) - assert WORLD_SIZE <= 8, "This benchmark is designed for intra-node GEMM-RS" - - args = parse_args() - M, N, K, dtype, threads, warmup, repeat = args.M, args.N, args.K, args.dtype, args.threads, args.warmup, args.repeat - PE_num = WORLD_SIZE - assert M % PE_num == 0 and K % PE_num == 0, "M and K must be divisible by PE_num" - M_per_rank, K_per_rank = M // PE_num, K // PE_num - torch_dtype = dtype_map[dtype] - - ## Inputs: input (M, K_per_rank), weight (N, K_per_rank) - ## Output: rs(input@weight.T) (M_per_rank, N) - - input = torch.randn([M, K_per_rank], dtype=torch_dtype, device="cuda") - weight = torch.randn([N, K_per_rank], dtype=torch_dtype, device="cuda") - - # Benchmark Torch (non-overlapped baseline) - def torch_gemm_rs(): - local_output = input @ weight.T - rs_output = torch.empty((M // PE_num, N), dtype=torch_dtype, device="cuda") - dist.reduce_scatter_tensor(rs_output, local_output, group=TP_GROUP) - return rs_output - - dist.barrier(TP_GROUP) - torch_out, torch_t = perf_fn(torch_gemm_rs, warmup, repeat) - print(f"rank {RANK} torch GEMM-RS avg time: {torch_t} ms") - - # TODO(wt) Add Triton-dist baseline (overlapped) - - # Benchmark Tilelang-dist (overlapped) - if args.persistent: - print("Use persistent GEMM producers...") - else: - print("Use non-persistent GEMM producers...") - - def tilelang_gemm_rs(): - return overlapped_gemm_rs( - input, weight, rank=RANK, num_ranks=PE_num, persistent=args.persistent) - - dist.barrier(TP_GROUP) - tl_out, tl_t = perf_fn(tilelang_gemm_rs, warmup, repeat) - print(f"rank {RANK} tilelang GEMM avg time: {tl_t} ms") - - # Check correctness - assert torch.allclose( - tl_out, torch_out, atol=1e-2, rtol=1e-2), f'max error: {(tl_out - torch_out).abs().max()}' - print(f"rank {RANK} check passed.✅") - - dist.destroy_process_group() diff --git a/benchmark/distributed/benchmark_reduce_scatter.py b/benchmark/distributed/benchmark_reduce_scatter.py deleted file mode 100644 index c6431f79a..000000000 --- a/benchmark/distributed/benchmark_reduce_scatter.py +++ /dev/null @@ -1,149 +0,0 @@ -# This benchmark requires GPU arch sm_90 or above. - -import argparse -import torch -import torch.distributed as dist -from triton_dist.kernels.nvidia.reduce_scatter import reduce_scatter_ring_push_1d_intra_node_ce -import pynvshmem -import tilelang -import tilelang.language as T -from tilelang.distributed import init_distributed, dtype_map, perf_fn - -tilelang.disable_cache() - -#TODO: Bench on 4/8 H100 -#TODO: split N? -'''init_nvshmem_by_torch_process_group(_TP_GROUP) -Note: Minor numerical differences exist between Triton/TileLang and Torch (~1e-2) -due to the order reductions are handled in different implementations. -(No error when #PE = 2) -''' - - -def reducescatter(PE_num, M, N, dtype="float16", threads=128): - M_per_rank = M // PE_num - block_M = 1 - acc_dtype = "float32" - - @T.prim_func - def pull_reduce( - A: T.Tensor((M, N), dtype), # type: ignore - B: T.Tensor((M_per_rank, N), dtype), # type: ignore - ): - with T.Kernel(M_per_rank // block_M, threads=threads) as (bx): - mype = T.get_pe() - - A_shared = T.alloc_shared((PE_num, block_M, N), dtype) - A_local = T.alloc_fragment((PE_num, block_M, N), dtype) - A_local_sum = T.alloc_fragment((block_M, N), acc_dtype) - - for i in T.serial(PE_num - 1): - peer = (mype + i + 1) % PE_num - T.getmem_nbi_block( - T.address_of(A_shared[peer, 0, 0]), - T.address_of(A[mype * M_per_rank + bx * block_M, 0]), - block_M * N * dtype_map[dtype].itemsize, peer) - base = mype * M_per_rank + bx * block_M - T.copy(A[base:base + block_M, :], A_shared[mype, :, :]) - - T.fence() # Ensure reduce happens after all IO - - T.copy(A_shared, A_local) - T.reduce_sum(A_local, A_local_sum, dim=0) - T.copy(A_local_sum, B[bx * block_M:bx * block_M + block_M, :]) - - return pull_reduce - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--M", type=int, default=8192) - parser.add_argument("--N", type=int, default=16384) - parser.add_argument( - "--dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"]) - parser.add_argument("--threads", type=int, default=128, help="number of threads in a block") - parser.add_argument("--print_source", action="store_true", help="print kernel source code") - parser.add_argument("--warmup", type=int, default=5, help="number of warmup iterations") - parser.add_argument("--repeat", type=int, default=10, help="number of repeat iterations") - return parser.parse_args() - - -if __name__ == '__main__': - assert torch.cuda.get_device_capability()[0] >= 9, '❗This benchmark requires sm_90 or higher' - - WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP = init_distributed(return_tp_group=True) - assert WORLD_SIZE <= 8, "This benchmark is designed for intra-node RS" - - args = parse_args() - M, N, dtype, threads, warmup, repeat = args.M, args.N, args.dtype, args.threads, args.warmup, args.repeat - PE_num = WORLD_SIZE - assert M % PE_num == 0, "M must be divisible by PE_num" - M_per_rank = M // PE_num - torch_dtype = dtype_map[dtype] - nelems = M * PE_num - - func = reducescatter(PE_num, M, N, dtype=dtype, threads=threads) - kernel = tilelang.compile(func, pass_configs={"tl.disable_tma_lower": True}, target='cuda') - - # Get CUDA Source - if RANK == 0 and args.print_source: - print(kernel.get_kernel_source()) - - local_data = torch.randn([M, N], dtype=torch_dtype).cuda() - - ## Input: [M, N] per rank - ## Output: [M_per_rank, N] per rank - - # Benchmark Torch - def torch_rs(): - out = torch.empty((M_per_rank, N), dtype=torch_dtype).cuda() - dist.reduce_scatter_tensor(out, local_data, group=TP_GROUP) - return out - - dist.barrier(TP_GROUP) - torch_out, torch_t = perf_fn(torch_rs, warmup, repeat) - print(f"rank {RANK} torch reduce_scatter avg time: {torch_t} ms") - - # Benchmark Triton-dist - def triton_rs(): - # Currently only support 'ce' implementation - input_buffer = pynvshmem.nvshmem_create_tensor([M, N], torch_dtype) - input_buffer.copy_(local_data) - input_flag = torch.ones((PE_num,), device="cuda", dtype=torch.int32) - symm_reduce_buffers = pynvshmem.nvshmem_create_tensor_list_intra_node([M, N], torch_dtype) - symm_reduce_flags = pynvshmem.nvshmem_create_tensor_list_intra_node((PE_num,), torch.int32) - symm_reduce_flags[RANK].zero_() - pynvshmem.nvshmemx_barrier_all_on_stream(torch.cuda.current_stream().cuda_stream) - output = reduce_scatter_ring_push_1d_intra_node_ce( - RANK, - WORLD_SIZE, - input_buffer, - input_flag, - symm_reduce_buffers, - symm_reduce_flags, - ) - pynvshmem.nvshmemx_barrier_all_on_stream(torch.cuda.current_stream().cuda_stream) - return output - - dist.barrier(TP_GROUP) - tt_out, tt_t = perf_fn(triton_rs, warmup, repeat) - print(f"rank {RANK} triton reduce_scatter avg time: {tt_t} ms") - - # Benchmark Tilelang-dist - def tilelang_rs(): - rs_buffer = pynvshmem.nvshmem_create_tensor([M, N], torch_dtype) - rs_buffer.copy_(local_data) - out = pynvshmem.nvshmem_create_tensor([M_per_rank, N], torch_dtype) - kernel(rs_buffer, out) - return out - - dist.barrier(TP_GROUP) - tl_out, tl_t = perf_fn(tilelang_rs, warmup, repeat) - print(f"rank {RANK} tilelang reduce_scatter avg time: {tl_t} ms") - - # Check correctness - assert torch.allclose( - tl_out, torch_out, atol=1e-2, rtol=1e-2), f'max error: {(tt_out - torch_out).abs().max()}' - print(f"rank {RANK} check passed.✅") - - dist.destroy_process_group() diff --git a/benchmark/distributed/ipc_impls/README.md b/benchmark/distributed/ipc_impls/README.md deleted file mode 100644 index d89d00956..000000000 --- a/benchmark/distributed/ipc_impls/README.md +++ /dev/null @@ -1,34 +0,0 @@ -# Benchmarks for IPC communication - -This benchmark aims to measure and compare the bandwidth of different implementations of IPC communication: -We launch only one block on each rank to avoid NVLink bandwidth as the bottleneck. - -## NVSHMEM-based push/pull -```bash -GPUS=2 bash tilelang/distributed/launch.sh benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py -``` - -## Unrolled-copy implemented in TileScale (*ours*) -```bash -export TILELANG_USE_DISTRIBUTED=1 -python benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py -``` - -## Results on Hopper connected by NVLink -| Size (Bytes) | NVSHMEM Push BW (GB/s) | NVSHMEM Pull BW (GB/s) | TileScale Push BW (GB/s) | TileScale Pull BW (GB/s) | -|---------------:|----------------------:|-----------------------:|-------------------------:|--------------------------:| -| 2,048 | 0.1680 | 0.1755 | 0.0632 | 0.0628 | -| 4,096 | 0.3415 | 0.4082 | 0.1316 | 0.1284 | -| 8,192 | 0.6836 | 0.8497 | 0.2601 | 0.2628 | -| 16,384 | 1.4119 | 1.6178 | 0.5241 | 0.5232 | -| 32,768 | 2.4592 | 1.8878 | 1.0178 | 1.1283 | -| 65,536 | 4.9380 | 2.0408 | 2.0380 | 1.9723 | -| 131,072 | 8.7134 | 2.1465 | 3.9668 | 2.1001 | -| 262,144 | 9.0743 | 2.1935 | 8.0200 | 2.1920 | -| 524,288 | 10.0191 | 2.2156 | 10.7943 | 2.2509 | -| 1,048,576 | 10.4359 | 2.2352 | 11.4781 | 2.2648 | -| 2,097,152 | 10.5573 | 2.2456 | 11.7712 | 2.2796 | -| 4,194,304 | 10.6560 | 2.2474 | 11.9145 | 2.2845 | - -> **Note:** All data presented above are unidirectional bandwidth. - diff --git a/benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py b/benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py deleted file mode 100644 index 5ab6265ae..000000000 --- a/benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py +++ /dev/null @@ -1,109 +0,0 @@ -# This benchmark aims to measure the bandwidth of NVHSMEM-based communication. -# We launch only one block on each rank to avoid NVLink bandwidth as the bottleneck. - -# Usage: GPUS=2 bash tilelang/distributed/launch.sh benchmark/distributed/benchmark_nvshmem_p2p.py - -import os -import tilelang -import tilelang.language as T -import argparse -import torch -import torch.distributed as dist -from tilelang.distributed import init_distributed, perf_fn -import pynvshmem - -os.environ['NCCL_DEBUG'] = 'WARN' - - -def nvshmem_kernel_push(size, threads): - - @T.prim_func - def nvshmem_push( - dst: T.Tensor((size), "float32"), # type: ignore - src: T.Tensor((size), "float32"), # type: ignore - ): - with T.Kernel(1, threads=threads): - T.putmem_block( - T.address_of(dst), - T.address_of(src), - size * 4, - T.get_pe() ^ 1, - ) - T.fence_sys() - - return nvshmem_push - - -def nvshmem_kernel_pull(size, threads): - - @T.prim_func - def nvshmem_pull( - dst: T.Tensor((size), "float32"), # type: ignore - src: T.Tensor((size), "float32"), # type: ignore - ): - with T.Kernel(1, threads=threads): - T.getmem_block( - T.address_of(dst), - T.address_of(src), - size * 4, - T.get_pe() ^ 1, - ) - T.fence_sys() - - return nvshmem_pull - - -def benchmark_nvshmem_bw(rank: int, num_ranks: int, group: dist.ProcessGroup, size: int, - args: argparse.Namespace): - assert num_ranks == 2, "this benchmark only supports 2 ranks" - assert args.threads % 32 == 0, "threads must be divisible by 32" - - kernel = tilelang.compile(nvshmem_kernel_push(size, args.threads)) - src = pynvshmem.nvshmem_create_tensor([size], torch.float32) - dst = pynvshmem.nvshmem_create_tensor([size], torch.float32) - - def push_fn(): - kernel(dst, src) - - dist.barrier(group) - torch.cuda.synchronize() - _, t_push = perf_fn(push_fn, args.warmup, args.repeat) # 1st returned value is output - bw_push = (size * 4 * 1e-9) / (t_push * 1e-3) - - dist.barrier(group) - - # Reuse allocator and tensors - kernel = tilelang.compile(nvshmem_kernel_pull(size, args.threads)) - - def pull_fn(): - kernel(dst, src) - - dist.barrier(group) - torch.cuda.synchronize() - _, t_pull = perf_fn(pull_fn, args.warmup, args.repeat) # 1st returned value is output - bw_pull = (size * 4 * 1e-9) / (t_pull * 1e-3) - - dist.barrier(group) - - return bw_push, bw_pull - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--warmup", type=int, default=10, help="number of warmup iterations (default: 10)") - parser.add_argument( - "--repeat", type=int, default=50, help="number of repeat iterations (default: 50)") - parser.add_argument("--threads", type=int, default=128, help="Threads per block (default: 128)") - args = parser.parse_args() - - num_ranks, rank, _, group = init_distributed(return_tp_group=True) - for log_size in range(9, 21): - size = 2**log_size - push_bw, pull_bw = benchmark_nvshmem_bw(rank, num_ranks, group, size, args) - if rank == 0: - print( - f"size={size*4} bytes, nvshmem push bw: {push_bw:.4f} GB/s, nvshmem pull bw: {pull_bw:.4f} GB/s" - ) - - dist.destroy_process_group() diff --git a/benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py b/benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py deleted file mode 100644 index c7d3f2556..000000000 --- a/benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py +++ /dev/null @@ -1,132 +0,0 @@ -import os -import tilelang -import tilelang.language as T -import argparse -import torch -import torch.distributed as dist -import torch.multiprocessing -from tilelang.distributed import init_dist, perf_fn - -tilelang.disable_cache() -os.environ['NCCL_DEBUG'] = 'WARN' - - -def ipc_kernel_push(size, threads, unroll_factor): - - @T.prim_func - def ipc_push( - dst: T.Tensor((size), "float32"), # type: ignore - src: T.Tensor((size), "float32"), # type: ignore - ): - with T.Kernel(1, threads=threads): - rank = T.alloc_local([1], "uint64") - rank[0] = T.get_rank() - warp_idx = T.get_thread_binding(0) // 32 - warp_copy_size = T.ceildiv(size, threads // 32) - warp_start = warp_copy_size * warp_idx - T.put_warp( - src=T.address_of(src[warp_start]), - dst=T.address_of(dst[warp_start]), - size=warp_copy_size, - dst_pe=rank[0] ^ 1, - unroll_factor=unroll_factor) - T.fence_sys() - - return ipc_push - - -def ipc_kernel_pull(size, threads, unroll_factor): - - @T.prim_func - def ipc_pull( - dst: T.Tensor((size), "float32"), # type: ignore - src: T.Tensor((size), "float32"), # type: ignore - ): - with T.Kernel(1, threads=threads): - rank = T.alloc_local([1], "uint64") - rank[0] = T.get_rank() - warp_idx = T.get_thread_binding(0) // 32 - warp_copy_size = T.ceildiv(size, threads // 32) - warp_start = warp_copy_size * warp_idx - T.get_warp( - src=T.address_of(src[warp_start]), - dst=T.address_of(dst[warp_start]), - size=warp_copy_size, - src_pe=rank[0] ^ 1, - unroll_factor=unroll_factor) - T.fence_sys() - - return ipc_pull - - -def benchmark_ipc_bw(rank: int, num_ranks: int, group: dist.ProcessGroup, size: int, - args: argparse.Namespace, allocator): - assert num_ranks == 2, "this benchmark only supports 2 ranks" - assert args.threads % 32 == 0, "threads must be divisible by 32" - - kernel = tilelang.compile(ipc_kernel_push(size, args.threads, args.unroll_factor)) - kernel.initialize(allocator=allocator) - src = tilelang.tensor((size,), torch.float32, allocator=allocator).random_() - dst = tilelang.tensor((size,), torch.float32, allocator=allocator) - - def push_fn(): - kernel(dst, src) - - dist.barrier(group) - torch.cuda.synchronize() - _, t_push = perf_fn(push_fn, args.warmup, args.repeat) # 1st returned value is output - bw_push = (size * 4 * 1e-9) / (t_push * 1e-3) - - dist.barrier(group) - - # Reuse allocator and tensors - kernel = tilelang.compile(ipc_kernel_pull(size, args.threads, args.unroll_factor)) - kernel.initialize(allocator=allocator) - - def pull_fn(): - kernel(dst, src) - - dist.barrier(group) - torch.cuda.synchronize() - _, t_pull = perf_fn(pull_fn, args.warmup, args.repeat) # 1st returned value is output - bw_pull = (size * 4 * 1e-9) / (t_pull * 1e-3) - - dist.barrier(group) - - return bw_push, bw_pull - - -def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): - rank, num_ranks, group = init_dist(local_rank, num_local_ranks) - - allocator = tilelang.get_allocator( - size=2**30, - device="cuda", - is_distributed=True, - local_rank=rank, - num_local_ranks=num_ranks, - group=group) - - for log_size in range(9, 21): - size = 2**log_size - push_bw, pull_bw = benchmark_ipc_bw(rank, num_ranks, group, size, args, allocator) - if rank == 0: - print( - f"size={size*4} bytes, ipc push bw: {push_bw:.4f} GB/s, ipc pull bw: {pull_bw:.4f} GB/s" - ) - - dist.destroy_process_group() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--warmup", type=int, default=10, help="number of warmup iterations (default: 10)") - parser.add_argument( - "--repeat", type=int, default=50, help="number of repeat iterations (default: 50)") - parser.add_argument("--threads", type=int, default=128, help="Threads per block (default: 128)") - parser.add_argument("--unroll-factor", type=int, default=4, help="Unroll factor (default: 4)") - args = parser.parse_args() - nprocs = 2 - - torch.multiprocessing.spawn(main, args=(nprocs, args), nprocs=nprocs) diff --git a/benchmark/distributed/utils.py b/benchmark/distributed/utils.py deleted file mode 100644 index fba164121..000000000 --- a/benchmark/distributed/utils.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -import pynvshmem - -DTYPE_MAP = { - "bfloat16": torch.bfloat16, - "float16": torch.float16, - "float8_e4m3fn": torch.float8_e4m3fn, - "float8_e5m2": torch.float8_e5m2, - "s8": torch.int8, - "s32": torch.int32, - "float32": torch.float32, -} - - -class AllToAllContext: - - def __init__( - self, - max_m: int, - hidden: int, - rank: int, - num_tot_experts: int, - WORLD_SIZE: int, - experts_per_rank: int, - dtype=torch.bfloat16, - scale_dtype=torch.float, - ): - """ - max_m: max number of tokens per rank - """ - self.send_buf = pynvshmem.nvshmem_create_tensor([max_m, hidden], dtype) - self.recv_buf = pynvshmem.nvshmem_create_tensor([WORLD_SIZE * max_m * 2, hidden], dtype) - self.scale_send_buf = pynvshmem.nvshmem_create_tensor([max_m], scale_dtype) - self.scale_recv_buf = pynvshmem.nvshmem_create_tensor([WORLD_SIZE * max_m * 2], scale_dtype) - self.split_send_buf = pynvshmem.nvshmem_create_tensor([num_tot_experts], torch.int32) - self.split_recv_buf = pynvshmem.nvshmem_create_tensor([num_tot_experts * 2], torch.int32) - self.signal_buf = pynvshmem.nvshmem_create_tensor([WORLD_SIZE * 2], torch.uint64) - - self.max_m = max_m - self.hidden = hidden - self.dtype = dtype - self.scale_dtype = scale_dtype - self.ele_size = self.dtype.itemsize - self.scale_ele_size = self.scale_dtype.itemsize - - self.num_tot_experts = num_tot_experts - self.experts_per_rank = experts_per_rank - - self.WORLD_SIZE = WORLD_SIZE - self.rank = rank - - # start from 1, because the initial values of signal buffer is 0 - self.call_count = 1 - self.MOD_VALUE = 1000000 - - -def create_all_to_all_context( - max_m: int, - hidden: int, - rank: int, - num_tot_experts: int, - WORLD_SIZE: int, - experts_per_rank: int, - dtype=torch.bfloat16, - scale_dtype=torch.float, -): - return AllToAllContext( - max_m, - hidden, - rank, - num_tot_experts, - WORLD_SIZE, - experts_per_rank, - dtype, - scale_dtype, - ) diff --git a/docs/get_started/Installation.md b/docs/get_started/Installation.md index f441d1a83..62c0a48b5 100644 --- a/docs/get_started/Installation.md +++ b/docs/get_started/Installation.md @@ -58,14 +58,14 @@ TILELANG_USE_DISTRIBUTED=1 python examples/distributed/example_allgather_gemm_ov ## To use NVSHMEM APIs -Before running the examples using NVSHMEM APIs (e.g., [example_allgather.py](../../examples/distributed/example_allgather.py)), you need to build NVSHMEM library for device-side code generation. +Before running the examples using NVSHMEM APIs (e.g., [example_allgather.py](../../examples/distributed/example_allgather.py)), you need to install NVSHMEM dependencies. +You can set environment variable `NVSHMEM_SRC` to force using NVSHMEM built from source. ```bash -pip install mpich # building NVSHMEM needs MPI -export NVSHMEM_SRC="your_custom_nvshmem_dir" # default to 3rdparty/nvshmem_src -cd tilelang/distributed -source build_nvshmem.sh +export NVSHMEM_SRC="your_custom_nvshmem_dir" ``` + +You can also skip this and use the default pre-installed NVSHMEM wheel. You also need to install the `pynvshmem` package, which provides wrapped host-side Python API for NVSHMEM. ```bash diff --git a/examples/distributed/README.md b/examples/distributed/README.md deleted file mode 100644 index e73ae0fac..000000000 --- a/examples/distributed/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# Distributed Examples - -This directory contains examples demonstrating distributed computing capabilities using TileLang. - -For example, -``` -./tilelang/distributed/launch.sh examples/distributed/example_allgather.py -``` - -## Prerequisites - -Before running the examples, you need to build NVSHMEM library for device-side code generation. - -```bash -export NVSHMEM_SRC="your_custom_nvshmem_dir" # default to 3rdparty/nvshmem_src -cd tilelang/distributed -source build_nvshmem.sh -``` -You also need to install the `pynvshmem` package, which provides wrapped host-side Python API for NVSHMEM. - -```bash -cd ./pynvshmem -python setup.py install -export LD_LIBRARY_PATH="$NVSHMEM_SRC/build/src/lib:$LD_LIBRARY_PATH" -``` - -Then you can test python import: -```bash -python -c "import pynvshmem" -``` diff --git a/examples/distributed/example_all_to_all.py b/examples/distributed/nvshmem_legacy/example_all_to_all.py similarity index 100% rename from examples/distributed/example_all_to_all.py rename to examples/distributed/nvshmem_legacy/example_all_to_all.py diff --git a/examples/distributed/example_allgather.py b/examples/distributed/nvshmem_legacy/example_allgather.py similarity index 100% rename from examples/distributed/example_allgather.py rename to examples/distributed/nvshmem_legacy/example_allgather.py diff --git a/examples/distributed/example_allgather_gemm.py b/examples/distributed/nvshmem_legacy/example_allgather_gemm.py similarity index 100% rename from examples/distributed/example_allgather_gemm.py rename to examples/distributed/nvshmem_legacy/example_allgather_gemm.py diff --git a/examples/distributed/example_cannon.py b/examples/distributed/nvshmem_legacy/example_cannon.py similarity index 100% rename from examples/distributed/example_cannon.py rename to examples/distributed/nvshmem_legacy/example_cannon.py diff --git a/examples/distributed/example_nvshmem.py b/examples/distributed/nvshmem_legacy/example_nvshmem.py similarity index 100% rename from examples/distributed/example_nvshmem.py rename to examples/distributed/nvshmem_legacy/example_nvshmem.py diff --git a/examples/distributed/example_overlapping_allgather.py b/examples/distributed/nvshmem_legacy/example_overlapping_allgather.py similarity index 100% rename from examples/distributed/example_overlapping_allgather.py rename to examples/distributed/nvshmem_legacy/example_overlapping_allgather.py diff --git a/examples/distributed/example_post_attn_all2all_transpose.py b/examples/distributed/nvshmem_legacy/example_post_attn_all2all_transpose.py similarity index 100% rename from examples/distributed/example_post_attn_all2all_transpose.py rename to examples/distributed/nvshmem_legacy/example_post_attn_all2all_transpose.py diff --git a/examples/distributed/example_pre_attn_all2all.py b/examples/distributed/nvshmem_legacy/example_pre_attn_all2all.py similarity index 100% rename from examples/distributed/example_pre_attn_all2all.py rename to examples/distributed/nvshmem_legacy/example_pre_attn_all2all.py diff --git a/examples/distributed/example_pre_attn_all2all_transpose.py b/examples/distributed/nvshmem_legacy/example_pre_attn_all2all_transpose.py similarity index 100% rename from examples/distributed/example_pre_attn_all2all_transpose.py rename to examples/distributed/nvshmem_legacy/example_pre_attn_all2all_transpose.py diff --git a/examples/distributed/example_simple_shift.py b/examples/distributed/nvshmem_legacy/example_simple_shift.py similarity index 100% rename from examples/distributed/example_simple_shift.py rename to examples/distributed/nvshmem_legacy/example_simple_shift.py diff --git a/examples/distributed/example_summa.py b/examples/distributed/nvshmem_legacy/example_summa.py similarity index 100% rename from examples/distributed/example_summa.py rename to examples/distributed/nvshmem_legacy/example_summa.py diff --git a/examples/distributed/gemm_rs_utils.py b/examples/distributed/nvshmem_legacy/gemm_rs_utils.py similarity index 100% rename from examples/distributed/gemm_rs_utils.py rename to examples/distributed/nvshmem_legacy/gemm_rs_utils.py diff --git a/testing/python/language/test_tilelang_language_ldst_options.py b/testing/python/distributed/test_tilelang_language_ldst_options.py similarity index 100% rename from testing/python/language/test_tilelang_language_ldst_options.py rename to testing/python/distributed/test_tilelang_language_ldst_options.py diff --git a/tilelang/distributed/build_nvshmem.sh b/tilelang/distributed/build_nvshmem.sh deleted file mode 100644 index 8f4d44d1d..000000000 --- a/tilelang/distributed/build_nvshmem.sh +++ /dev/null @@ -1,103 +0,0 @@ -#!/bin/bash - -if [ -z "${NVSHMEM_SRC}" ]; then - export NVSHMEM_SRC="$(realpath ../../3rdparty/nvshmem_src)" - echo "NVSHMEM_SRC not set, defaulting to ${NVSHMEM_SRC}" -else - NVSHMEM_SRC="$(realpath ${NVSHMEM_SRC})" - echo "Using NVSHMEM_SRC=${NVSHMEM_SRC}" -fi - -if [ -d "${NVSHMEM_SRC}" ]; then - if [ "$(ls -A ${NVSHMEM_SRC})" ]; then - echo "NVSHMEM_SRC directory (${NVSHMEM_SRC}) is not empty, cleaning it..." - rm -rf "${NVSHMEM_SRC}/"* - rm -rf "${NVSHMEM_SRC}/".* 2>/dev/null || true - fi -else - mkdir -p "${NVSHMEM_SRC}" -fi - -wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz -tar zxvf nvshmem_src_3.2.5-1.txz -rm -rf nvshmem_src_3.2.5-1.txz - -mkdir -p "${NVSHMEM_SRC}" -mv nvshmem_src/* "${NVSHMEM_SRC}/" -mv nvshmem_src/.* "${NVSHMEM_SRC}/" 2>/dev/null || true -rmdir nvshmem_src - - -export NVSHMEM_PATH="${NVSHMEM_SRC}" - -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -PROJECT_ROOT=$(realpath "${SCRIPT_DIR}") -echo "SCRIPT_DIR: ${SCRIPT_DIR}" -echo "PROJECT_ROOT: ${PROJECT_ROOT}" -echo "NVSHMEM will be installed to: ${NVSHMEM_SRC}" - -ARCH="" -JOBS="" - -# Iterate over the command-line arguments -while [[ $# -gt 0 ]]; do - key="$1" - - case $key in - --arch) - # Process the arch argument - ARCH="$2" - shift # Skip the argument value - shift # Skip the argument key - ;; - --jobs) - # Process the jobs argument - JOBS="$2" - shift # Skip the argument value - shift # Skip the argument key - ;; - *) - # Unknown argument - echo "Unknown argument: $1" - shift # Skip the argument - ;; - esac -done - -if [[ -n "${ARCH}" ]]; then - export CMAKE_CUDA_ARCHITECTURES="${ARCH}" - CUDAARCH_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${ARCH}" -fi - -if [[ -z "${JOBS}" ]]; then - JOBS=$(nproc --ignore 2) -fi - -export NVSHMEM_IBGDA_SUPPORT=0 -export NVSHMEM_IBGDA_SUPPORT_GPUMEM_ONLY=0 -export NVSHMEM_IBDEVX_SUPPORT=0 -export NVSHMEM_IBRC_SUPPORT=1 -export NVSHMEM_LIBFABRIC_SUPPORT=0 -export NVSHMEM_MPI_SUPPORT=1 -export NVSHMEM_USE_GDRCOPY=0 -export NVSHMEM_TORCH_SUPPORT=1 -export NVSHMEM_ENABLE_ALL_DEVICE_INLINING=1 - -pushd "${NVSHMEM_SRC}" -mkdir -p build -cd build -CMAKE=${CMAKE:-cmake} - -if [ ! -f CMakeCache.txt ]; then - ${CMAKE} .. \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ - ${CUDAARCH_ARGS} \ - -DNVSHMEM_BUILD_TESTS=OFF \ - -DNVSHMEM_BUILD_EXAMPLES=OFF \ - -DNVSHMEM_BUILD_PACKAGES=OFF -fi - -make VERBOSE=1 -j"${JOBS}" -popd - -echo "NVSHMEM installed successfully to ${NVSHMEM_SRC}" \ No newline at end of file diff --git a/tilelang/distributed/pynvshmem/CMakeLists.txt b/tilelang/distributed/pynvshmem/CMakeLists.txt index 6310ce52e..dfd323a4e 100644 --- a/tilelang/distributed/pynvshmem/CMakeLists.txt +++ b/tilelang/distributed/pynvshmem/CMakeLists.txt @@ -44,7 +44,17 @@ if(TORCH_CXX_FLAGS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") endif() -find_package(NVSHMEM REQUIRED) +# find NVSHMEM from tilelang.env +execute_process( + COMMAND ${PYTHON_EXECUTABLE} "-c" + "from tilelang.env import env; print(env.NVSHMEM_INCLUDE_DIR); print(env.NVSHMEM_LIB_PATH)" + OUTPUT_VARIABLE NVSHMEM_PATHS) +string(REGEX REPLACE "\n" ";" NVSHMEM_PATHS "${NVSHMEM_PATHS}") +list(GET NVSHMEM_PATHS 0 NVSHMEM_INCLUDE_DIR) +list(GET NVSHMEM_PATHS 1 NVSHMEM_LIB_DIR) +set(NVSHMEM_INCLUDE_DIRS ${NVSHMEM_INCLUDE_DIR}) +find_library(NVSHMEM_HOST_LIBRARY nvshmem_host PATHS ${NVSHMEM_LIB_DIR} NO_DEFAULT_PATH REQUIRED) +find_library(NVSHMEM_DEVICE_LIBRARY nvshmem_device PATHS ${NVSHMEM_LIB_DIR} NO_DEFAULT_PATH REQUIRED) if(NOT CMAKE_CUDA_ARCHITECTURES) set(CMAKE_CUDA_ARCHITECTURES native) @@ -57,9 +67,10 @@ message(STATUS "CUDA include directories: ${CUDA_INCLUDE_DIRS}") set_target_properties(pynvshmem PROPERTIES CXX_STANDARD 17 CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(pynvshmem PRIVATE nvshmem::nvshmem_host - nvshmem::nvshmem_device +target_link_libraries(pynvshmem PRIVATE ${NVSHMEM_HOST_LIBRARY} + ${NVSHMEM_DEVICE_LIBRARY} torch ${TORCH_PYTHON_LIBRARY}) target_include_directories(pynvshmem PRIVATE ${NVSHMEM_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS}) +target_link_directories(pynvshmem PRIVATE ${NVSHMEM_LIB_DIR}) target_compile_options(pynvshmem PRIVATE $<$:-rdc=true>) diff --git a/tilelang/distributed/pynvshmem/setup.py b/tilelang/distributed/pynvshmem/setup.py index fe45c762e..3d4237418 100644 --- a/tilelang/distributed/pynvshmem/setup.py +++ b/tilelang/distributed/pynvshmem/setup.py @@ -9,6 +9,8 @@ import setuptools from torch.utils.cpp_extension import BuildExtension +from tilelang.env import env + # Project directory root root_path: Path = Path(__file__).resolve().parent PACKAGE_NAME = "pynvshmem" @@ -60,11 +62,57 @@ def wrapper(*kargs, **kwargs): return wrapper +def ensure_nvshmem_symlinks(): + """ + Ensure symbolic links exist for NVSHMEM libraries. + + The nvidia-nvshmem-cu12 wheel provides versioned libraries (e.g., libnvshmem_host.so.3), + but the linker expects unversioned names (e.g., libnvshmem_host.so). + This function creates the necessary symlinks automatically during build. + """ + if env.NVSHMEM_LIB_PATH is None: + return + + lib_path = Path(env.NVSHMEM_LIB_PATH) + if not lib_path.exists(): + return + + # Map of expected symlink name to the versioned library file pattern + symlink_map = { + "libnvshmem_host.so": "libnvshmem_host.so.*", + "libnvshmem_device.a": "libnvshmem_device.a", # This one might already be correct + } + + for symlink_name, pattern in symlink_map.items(): + symlink_path = lib_path / symlink_name + + # Skip if symlink already exists and is valid + if symlink_path.exists() or symlink_path.is_symlink(): + continue + + # Find the versioned library file + versioned_libs = list(lib_path.glob(pattern)) + if not versioned_libs: + continue + + # Use the first match (or latest if multiple) + target = versioned_libs[0].name + + try: + # Create the symlink + symlink_path.symlink_to(target) + print(f"Created symlink: {symlink_path} -> {target}") + except Exception as e: + print(f"Warning: Could not create symlink {symlink_path}: {e}") + + @pathlib_wrapper def nvshmem_deps(): - nvshmem_home = Path(os.environ.get("NVSHMEM_SRC", root_path / "../../../3rdparty/nvshmem_src")) - include_dirs = [nvshmem_home / "build/src/include"] - library_dirs = [nvshmem_home / "build/src/lib"] + # Ensure symlinks exist before returning dependencies + ensure_nvshmem_symlinks() + + include_dirs = [env.NVSHMEM_INCLUDE_DIR] + library_dirs = [env.NVSHMEM_LIB_PATH] libraries = ["nvshmem_host", "nvshmem_device"] return include_dirs, library_dirs, libraries diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 4bd77c8c8..3cc8978de 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -67,9 +67,9 @@ def tilelang_callback_cuda_compile(code, target): else: cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include")) if env.USE_DISTRIBUTED: - if os.environ.get("NVSHMEM_SRC", None) is not None: - nvshmem_include_path = os.environ["NVSHMEM_SRC"] + "/build/src/include" - nvshmem_lib_path = os.environ["NVSHMEM_SRC"] + "/build/src/lib" + if env.NVSHMEM_SRC: + nvshmem_include_path = env.NVSHMEM_INCLUDE_DIR + nvshmem_lib_path = env.NVSHMEM_LIB_PATH else: raise ValueError("NVSHMEM_SRC is not set") target_arch = nvcc.get_target_arch(nvcc.get_target_compute_version(target)) @@ -87,7 +87,7 @@ def tilelang_callback_cuda_compile(code, target): "-I" + cutlass_path, ] if env.USE_DISTRIBUTED: - if os.environ.get("NVSHMEM_SRC", None) is not None: + if env.NVSHMEM_SRC: options += [ "-I" + nvshmem_include_path, "-L" + nvshmem_lib_path, diff --git a/tilelang/env.py b/tilelang/env.py index 6db3ee33b..e40cc84b0 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -18,6 +18,7 @@ TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path") ", which may lead to compilation bugs when utilize tilelang backend." TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path") +NVSHMEM_USE_DEFAULT_WHEEL_MESSAGE = ("NVSHMEM_SRC is not set, using the default NVSHMEM wheel: %s") TL_ROOT = os.path.dirname(os.path.abspath(__file__)) TL_LIBS = [TL_ROOT, os.path.join(TL_ROOT, 'lib')] @@ -253,12 +254,21 @@ class Environment: USE_NVSHMEM = EnvVar("TILELANG_USE_NVSHMEM", "0").get().lower() in ("1", "true", "on") if USE_DISTRIBUTED: if EnvVar("NVSHMEM_SRC", None).get() is not None: + # built from source NVSHMEM_SRC = EnvVar("NVSHMEM_SRC", None).get() + NVSHMEM_INCLUDE_DIR: str = NVSHMEM_SRC + "/build/src/include" + NVSHMEM_LIB_PATH: str = NVSHMEM_SRC + "/build/src/lib" else: - NVSHMEM_SRC = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "nvshmem_src") - NVSHMEM_INCLUDE_DIR: str = NVSHMEM_SRC + "/build/src/include" - NVSHMEM_LIB_PATH: str = NVSHMEM_SRC + "/build/src/lib" + # installed from wheel + from importlib.metadata import distribution + + root = pathlib.Path(distribution("nvidia-nvshmem-cu12").locate_file("")) + nvshmem_wheel_path = root / "nvidia" / "nvshmem" + logger.warning(NVSHMEM_USE_DEFAULT_WHEEL_MESSAGE % nvshmem_wheel_path) + assert nvshmem_wheel_path.exists(), f"NVSHMEM wheel path does not exist" + NVSHMEM_SRC = str(nvshmem_wheel_path) + NVSHMEM_INCLUDE_DIR = str(nvshmem_wheel_path / "include") + NVSHMEM_LIB_PATH = str(nvshmem_wheel_path / "lib") else: NVSHMEM_INCLUDE_DIR = None NVSHMEM_LIB_PATH = None