diff --git a/examples/distributed/deepseek_deepep/deepep_utils.py b/examples/distributed/deepseek_deepep/deepep_utils.py index 288640295..8a612504f 100644 --- a/examples/distributed/deepseek_deepep/deepep_utils.py +++ b/examples/distributed/deepseek_deepep/deepep_utils.py @@ -125,7 +125,7 @@ def unpack_bias(bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): # Check: DeepEP/tests/test_intranode.py:test_main def gen_inputs(num_tokens: int, hidden: int, num_topk: int, num_experts: int, num_ranks: int): - """Generate random inputs for testing purpose. + """Generate random inputs for intranode testing purpose. Args: num_tokens: the number of tokens. hidden: the hidden dimension. @@ -156,6 +156,77 @@ def gen_inputs(num_tokens: int, hidden: int, num_topk: int, num_experts: int, nu return x, topk_idx, topk_weights, rank_idx + +def gen_internode_inputs(num_tokens: int, hidden: int, num_topk_groups: int, num_topk: int, num_experts: int, num_ranks: int, num_nodes: int): + """ + Generate random inputs with group restriction (native in DeepSeek MoE) for internode testing purpose. + Args: + num_tokens: the number of tokens. + hidden: the hidden dimension. + num_topk_groups: the number of top-k groups. + num_topk: the number of top-k experts to select for each token. + num_experts: the number of experts. + num_ranks: the number of total ranks. + num_nodes: the number of nodes. + + Returns: + x: `[num_tokens, hidden]` with `torch.bfloat16`, the input to MoE layer. + topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token, + `-1` means no selections. + topk_weights: `[num_tokens, num_topk]` with `torch.float32`, the weights corresponding to + each selected expert for each token. + rank_idx: `[num_tokens, num_topk]` with `torch.int32`, the rank indices corresponding to + each selected expert, `-1` means no selections. + rdma_rank_idx: `[num_tokens, num_topk]` with `torch.int32`, the RDMA rank indices corresponding to + each selected expert, `-1` means no selections. + """ + assert num_topk <= num_experts, "num_topk must be less than or equal to num_experts" + assert num_experts % num_ranks == 0, "num_experts must be divisible by num_ranks" + + x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 + group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1) + group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices + masked_scores = create_grouped_scores(scores, group_idx, num_nodes) + topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1] + topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') + rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rank_idx, num_ranks) + num_local_ranks = num_ranks // num_nodes + rdma_rank_idx = rank_idx // num_local_ranks + rdma_rank_idx.masked_fill_(rank_idx == -1, -1) + inplace_unique(rdma_rank_idx, num_nodes) + return x, topk_idx, topk_weights, rank_idx, rdma_rank_idx + + +def inplace_unique(x: torch.Tensor, num_slots: int): + """ + Keep at most `num_slots` different values in each row of `x`, + and fill `x` with -1 in other positions. + """ + assert x.dim() == 2 + mask = x < 0 + x_padded = x.masked_fill(mask, num_slots) + bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device) + bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded)) + bin_count = bin_count[:, :num_slots] + sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True) + sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1) + sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values + x[:, :].fill_(-1) + valid_len = min(num_slots, x.size(1)) + x[:, :valid_len] = sorted_bin_idx[:, :valid_len] + + +def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int): + num_tokens, num_experts = scores.shape + scores = scores.view(num_tokens, num_groups, -1) + mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device) + mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores) + return (scores * mask).view(num_tokens, num_experts) + + def inplace_unique(x: torch.Tensor, num_slots: int): """ Keep at most `num_slots` different values in each row of `x`, diff --git a/examples/distributed/deepseek_deepep/internode/get_dispatch_layout.py b/examples/distributed/deepseek_deepep/internode/get_dispatch_layout.py new file mode 100644 index 000000000..a92b85f96 --- /dev/null +++ b/examples/distributed/deepseek_deepep/internode/get_dispatch_layout.py @@ -0,0 +1,270 @@ +# For internode +# This op is non-distributed +### python get_dispatch_layout.py + +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(__file__))) # add parent folder to path + +import torch +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench +from typing import Tuple +from argparse import ArgumentParser +from utils import gen_internode_inputs, NUM_MAX_NVL_PEERS # noqa: F403 + + +# TODO(wt): Add async functionality +def get_dispatch_layout( + topk_idx: torch.Tensor, num_experts: int, + num_ranks, num_rdma_ranks) -> Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor]: + """Calculate the layout required for later communication. + + Arguments: + topk_idx: `[num_tokens, num_topk]`, dtype must be `torch.int64`, the expert indices selected by each token, + `-1` means no selections. + num_experts: the number of experts. + num_ranks: the number of ranks. + num_rdma_ranks: the number of RDMA ranks, i.e. the number of nodes. + + Returns: + num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank. + num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA + rank (with the same GPU index), return `None` for intranode settings. + num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert. + is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank. + """ + + # Check inputs + assert topk_idx.dtype == torch.int64, "topk_idx must be of dtype torch.int64" + assert topk_idx.ndim == 2, "topk_idx must be a 2D tensor" + assert topk_idx.is_contiguous(), "topk_idx must be a contiguous tensor" + + # Allocate tensors + # TODO(wt): Wait on previous events and allocate on comm stream when adding async functionality + num_tokens, num_topk = topk_idx.shape + num_tokens_per_rank = torch.empty(num_ranks, dtype=torch.int32, device='cuda') + num_tokens_per_rdma_rank = torch.empty(num_rdma_ranks, dtype=torch.int32, device='cuda') + num_tokens_per_expert = torch.empty(num_experts, dtype=torch.int32, device='cuda') + is_token_in_rank = torch.empty((num_tokens, num_ranks), dtype=torch.bool, device='cuda') + + # Launch the kernel + kernel = get_dispatch_layout_kernel(num_topk, num_experts, num_ranks, num_rdma_ranks) + kernel( + topk_idx, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + ) + + # TODO(wt): Wait streams when adding async functionality + + return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank + + +@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +def get_dispatch_layout_kernel( + num_topk: int, + num_experts: int, + num_ranks: int, + num_rdma_ranks: int, +) -> tilelang.JITKernel: + threads = 256 + experts_per_sm = 4 + ranks_per_sm = 8 + rdma_ranks_per_sm = ranks_per_sm // NUM_MAX_NVL_PEERS + num_sms = T.ceildiv(num_experts, experts_per_sm) + T.ceildiv(num_ranks, ranks_per_sm) + experts_per_rank = num_experts // num_ranks + + num_tokens = T.dynamic('num_tokens') + + @T.prim_func + def get_dispatch_layout_main( + topk_idx: T.Tensor([num_tokens, num_topk], "int64"), # type: ignore + num_tokens_per_rank: T.Tensor([num_ranks], "int32"), # type: ignore + num_tokens_per_rdma_rank: T.Tensor([num_rdma_ranks], "int32"), # type: ignore + num_tokens_per_expert: T.Tensor([num_experts], "int32"), # type: ignore + is_token_in_rank: T.Tensor([num_tokens, num_ranks], "bool"), # type: ignore + ): + with T.Kernel(num_sms, threads=threads) as bx: + tx = T.get_thread_binding() + + # Calculate expert statistics + tokens_per_expert_per_thread = T.alloc_shared([threads, experts_per_sm], "int32") + T.clear(tokens_per_expert_per_thread) + + expert_begin_idx = T.alloc_var("int32") + expert_begin_idx = bx * experts_per_sm + expert_end_idx = T.alloc_var("int32") + expert_end_idx = T.min(expert_begin_idx + experts_per_sm, num_experts) + + if expert_begin_idx < expert_end_idx: + for i in T.serial(tx, num_tokens, threads): + for j in T.serial(num_topk): + expert_idx = T.alloc_var("int32") + expert_idx = topk_idx[i, j] + if expert_begin_idx <= expert_idx and expert_idx < expert_end_idx: + tokens_per_expert_per_thread[tx, + expert_idx - expert_begin_idx] += 1 + + if expert_begin_idx + tx < expert_end_idx: + sum = T.alloc_var("int32") + sum = 0 + for i in T.serial(threads): + sum += tokens_per_expert_per_thread[i, tx] + num_tokens_per_expert[expert_begin_idx + tx] = sum + + # Calculate rank statistics + sm_begin = T.alloc_var("int32") + sm_begin = T.ceildiv(num_experts, experts_per_sm) + rank_begin_idx = T.alloc_var("int32") + rank_begin_idx = (bx - sm_begin) * ranks_per_sm + rank_end_idx = T.alloc_var("int32") + rank_end_idx = T.min(rank_begin_idx + ranks_per_sm, num_ranks) + rdma_rank_begin_idx = rank_begin_idx // NUM_MAX_NVL_PEERS + rdma_rank_end_idx = rank_end_idx // NUM_MAX_NVL_PEERS + + if rank_begin_idx >= 0 and rank_begin_idx < rank_end_idx: + tokens_per_rank_per_thread = T.alloc_shared([threads, ranks_per_sm], "int32") + tokens_per_rdma_rank_per_thread = T.alloc_shared([threads, rdma_ranks_per_sm], "int32") + T.clear(tokens_per_rank_per_thread) + T.clear(tokens_per_rdma_rank_per_thread) + + expert_begin = T.alloc_var("int32") + expert_begin = rank_begin_idx * experts_per_rank + expert_end = T.alloc_var("int32") + expert_end = rank_end_idx * experts_per_rank + + for i in T.serial(tx, num_tokens, threads): + is_in_rank = T.alloc_local([ranks_per_sm], "int32") + is_in_rdma_rank = T.alloc_local([rdma_ranks_per_sm], "int32") + T.clear(is_in_rank) + T.clear(is_in_rdma_rank) + + for j in T.serial(num_topk): + expert_idx = T.alloc_var("int32") + rank_idx = T.alloc_var("int32") + expert_idx = topk_idx[i, j] + if expert_begin <= expert_idx and expert_idx < expert_end: + rank_idx = expert_idx // experts_per_rank - rank_begin_idx + + is_in_rank[rank_idx] += 1 + is_in_rdma_rank[rank_idx // NUM_MAX_NVL_PEERS] += 1 + + for j in T.serial(rank_begin_idx, rank_end_idx): + if is_in_rank[j - rank_begin_idx] > 0: + is_token_in_rank[i, j] = True + tokens_per_rank_per_thread[tx, j - rank_begin_idx] += 1 + else: + is_token_in_rank[i, j] = False + + for j in T.serial(rdma_rank_begin_idx, rdma_rank_end_idx): + if is_in_rdma_rank[j - rdma_rank_begin_idx] > 0: + tokens_per_rdma_rank_per_thread[tx, j - rdma_rank_begin_idx] += 1 + + if rank_begin_idx + tx < rank_end_idx: + sum = T.alloc_var("int32") + sum = 0 + for i in T.serial(threads): + sum += tokens_per_rank_per_thread[i, tx] + num_tokens_per_rank[rank_begin_idx + tx] = sum + + if rdma_rank_begin_idx + tx < rdma_rank_end_idx: + sum = T.alloc_var("int32") + sum = 0 + for i in T.serial(threads): + sum += tokens_per_rdma_rank_per_thread[i, tx] + num_tokens_per_rdma_rank[rdma_rank_begin_idx + tx] = sum + + return get_dispatch_layout_main + + +def test_get_dispatch_layout( + num_tokens: int, + num_topk_groups: int, + num_topk: int, + num_experts: int, + num_local_ranks: int, + num_nodes: int, +): + try: + import deep_ep_cpp # noqa: F403 + except ModuleNotFoundError as e: + raise ModuleNotFoundError("Please install DeepEP to run this test.") + + num_ranks = num_local_ranks * num_nodes + num_rdma_ranks = num_ranks // NUM_MAX_NVL_PEERS + + # Validate correctness + x, topk_idx, topk_weights, rank_idx, rdma_rank_idx = gen_internode_inputs(num_tokens, 1, num_topk_groups, num_topk, num_experts, num_local_ranks * num_nodes, num_nodes) + buffer = deep_ep_cpp.Buffer( + 0, # rank + num_ranks, + 0, # num_nvl_bytes = 0 to bypass internode sync + 0, # num_rdma_bytes = 0 to bypass internode sync + False, # low_latency_mode + False, # explicit_destroy + False, # enable_shrink + False, # use fabric + ) + buffer.sync([i for i in range(num_ranks)], [None] * num_ranks, None) # fake a buffer to run on single rank + + ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = buffer.get_dispatch_layout(topk_idx, num_experts, None, False, False) + + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank = get_dispatch_layout(topk_idx, num_experts, num_ranks, num_rdma_ranks) + + assert torch.equal(num_tokens_per_expert, ref_num_tokens_per_expert), \ + f"num_tokens_per_expert mismatch, max err: {(num_tokens_per_expert - ref_num_tokens_per_expert).abs().max()}" + + assert torch.equal(is_token_in_rank, ref_is_token_in_rank), \ + "is_token_in_rank mismatch" + + assert torch.equal(num_tokens_per_rank, ref_num_tokens_per_rank), \ + f"num_tokens_per_rank mismatch, max err: {(num_tokens_per_rank - ref_num_tokens_per_rank).abs().max()}" + + assert torch.equal(num_tokens_per_rdma_rank, ref_num_tokens_per_rdma_rank), \ + f"num_tokens_per_rdma_rank mismatch, max err: {(num_tokens_per_rdma_rank - ref_num_tokens_per_rdma_rank).abs().max()}" + + print("All checks passed for TileScale internode get_dispatch_layout.✅") + + # Benchmark + t1 = do_bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts, None, False, False), + _n_warmup=1, + _n_repeat=1, + ) + t2 = do_bench(lambda: get_dispatch_layout(topk_idx, num_experts, num_ranks, num_rdma_ranks), + _n_warmup=1, + _n_repeat=1, + ) + print(f"DeepEP: {t1:.3f} ms") + print(f"TileScale: {t2:.3f} ms") + print(f"Speedup: {t1 / t2:.2f}x") + + +def parse_args(): + parser = ArgumentParser(description="Test get_dispatch_layout") + parser.add_argument("--num_tokens", type=int, default=4096, help="Number of tokens") + parser.add_argument('--num-topk-groups', type=int, default=None, help='Number of top-k groups (default: `min(num_nodes, 4)`)') + parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)') + parser.add_argument("--num_experts", type=int, default=256, help="Number of experts") + parser.add_argument("--num_local_ranks", type=int, default=8, help="Number of local ranks") + parser.add_argument("--num_nodes", type=int, default=8, help="Number of nodes") + + args = parser.parse_args() + if args.num_topk_groups is None: + args.num_topk_groups = min(args.num_nodes, 4) + + return args + + +if __name__ == "__main__": + args = parse_args() + + test_get_dispatch_layout( + num_tokens=args.num_tokens, + num_topk_groups=args.num_topk_groups, + num_topk=args.num_topk, + num_experts=args.num_experts, + num_local_ranks=args.num_local_ranks, + num_nodes=args.num_nodes) diff --git a/examples/distributed/deepseek_deepep/intranode/dispatch.py b/examples/distributed/deepseek_deepep/intranode/dispatch.py index 83912a089..dda152b51 100644 --- a/examples/distributed/deepseek_deepep/intranode/dispatch.py +++ b/examples/distributed/deepseek_deepep/intranode/dispatch.py @@ -853,7 +853,7 @@ def intranode_dispatch( "bfloat16", ) kernel.initialize(allocator=allocator, stream=comm_stream.cuda_stream) - with tvm_ffi.use_torch_stream(torch.cuda.stream(comm_stream)): + with torch.cuda.stream(comm_stream): kernel( rank, recv_x, diff --git a/src/op/builtin.cc b/src/op/builtin.cc index e82870af7..c5464dc74 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -395,5 +395,26 @@ TIR_DEFINE_TL_BUILTIN(warp_any).set_num_inputs(2).set_attr( TIR_DEFINE_TL_BUILTIN(warp_all).set_num_inputs(2).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); +TIR_DEFINE_TL_BUILTIN(ibgda_get_qps_per_rdma_rank) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + +TIR_DEFINE_TL_BUILTIN(ibgda_quiet) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ibgda_put_nbi_warp) + .set_num_inputs(8) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ibgda_amo_nonfetch_add) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + + } // namespace tl } // namespace tvm diff --git a/src/op/builtin.h b/src/op/builtin.h index 99da3d755..fa8c03626 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -637,7 +637,25 @@ TVM_DLL const Op &warp_any(); */ TVM_DLL const Op &warp_all(); -// Note: ld and st are TileOperators defined in remote_copy.h, not builtins +/*! + * \brief tilelang intrinsic for getting the number of QPs per RDMA rank. + */ +TVM_DLL const Op &ibgda_get_qps_per_rdma_rank(); + +/*! + * \brief tilelang intrinsic for quieting a QP. + */ +TVM_DLL const Op &ibgda_quiet(); + +/*! + * \brief tilelang intrinsic for putting a NBI warp. + */ +TVM_DLL const Op &ibgda_put_nbi_warp(); + +/*! + * \brief tilelang intrinsic for AMO non-fetch add. + */ +TVM_DLL const Op &ibgda_amo_nonfetch_add(); } // namespace tl } // namespace tvm diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index ce605ac7d..110217639 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -299,6 +299,10 @@ std::string CodeGenTileLangCUDA::Finish() { decl_stream << "#include \n"; } + if (use_ibgda_) { + decl_stream << "#include \n"; + } + if (need_cooperative_groups_) { decl_stream << "#include \n"; } @@ -2890,6 +2894,23 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->use_distributed_ = true; std::string ptr_str = this->PrintExpr(op->args[0]); os << "tl::get_uintptr_t(" << ptr_str << ")"; + os << "__all_sync(" << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::ibgda_get_qps_per_rdma_rank())) { + this->use_ibgda_ = true; + os << "ibgda_get_qps_per_rdma_rank()"; + } else if (op->op.same_as(tl::ibgda_quiet())) { + this->use_ibgda_ = true; + os << "ibgda_quiet(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::ibgda_put_nbi_warp())) { + this->use_ibgda_ = true; + int always_do_post_send = Downcast(op->args[7])->value; + std::string always_do_post_send_str = always_do_post_send ? "true" : "false"; + os << "ibgda_put_nbi_warp<" << always_do_post_send_str << ">(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", " << PrintExpr(op->args[4]) << ", " << PrintExpr(op->args[5]) << ", " << PrintExpr(op->args[6]) << ")"; + } else if (op->op.same_as(tl::ibgda_amo_nonfetch_add())) { + this->use_ibgda_ = true; + int is_local_copy = Downcast(op->args[4])->value; + std::string is_local_copy_str = is_local_copy ? "true" : "false"; + os << "ibgda_amo_nonfetch_add<" << is_local_copy_str << ">(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", " << PrintExpr(op->args[4]) << ")"; } else { // Note: tl.put, tl.get, tl.wait are TileOperators handled through // remote_copy.cc They are lowered to call_extern with diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h index 6c5f89e07..8ce149fd9 100644 --- a/src/target/codegen_cuda.h +++ b/src/target/codegen_cuda.h @@ -145,6 +145,8 @@ class CodeGenTileLangCUDA final : public CodeGenC { bool use_distributed_{use_distributed()}; // whether need nvshmem.h bool use_nvshmem_{false}; + // whether need ibgda.h + bool use_ibgda_{false}; // Op attribute map OpAttrMap op_need_warp_shuffle_ = Op::GetAttrMap("cuda.need_warp_shuffle"); diff --git a/src/tl_templates/cuda/ibgda_device.cuh b/src/tl_templates/cuda/ibgda_device.cuh new file mode 100644 index 000000000..d1acec33a --- /dev/null +++ b/src/tl_templates/cuda/ibgda_device.cuh @@ -0,0 +1,542 @@ +// Portions derived from NVSHMEM (https://developer.nvidia.com/nvshmem) +// Copyright (c) NVIDIA Corporation. +// Licensed under the NVSHMEM Software License Agreement (version: September 3, 2019). +// See full license at: https://docs.nvidia.com/nvshmem/api/sla.html +// +// Modified from original source: +// - nvshmem/src/include/non_abi/device/pt-to-pt/ibgda_device.cuh +// - DeepEP/csrc/kernels/ibgda_device.cuh + +// NOTE(wt): For now, we rely on NVSHMEM to implement IBGDA + +#pragma once + +#include "common.h" +#include "ldst.h" +#include "sync.h" + +#ifndef __CUDACC_RTC__ +#include +#endif + +#include +#include +#include +#include +#include + +#include + +namespace tl { + +namespace detail { + +template +TL_DEVICE void store_relaxed_na(T* ptr, T value) { + ::tl::st(ptr, value); +} + +template +TL_DEVICE void store_release_na(T* ptr, T value) { + ::tl::st(ptr, value); +} + +template +TL_DEVICE T load_relaxed_na(const T* ptr) { + T value; + ::tl::ld(ptr, value); + return value; +} + +TL_DEVICE void store_relaxed_na_vec4(int4* ptr, const int4& value) { + asm volatile("st.relaxed.gpu.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w)); +} + +} // namespace detail + +static_assert(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth"); + +TL_DEVICE uint64_t HtoBE64(uint64_t x) { + uint64_t ret; + asm("{\n\t" + ".reg .b32 ign;\n\t" + ".reg .b32 lo;\n\t" + ".reg .b32 hi;\n\t" + ".reg .b32 new_lo;\n\t" + ".reg .b32 new_hi;\n\t" + "mov.b64 {lo,hi}, %1;\n\t" + "prmt.b32 new_hi, lo, ign, 0x0123;\n\t" + "prmt.b32 new_lo, hi, ign, 0x0123;\n\t" + "mov.b64 %0, {new_lo,new_hi};\n\t" + "}" + : "=l"(ret) + : "l"(x)); + return ret; +} + +TL_DEVICE uint32_t HtoBE32(uint32_t x) { + uint32_t ret; + asm("{\n\t" + ".reg .b32 ign;\n\t" + "prmt.b32 %0, %1, ign, 0x0123;\n\t" + "}" + : "=r"(ret) + : "r"(x)); + return ret; +} + +TL_DEVICE uint16_t HtoBE16(uint16_t x) { + // TODO: simplify PTX using 16-bit instructions + auto a = static_cast(x); + uint32_t d; + asm volatile( + "{\n\t" + ".reg .b32 mask;\n\t" + ".reg .b32 ign;\n\t" + "mov.b32 mask, 0x4401;\n\t" + "mov.b32 ign, 0x0;\n\t" + "prmt.b32 %0, %1, ign, mask;\n\t" + "}" + : "=r"(d) + : "r"(a)); + return static_cast(d); +} + +typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t; + +typedef struct { + uint32_t add_data; + uint32_t field_boundary; + uint64_t reserved; +} __attribute__((__packed__)) ibgda_atomic_32_masked_fa_seg_t; + +TL_DEVICE nvshmemi_ibgda_device_state_t* ibgda_get_state() { + return &nvshmemi_ibgda_device_state_d; +} + +TL_DEVICE nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) { + auto state = ibgda_get_state(); + const auto num_rc_per_pe = ibgda_get_state()->num_rc_per_pe; + return &state->globalmem + .rcs[pe * num_rc_per_pe * state->num_devices_initialized + id % (num_rc_per_pe * state->num_devices_initialized)]; +} + +TL_DEVICE int ibgda_get_qps_per_rdma_rank() { + return ibgda_get_state()->num_rc_per_pe * ibgda_get_state()->num_devices_initialized; +} + +TL_DEVICE void ibgda_lock_acquire(int* lock) { + while (atomicCAS(lock, 0, 1) == 1) + ; + + // Prevent reordering before the lock is acquired + memory_fence_cta(); +} + +TL_DEVICE void ibgda_lock_release(int* lock) { + memory_fence_cta(); + + // Prevent reordering before lock is released + detail::store_relaxed_na(lock, 0); +} + +TL_DEVICE void ibgda_update_dbr(nvshmemi_ibgda_device_qp_t* qp, uint32_t dbrec_head) { + // `DBREC` contains the index of the next empty `WQEBB` + __be32 dbrec_val; + __be32* dbrec_ptr = qp->tx_wq.dbrec; + + // This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(dbrec_head & 0xffff))` + asm("{\n\t" + ".reg .b32 dbrec_head_16b;\n\t" + ".reg .b32 ign;\n\t" + "and.b32 dbrec_head_16b, %1, 0xffff;\n\t" + "prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t" + "}" + : "=r"(dbrec_val) + : "r"(dbrec_head)); + detail::store_release_na(dbrec_ptr, dbrec_val); +} + +TL_DEVICE void ibgda_ring_db(nvshmemi_ibgda_device_qp_t* qp, uint16_t prod_idx) { + auto bf_ptr = reinterpret_cast(qp->tx_wq.bf); + ibgda_ctrl_seg_t ctrl_seg = {.opmod_idx_opcode = HtoBE32(prod_idx << 8), .qpn_ds = HtoBE32(qp->qpn << 8)}; + + static_assert(sizeof(decltype(&ctrl_seg)) == sizeof(uint64_t), ""); + detail::store_release_na(bf_ptr, *(reinterpret_cast(&ctrl_seg))); +} + +TL_DEVICE void ibgda_post_send(nvshmemi_ibgda_device_qp_t* qp, uint64_t new_prod_idx) { + nvshmemi_ibgda_device_qp_management_t* mvars = &qp->mvars; + uint64_t old_prod_idx; + + // Update `prod_idx` before ringing the doorbell, so that we know which index is needed in quiet/fence + ibgda_lock_acquire(&mvars->post_send_lock); + + old_prod_idx = atomicMax(reinterpret_cast(&mvars->tx_wq.prod_idx), new_prod_idx); + if (new_prod_idx > old_prod_idx) { + ibgda_update_dbr(qp, new_prod_idx); + ibgda_ring_db(qp, new_prod_idx); + } + ibgda_lock_release(&mvars->post_send_lock); +} + +template +TL_DEVICE void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t* qp, + uint64_t base_wqe_idx, + uint32_t num_wqes, + int message_idx = 0) { + auto state = ibgda_get_state(); + nvshmemi_ibgda_device_qp_management_t* mvars = &qp->mvars; + uint64_t new_wqe_idx = base_wqe_idx + num_wqes; + + // WQE writes must be finished first + __threadfence(); + + unsigned long long int* ready_idx = + (unsigned long long int*)(state->use_async_postsend ? qp->tx_wq.prod_idx : &mvars->tx_wq.ready_head); + + // Wait for prior WQE slots to be filled first + while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx) + ; + + // Always post, not in batch + if (!state->use_async_postsend) { + constexpr int kNumRequestInBatch = 4; + if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0) + ibgda_post_send(qp, new_wqe_idx); + } +} + +TL_DEVICE void ibgda_write_rdma_write_inl_wqe( + nvshmemi_ibgda_device_qp_t* qp, const uint32_t* val, uint64_t raddr, __be32 rkey, uint16_t wqe_idx, void** out_wqes, uint32_t imm) { + ibgda_ctrl_seg_t ctrl_seg; + struct mlx5_wqe_raddr_seg raddr_seg; + struct mlx5_wqe_inl_data_seg inl_seg; + + auto* ctrl_seg_ptr = reinterpret_cast(out_wqes[0]); + auto* raddr_seg_ptr = reinterpret_cast(reinterpret_cast(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); + auto* inl_seg_ptr = reinterpret_cast(reinterpret_cast(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); + auto* wqe_data_ptr = reinterpret_cast(reinterpret_cast(inl_seg_ptr) + sizeof(*inl_seg_ptr)); + + raddr_seg.raddr = HtoBE64(raddr); + raddr_seg.rkey = rkey; + raddr_seg.reserved = 0; + + inl_seg.byte_count = HtoBE32(4 | MLX5_INLINE_SEG); + + // `imm == std::numeric_limits::max()` means no imm writes + ctrl_seg = {0}; + ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3); + ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; + ctrl_seg.opmod_idx_opcode = + HtoBE32((wqe_idx << 8) | (imm != std::numeric_limits::max() ? MLX5_OPCODE_RDMA_WRITE_IMM : MLX5_OPCODE_RDMA_WRITE)); + if (imm != std::numeric_limits::max()) + ctrl_seg.imm = HtoBE32(imm); + + static_assert(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16"); + static_assert(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16"); + static_assert(sizeof(*inl_seg_ptr) == 4, "sizeof(*inl_seg_ptr) == 4"); + detail::store_relaxed_na_vec4(reinterpret_cast(ctrl_seg_ptr), *reinterpret_cast(&ctrl_seg)); + detail::store_relaxed_na_vec4(reinterpret_cast(raddr_seg_ptr), *reinterpret_cast(&raddr_seg)); + detail::store_relaxed_na(reinterpret_cast(inl_seg_ptr), *reinterpret_cast(&inl_seg)); + detail::store_relaxed_na(reinterpret_cast(wqe_data_ptr), *reinterpret_cast(val)); +} + +TL_DEVICE uint64_t +ibgda_get_lkey_and_rkey(uint64_t laddr, __be32* lkey, uint64_t raddr, int dst_pe, uint64_t* out_raddr, __be32* out_rkey, uint32_t dev_idx) { + auto state = ibgda_get_state(); + auto heap_start = reinterpret_cast(nvshmemi_device_state_d.heap_base); + auto log2_cumem_granularity = state->log2_cumem_granularity; + + // Local key + uint64_t idx = ((laddr - heap_start) >> log2_cumem_granularity) * state->num_devices_initialized + dev_idx; + auto device_key = state->constmem.lkeys[idx]; + auto lchunk_size = device_key.next_addr - laddr; + *lkey = device_key.key; + + // Remote key + uint64_t roffset = raddr - heap_start; + + idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) * state->num_devices_initialized + + dst_pe * state->num_devices_initialized + dev_idx; + if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) { + device_key = state->constmem.rkeys[idx]; + } else { + device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS]; + } + *out_raddr = reinterpret_cast(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset; + *out_rkey = device_key.key; + + // Return the minimum of local and remote chunk sizes + auto rchunk_size = device_key.next_addr - roffset; + return min(lchunk_size, rchunk_size); +} + +TL_DEVICE void ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t* out_raddr, __be32* out_rkey, uint32_t dev_idx) { + auto state = ibgda_get_state(); + auto heap_start = reinterpret_cast(nvshmemi_device_state_d.heap_base); + + uint64_t roffset = addr - heap_start; + uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes * state->num_devices_initialized) + + dst_pe * state->num_devices_initialized + dev_idx; + nvshmemi_ibgda_device_key_t device_key; + if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) + device_key = state->constmem.rkeys[idx]; + else + device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS]; + *out_raddr = reinterpret_cast(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset; + *out_rkey = device_key.key; +} + +TL_DEVICE uint64_t ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t* qp, uint32_t num_wqes) { + auto mvars = &qp->mvars; + return atomicAdd(reinterpret_cast(&mvars->tx_wq.resv_head), static_cast(num_wqes)); +} + +TL_DEVICE void* ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) { + uint16_t cnt = qp->tx_wq.nwqes; + uint16_t idx = wqe_idx & (cnt - 1); + return reinterpret_cast(reinterpret_cast(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT)); +} + +TL_DEVICE void nvshmemi_ibgda_rma_p( + int* rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits::max()) { + // Get rkey + // NOTES: the `p` operation will not cross multiple remote chunks + __be32 rkey; + uint64_t raddr; + auto qp = ibgda_get_rc(dst_pe, qp_id); + ibgda_get_rkey(reinterpret_cast(rptr), dst_pe, &raddr, &rkey, qp->dev_idx); + + // Write WQEs + uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); + void* wqe_ptrs; + wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx); + ibgda_write_rdma_write_inl_wqe(qp, reinterpret_cast(&value), raddr, rkey, base_wqe_idx, &wqe_ptrs, imm); + + // Submit requests + ibgda_submit_requests(qp, base_wqe_idx, 1); +} + +TL_DEVICE void ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t* qp, + uint64_t laddr, + __be32 lkey, + uint64_t raddr, + __be32 rkey, + uint32_t bytes, + uint16_t wqe_idx, + void** out_wqes) { + ibgda_ctrl_seg_t ctrl_seg; + struct mlx5_wqe_raddr_seg raddr_seg; + struct mlx5_wqe_data_seg data_seg; + + auto* ctrl_seg_ptr = reinterpret_cast(out_wqes[0]); + void* av_seg_ptr = reinterpret_cast(reinterpret_cast(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); + struct mlx5_wqe_raddr_seg* raddr_seg_ptr; + struct mlx5_wqe_data_seg* data_seg_ptr; + + raddr_seg_ptr = reinterpret_cast(reinterpret_cast(av_seg_ptr)); + data_seg_ptr = reinterpret_cast(reinterpret_cast(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); + + raddr_seg.raddr = HtoBE64(raddr); + raddr_seg.rkey = rkey; + raddr_seg.reserved = 0; + + data_seg.byte_count = HtoBE32(bytes); + data_seg.lkey = lkey; + data_seg.addr = HtoBE64(laddr); + + ctrl_seg = {0}; + ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3); + ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; + ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE); + + static_assert(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16"); + static_assert(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16"); + static_assert(sizeof(*data_seg_ptr) == 16, "sizeof(*data_seg_ptr) == 16"); + detail::store_relaxed_na_vec4(reinterpret_cast(ctrl_seg_ptr), *reinterpret_cast(&ctrl_seg)); + detail::store_relaxed_na_vec4(reinterpret_cast(raddr_seg_ptr), *reinterpret_cast(&raddr_seg)); + detail::store_relaxed_na_vec4(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); +} + +TL_DEVICE void ibgda_write_empty_recv_wqe(void* out_wqe) { + auto* data_seg_ptr = reinterpret_cast(out_wqe); + struct mlx5_wqe_data_seg data_seg; + + // Make the first segment in the WQE invalid, then the entire list will be invalid + data_seg.byte_count = 0; + data_seg.lkey = HtoBE64(MLX5_INVALID_LKEY); + data_seg.addr = 0; + + static_assert(sizeof(mlx5_wqe_data_seg) == sizeof(int4), "Invalid data type length"); + detail::store_relaxed_na_vec4(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); +} + +template +TL_DEVICE void nvshmemi_ibgda_put_nbi_warp( + uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) { + // Get lkey and rkey, store them into lanes + uint32_t num_wqes = 0; + __be32 my_lkey = 0; + uint64_t my_laddr = 0; + __be32 my_rkey = 0; + uint64_t my_raddr = 0; + uint64_t my_chunk_size = 0; + + auto qp = ibgda_get_rc(dst_pe, qp_id); + + // Decide how many messages (theoretically 3 for maximum) + auto remaining_bytes = bytes; + while (remaining_bytes > 0) { + if (lane_id == num_wqes) { + my_chunk_size = min(remaining_bytes, + ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, &my_rkey, qp->dev_idx)); + } + + // Move one more message + auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast(num_wqes)); + remaining_bytes -= chunk_size; + req_lptr += chunk_size; + req_rptr += chunk_size; + ++num_wqes; + } + if (!(num_wqes <= 32)) { + printf("Assertion failed: %s:%d, condition: num_wqes <= 32\n", __FILE__, __LINE__); + trap(); + } + + // Process WQE + uint64_t base_wqe_idx = 0; + if (lane_id == 0) + base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes); + base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0); + if (lane_id < num_wqes) { + auto wqe_idx = base_wqe_idx + lane_id; + auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx); + ibgda_write_rdma_write_wqe(qp, my_laddr, my_lkey, my_raddr, my_rkey, my_chunk_size, wqe_idx, &wqe_ptr); + } + __syncwarp(); + + // Submit + if (lane_id == 0) + ibgda_submit_requests(qp, base_wqe_idx, num_wqes, message_idx); + __syncwarp(); +} + +TL_DEVICE void ibgda_write_amo_add_wqe(nvshmemi_ibgda_device_qp_t* qp, + const int& value, + uint64_t laddr, + __be32 lkey, + uint64_t raddr, + __be32 rkey, + uint16_t wqe_idx, + void** out_wqes) { + ibgda_ctrl_seg_t ctrl_seg = {0}; + struct mlx5_wqe_raddr_seg raddr_seg; + struct mlx5_wqe_atomic_seg atomic_seg_1; + struct mlx5_wqe_data_seg data_seg; + + auto ctrl_seg_ptr = reinterpret_cast(out_wqes[0]); + auto raddr_seg_ptr = reinterpret_cast(reinterpret_cast(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); + auto atomic_seg_ptr = reinterpret_cast(reinterpret_cast(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); + auto data_seg_ptr = reinterpret_cast(reinterpret_cast(atomic_seg_ptr) + sizeof(*atomic_seg_ptr)); + + raddr_seg.raddr = HtoBE64(raddr); + raddr_seg.rkey = rkey; + raddr_seg.reserved = 0; + + // NOTES: `0x08000000` means `IBGDA_4_BYTE_EXT_AMO_OPMOD` + ctrl_seg.opmod_idx_opcode = HtoBE32(MLX5_OPCODE_ATOMIC_MASKED_FA | (wqe_idx << 8) | 0x08000000); + auto atomic_32_masked_fa_seg = reinterpret_cast(&atomic_seg_1); + atomic_32_masked_fa_seg->add_data = HtoBE32(value); + atomic_32_masked_fa_seg->field_boundary = 0; + + ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 4); + ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; + + data_seg.byte_count = HtoBE32(sizeof(int)); + data_seg.lkey = lkey; + data_seg.addr = HtoBE64(laddr); + + static_assert(sizeof(*ctrl_seg_ptr) == sizeof(int4), "Invalid vectorization"); + static_assert(sizeof(*raddr_seg_ptr) == sizeof(int4), "Invalid vectorization"); + static_assert(sizeof(*atomic_seg_ptr) == sizeof(int4), "Invalid vectorization"); + static_assert(sizeof(*data_seg_ptr) == sizeof(int4), "Invalid vectorization"); + detail::store_relaxed_na_vec4(reinterpret_cast(ctrl_seg_ptr), *reinterpret_cast(&ctrl_seg)); + detail::store_relaxed_na_vec4(reinterpret_cast(raddr_seg_ptr), *reinterpret_cast(&raddr_seg)); + detail::store_relaxed_na_vec4(reinterpret_cast(atomic_seg_ptr), *reinterpret_cast(&atomic_seg_1)); + detail::store_relaxed_na_vec4(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); +} + +TL_DEVICE void nvshmemi_ibgda_amo_nonfetch_add( + void* rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) { + if (is_local_copy) { + atomicAdd(static_cast(rptr), value); + } else { + nvshmemi_ibgda_device_qp_t* qp = ibgda_get_rc(pe, qp_id); + + __be32 rkey; + uint64_t raddr; + ibgda_get_rkey(reinterpret_cast(rptr), pe, &raddr, &rkey, qp->dev_idx); + + uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); + void* wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx); + + ibgda_write_amo_add_wqe(qp, value, reinterpret_cast(qp->ibuf.buf), qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs); + + ibgda_submit_requests(qp, my_wqe_idx, 1); + } +} + +TL_DEVICE uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, const int& rank, const int& dst_rank) { + // Local rank, no need for mapping + if (rank == dst_rank) + return ptr; + auto peer_base = __ldg(reinterpret_cast(nvshmemi_device_state_d.peer_heap_base_p2p) + dst_rank); + + // RDMA connected + if (peer_base == 0) + return 0; + + // NVLink P2P is enabled + return peer_base + (ptr - reinterpret_cast(nvshmemi_device_state_d.heap_base)); +} + +// This is a simplified version of NVSHMEM's `ibgda_poll_cq`. +// Note that this implementation does not guarantee thread safety, +// so we must ensure that no other threads are concurrently using the same QP. +TL_DEVICE void ibgda_poll_cq(nvshmemi_ibgda_device_cq_t* cq, uint64_t idx) { + const auto cqe64 = static_cast(cq->cqe); + const uint32_t ncqes = cq->ncqes; + memory_fence_cta(); + if (*cq->cons_idx >= idx) + return; + // NOTES: this while loop is part of do-while below. + // `wqe_counter` is the HW consumer index. However, we always maintain `index + 1`. + // To be able to compare with the index, we need to use `wqe_counter + 1`. + // Because `wqe_counter` is `uint16_t`, it may be overflow. Still, we know for + // sure that if `idx - wqe_counter - 1 < ncqes`, `wqe_counter + 1 is less than + // idx, and thus we need to wait. We don't need to wait when `idx == wqe_counter + 1` + // That's why we use `- 2` here to make this case overflow. + uint16_t wqe_counter; + do { + wqe_counter = HtoBE16(detail::load_relaxed_na(&cqe64->wqe_counter)); + } while ((static_cast(static_cast(idx) - wqe_counter - static_cast(2)) < ncqes)); + *cq->cons_idx = idx; + + // Prevent reordering of this function and later instructions + memory_fence_cta(); +} + +// Wait until wqe `idx - 1` is completed. +TL_DEVICE void nvshmemi_ibgda_quiet(int dst_pe, int qp_id) { + auto qp = ibgda_get_rc(dst_pe, qp_id); + auto state = ibgda_get_state(); + uint64_t prod_idx = state->use_async_postsend ? detail::load_relaxed_na(qp->tx_wq.prod_idx) + : detail::load_relaxed_na(&qp->mvars.tx_wq.ready_head); + ibgda_poll_cq(qp->tx_wq.cq, prod_idx); +} + +} // namespace tl diff --git a/tilelang/language/distributed/multi_device/ibgda.py b/tilelang/language/distributed/multi_device/ibgda.py new file mode 100644 index 000000000..58da42d32 --- /dev/null +++ b/tilelang/language/distributed/multi_device/ibgda.py @@ -0,0 +1,22 @@ +"""The language interface for tl programs.""" + +"""This file provides interface for inter-node comm via IBGDA. +For now, we rely on NVSHMEM to implement IBGDA.""" + +from tvm import tir + + +def ibgda_get_qps_per_rdma_rank(): + return tir.call_intrin("int32", tir.op.Op.get("tl.IbgdaGetQpsPerRdmaRank")) + + +def ibgda_quiet(dst_pe, qp_id): + return tir.call_intrin("handle", tir.op.Op.get("tl.IbgdaQuiet"), dst_pe, qp_id) + + +def ibgda_put_nbi_warp(req_rptr, req_lptr, bytes, dst_pe, qp_id, lane_id, message_idx, always_do_post_send=False): + return tir.call_intrin("handle", tir.op.Op.get("tl.IbgdaPutNbiWarp"), req_rptr, req_lptr, bytes, dst_pe, qp_id, lane_id, message_idx, always_do_post_send) + + +def ibgda_amo_nonfetch_add(rptr, value, pe, qp_id, is_local_copy=False): + return tir.call_intrin("handle", tir.op.Op.get("tl.IbgdaAmoNonfetchAdd"), rptr, value, pe, qp_id, is_local_copy)