Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion examples/distributed/deepseek_deepep/deepep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def unpack_bias(bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]):

# Check: DeepEP/tests/test_intranode.py:test_main
def gen_inputs(num_tokens: int, hidden: int, num_topk: int, num_experts: int, num_ranks: int):
"""Generate random inputs for testing purpose.
"""Generate random inputs for intranode testing purpose.
Args:
num_tokens: the number of tokens.
hidden: the hidden dimension.
Expand Down Expand Up @@ -156,6 +156,77 @@ def gen_inputs(num_tokens: int, hidden: int, num_topk: int, num_experts: int, nu
return x, topk_idx, topk_weights, rank_idx



def gen_internode_inputs(num_tokens: int, hidden: int, num_topk_groups: int, num_topk: int, num_experts: int, num_ranks: int, num_nodes: int):
"""
Generate random inputs with group restriction (native in DeepSeek MoE) for internode testing purpose.
Args:
num_tokens: the number of tokens.
hidden: the hidden dimension.
num_topk_groups: the number of top-k groups.
num_topk: the number of top-k experts to select for each token.
num_experts: the number of experts.
num_ranks: the number of total ranks.
num_nodes: the number of nodes.

Returns:
x: `[num_tokens, hidden]` with `torch.bfloat16`, the input to MoE layer.
topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token,
`-1` means no selections.
topk_weights: `[num_tokens, num_topk]` with `torch.float32`, the weights corresponding to
each selected expert for each token.
rank_idx: `[num_tokens, num_topk]` with `torch.int32`, the rank indices corresponding to
each selected expert, `-1` means no selections.
rdma_rank_idx: `[num_tokens, num_topk]` with `torch.int32`, the RDMA rank indices corresponding to
each selected expert, `-1` means no selections.
"""
assert num_topk <= num_experts, "num_topk must be less than or equal to num_experts"
assert num_experts % num_ranks == 0, "num_experts must be divisible by num_ranks"

x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda')
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
num_local_ranks = num_ranks // num_nodes
rdma_rank_idx = rank_idx // num_local_ranks
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
inplace_unique(rdma_rank_idx, num_nodes)
return x, topk_idx, topk_weights, rank_idx, rdma_rank_idx


def inplace_unique(x: torch.Tensor, num_slots: int):
"""
Keep at most `num_slots` different values in each row of `x`,
and fill `x` with -1 in other positions.
"""
assert x.dim() == 2
mask = x < 0
x_padded = x.masked_fill(mask, num_slots)
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
bin_count = bin_count[:, :num_slots]
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
x[:, :].fill_(-1)
valid_len = min(num_slots, x.size(1))
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]


def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int):
num_tokens, num_experts = scores.shape
scores = scores.view(num_tokens, num_groups, -1)
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
return (scores * mask).view(num_tokens, num_experts)


def inplace_unique(x: torch.Tensor, num_slots: int):
"""
Keep at most `num_slots` different values in each row of `x`,
Expand Down
270 changes: 270 additions & 0 deletions examples/distributed/deepseek_deepep/internode/get_dispatch_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
# For internode
# This op is non-distributed
### python get_dispatch_layout.py

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(__file__))) # add parent folder to path

import torch
import tilelang
import tilelang.language as T
from tilelang.profiler import do_bench
from typing import Tuple
from argparse import ArgumentParser
from utils import gen_internode_inputs, NUM_MAX_NVL_PEERS # noqa: F403


# TODO(wt): Add async functionality
def get_dispatch_layout(
topk_idx: torch.Tensor, num_experts: int,
num_ranks, num_rdma_ranks) -> Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor]:
"""Calculate the layout required for later communication.

Arguments:
topk_idx: `[num_tokens, num_topk]`, dtype must be `torch.int64`, the expert indices selected by each token,
`-1` means no selections.
num_experts: the number of experts.
num_ranks: the number of ranks.
num_rdma_ranks: the number of RDMA ranks, i.e. the number of nodes.

Returns:
num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank.
num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA
rank (with the same GPU index), return `None` for intranode settings.
num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert.
is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank.
"""

# Check inputs
assert topk_idx.dtype == torch.int64, "topk_idx must be of dtype torch.int64"
assert topk_idx.ndim == 2, "topk_idx must be a 2D tensor"
assert topk_idx.is_contiguous(), "topk_idx must be a contiguous tensor"

# Allocate tensors
# TODO(wt): Wait on previous events and allocate on comm stream when adding async functionality
num_tokens, num_topk = topk_idx.shape
num_tokens_per_rank = torch.empty(num_ranks, dtype=torch.int32, device='cuda')
num_tokens_per_rdma_rank = torch.empty(num_rdma_ranks, dtype=torch.int32, device='cuda')
num_tokens_per_expert = torch.empty(num_experts, dtype=torch.int32, device='cuda')
is_token_in_rank = torch.empty((num_tokens, num_ranks), dtype=torch.bool, device='cuda')

# Launch the kernel
kernel = get_dispatch_layout_kernel(num_topk, num_experts, num_ranks, num_rdma_ranks)
kernel(
topk_idx,
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
is_token_in_rank,
)

# TODO(wt): Wait streams when adding async functionality

return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank


@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def get_dispatch_layout_kernel(
num_topk: int,
num_experts: int,
num_ranks: int,
num_rdma_ranks: int,
) -> tilelang.JITKernel:
threads = 256
experts_per_sm = 4
ranks_per_sm = 8
rdma_ranks_per_sm = ranks_per_sm // NUM_MAX_NVL_PEERS
num_sms = T.ceildiv(num_experts, experts_per_sm) + T.ceildiv(num_ranks, ranks_per_sm)
experts_per_rank = num_experts // num_ranks

num_tokens = T.dynamic('num_tokens')

@T.prim_func
def get_dispatch_layout_main(
topk_idx: T.Tensor([num_tokens, num_topk], "int64"), # type: ignore
num_tokens_per_rank: T.Tensor([num_ranks], "int32"), # type: ignore
num_tokens_per_rdma_rank: T.Tensor([num_rdma_ranks], "int32"), # type: ignore
num_tokens_per_expert: T.Tensor([num_experts], "int32"), # type: ignore
is_token_in_rank: T.Tensor([num_tokens, num_ranks], "bool"), # type: ignore
):
with T.Kernel(num_sms, threads=threads) as bx:
tx = T.get_thread_binding()

# Calculate expert statistics
tokens_per_expert_per_thread = T.alloc_shared([threads, experts_per_sm], "int32")
T.clear(tokens_per_expert_per_thread)

expert_begin_idx = T.alloc_var("int32")
expert_begin_idx = bx * experts_per_sm
expert_end_idx = T.alloc_var("int32")
expert_end_idx = T.min(expert_begin_idx + experts_per_sm, num_experts)

if expert_begin_idx < expert_end_idx:
for i in T.serial(tx, num_tokens, threads):
for j in T.serial(num_topk):
expert_idx = T.alloc_var("int32")
expert_idx = topk_idx[i, j]
if expert_begin_idx <= expert_idx and expert_idx < expert_end_idx:
tokens_per_expert_per_thread[tx,
expert_idx - expert_begin_idx] += 1

if expert_begin_idx + tx < expert_end_idx:
sum = T.alloc_var("int32")
sum = 0
for i in T.serial(threads):
sum += tokens_per_expert_per_thread[i, tx]
num_tokens_per_expert[expert_begin_idx + tx] = sum

# Calculate rank statistics
sm_begin = T.alloc_var("int32")
sm_begin = T.ceildiv(num_experts, experts_per_sm)
rank_begin_idx = T.alloc_var("int32")
rank_begin_idx = (bx - sm_begin) * ranks_per_sm
rank_end_idx = T.alloc_var("int32")
rank_end_idx = T.min(rank_begin_idx + ranks_per_sm, num_ranks)
rdma_rank_begin_idx = rank_begin_idx // NUM_MAX_NVL_PEERS
rdma_rank_end_idx = rank_end_idx // NUM_MAX_NVL_PEERS

if rank_begin_idx >= 0 and rank_begin_idx < rank_end_idx:
tokens_per_rank_per_thread = T.alloc_shared([threads, ranks_per_sm], "int32")
tokens_per_rdma_rank_per_thread = T.alloc_shared([threads, rdma_ranks_per_sm], "int32")
T.clear(tokens_per_rank_per_thread)
T.clear(tokens_per_rdma_rank_per_thread)

expert_begin = T.alloc_var("int32")
expert_begin = rank_begin_idx * experts_per_rank
expert_end = T.alloc_var("int32")
expert_end = rank_end_idx * experts_per_rank

for i in T.serial(tx, num_tokens, threads):
is_in_rank = T.alloc_local([ranks_per_sm], "int32")
is_in_rdma_rank = T.alloc_local([rdma_ranks_per_sm], "int32")
T.clear(is_in_rank)
T.clear(is_in_rdma_rank)

for j in T.serial(num_topk):
expert_idx = T.alloc_var("int32")
rank_idx = T.alloc_var("int32")
expert_idx = topk_idx[i, j]
if expert_begin <= expert_idx and expert_idx < expert_end:
rank_idx = expert_idx // experts_per_rank - rank_begin_idx

is_in_rank[rank_idx] += 1
is_in_rdma_rank[rank_idx // NUM_MAX_NVL_PEERS] += 1

for j in T.serial(rank_begin_idx, rank_end_idx):
if is_in_rank[j - rank_begin_idx] > 0:
is_token_in_rank[i, j] = True
tokens_per_rank_per_thread[tx, j - rank_begin_idx] += 1
else:
is_token_in_rank[i, j] = False

for j in T.serial(rdma_rank_begin_idx, rdma_rank_end_idx):
if is_in_rdma_rank[j - rdma_rank_begin_idx] > 0:
tokens_per_rdma_rank_per_thread[tx, j - rdma_rank_begin_idx] += 1

if rank_begin_idx + tx < rank_end_idx:
sum = T.alloc_var("int32")
sum = 0
for i in T.serial(threads):
sum += tokens_per_rank_per_thread[i, tx]
num_tokens_per_rank[rank_begin_idx + tx] = sum

if rdma_rank_begin_idx + tx < rdma_rank_end_idx:
sum = T.alloc_var("int32")
sum = 0
for i in T.serial(threads):
sum += tokens_per_rdma_rank_per_thread[i, tx]
num_tokens_per_rdma_rank[rdma_rank_begin_idx + tx] = sum

return get_dispatch_layout_main


def test_get_dispatch_layout(
num_tokens: int,
num_topk_groups: int,
num_topk: int,
num_experts: int,
num_local_ranks: int,
num_nodes: int,
):
try:
import deep_ep_cpp # noqa: F403
except ModuleNotFoundError as e:
raise ModuleNotFoundError("Please install DeepEP to run this test.")

num_ranks = num_local_ranks * num_nodes
num_rdma_ranks = num_ranks // NUM_MAX_NVL_PEERS

# Validate correctness
x, topk_idx, topk_weights, rank_idx, rdma_rank_idx = gen_internode_inputs(num_tokens, 1, num_topk_groups, num_topk, num_experts, num_local_ranks * num_nodes, num_nodes)
buffer = deep_ep_cpp.Buffer(
0, # rank
num_ranks,
0, # num_nvl_bytes = 0 to bypass internode sync
0, # num_rdma_bytes = 0 to bypass internode sync
False, # low_latency_mode
False, # explicit_destroy
False, # enable_shrink
False, # use fabric
)
buffer.sync([i for i in range(num_ranks)], [None] * num_ranks, None) # fake a buffer to run on single rank

ref_num_tokens_per_rank, ref_num_tokens_per_rdma_rank, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = buffer.get_dispatch_layout(topk_idx, num_experts, None, False, False)

num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank = get_dispatch_layout(topk_idx, num_experts, num_ranks, num_rdma_ranks)

assert torch.equal(num_tokens_per_expert, ref_num_tokens_per_expert), \
f"num_tokens_per_expert mismatch, max err: {(num_tokens_per_expert - ref_num_tokens_per_expert).abs().max()}"

assert torch.equal(is_token_in_rank, ref_is_token_in_rank), \
"is_token_in_rank mismatch"

assert torch.equal(num_tokens_per_rank, ref_num_tokens_per_rank), \
f"num_tokens_per_rank mismatch, max err: {(num_tokens_per_rank - ref_num_tokens_per_rank).abs().max()}"

assert torch.equal(num_tokens_per_rdma_rank, ref_num_tokens_per_rdma_rank), \
f"num_tokens_per_rdma_rank mismatch, max err: {(num_tokens_per_rdma_rank - ref_num_tokens_per_rdma_rank).abs().max()}"

print("All checks passed for TileScale internode get_dispatch_layout.✅")

# Benchmark
t1 = do_bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts, None, False, False),
_n_warmup=1,
_n_repeat=1,
)
t2 = do_bench(lambda: get_dispatch_layout(topk_idx, num_experts, num_ranks, num_rdma_ranks),
_n_warmup=1,
_n_repeat=1,
)
print(f"DeepEP: {t1:.3f} ms")
print(f"TileScale: {t2:.3f} ms")
print(f"Speedup: {t1 / t2:.2f}x")


def parse_args():
parser = ArgumentParser(description="Test get_dispatch_layout")
parser.add_argument("--num_tokens", type=int, default=4096, help="Number of tokens")
parser.add_argument('--num-topk-groups', type=int, default=None, help='Number of top-k groups (default: `min(num_nodes, 4)`)')
parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')
parser.add_argument("--num_experts", type=int, default=256, help="Number of experts")
parser.add_argument("--num_local_ranks", type=int, default=8, help="Number of local ranks")
parser.add_argument("--num_nodes", type=int, default=8, help="Number of nodes")

args = parser.parse_args()
if args.num_topk_groups is None:
args.num_topk_groups = min(args.num_nodes, 4)

return args


if __name__ == "__main__":
args = parse_args()

test_get_dispatch_layout(
num_tokens=args.num_tokens,
num_topk_groups=args.num_topk_groups,
num_topk=args.num_topk,
num_experts=args.num_experts,
num_local_ranks=args.num_local_ranks,
num_nodes=args.num_nodes)
2 changes: 1 addition & 1 deletion examples/distributed/deepseek_deepep/intranode/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ def intranode_dispatch(
"bfloat16",
)
kernel.initialize(allocator=allocator, stream=comm_stream.cuda_stream)
with tvm_ffi.use_torch_stream(torch.cuda.stream(comm_stream)):
with torch.cuda.stream(comm_stream):
kernel(
rank,
recv_x,
Expand Down
21 changes: 21 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,5 +395,26 @@ TIR_DEFINE_TL_BUILTIN(warp_any).set_num_inputs(2).set_attr<TCallEffectKind>(
TIR_DEFINE_TL_BUILTIN(warp_all).set_num_inputs(2).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(ibgda_get_qps_per_rdma_rank)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(ibgda_quiet)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ibgda_put_nbi_warp)
.set_num_inputs(8)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ibgda_amo_nonfetch_add)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));


} // namespace tl
} // namespace tvm
Loading
Loading