From 1d5f1069d82d496923390507f27cae7ac4db0e6b Mon Sep 17 00:00:00 2001 From: shangyuan-ant Date: Fri, 19 Sep 2025 15:56:01 +0800 Subject: [PATCH] feat: Add Expert Affinity Aware EPLB algorithm. Signed-off-by: shangyuan-ant --- .../srt/eplb/eplb_algorithms/__init__.py | 15 +- .../eplb/eplb_algorithms/deepseek_commun.py | 314 ++++++++++++++++++ python/sglang/srt/eplb/eplb_manager.py | 3 +- python/sglang/srt/eplb/expert_distribution.py | 50 ++- python/sglang/srt/eplb/expert_location.py | 33 +- .../srt/eplb/utils/comm_matrix_process.py | 41 +++ .../srt/layers/moe/token_dispatcher/deepep.py | 8 + 7 files changed, 452 insertions(+), 12 deletions(-) create mode 100644 python/sglang/srt/eplb/eplb_algorithms/deepseek_commun.py create mode 100644 python/sglang/srt/eplb/utils/comm_matrix_process.py diff --git a/python/sglang/srt/eplb/eplb_algorithms/__init__.py b/python/sglang/srt/eplb/eplb_algorithms/__init__.py index e2a2678104a..7a603d2204b 100644 --- a/python/sglang/srt/eplb/eplb_algorithms/__init__.py +++ b/python/sglang/srt/eplb/eplb_algorithms/__init__.py @@ -3,7 +3,7 @@ import torch -from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec +from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, deepseek_commun class EplbAlgorithm(Enum): @@ -11,6 +11,7 @@ class EplbAlgorithm(Enum): deepseek_hierarchical = auto() deepseek_vec = auto() deepseek_vec_hierarchical = auto() + deepseek_commun = auto() # TODO may have more algorithm later @@ -21,6 +22,7 @@ def rebalance_experts( num_groups: Optional[int], num_nodes: int, algorithm: EplbAlgorithm, + comm_matrix: Optional[torch.Tensor], ): if algorithm in [EplbAlgorithm.deepseek, EplbAlgorithm.deepseek_hierarchical]: return deepseek.rebalance_experts( @@ -45,6 +47,17 @@ def rebalance_experts( enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical, ) + if algorithm == EplbAlgorithm.deepseek_commun: + """Using DeepSeek-Commun algorithm for expert rebalancing.""" + return deepseek_commun.rebalance_experts( + weight=tokens_per_expert.sum(dim=0), + num_replicas=num_physical_experts, + num_groups=num_groups, + num_nodes=num_nodes, + num_gpus=num_physical_experts // num_local_physical_experts, + comm_matrix=comm_matrix, + ) + raise NotImplementedError diff --git a/python/sglang/srt/eplb/eplb_algorithms/deepseek_commun.py b/python/sglang/srt/eplb/eplb_algorithms/deepseek_commun.py new file mode 100644 index 00000000000..b2b64152a4d --- /dev/null +++ b/python/sglang/srt/eplb/eplb_algorithms/deepseek_commun.py @@ -0,0 +1,314 @@ +from typing import Tuple +import time +import torch + +def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs + are as balanced as possible. + + Parameters: + weight: [X, n], the weight of each item + num_packs: number of packs + + Returns: + pack_index: [X, n], the pack index of each item + rank_in_pack: [X, n], the rank of the item in the pack + """ + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + if groups_per_pack == 1: + pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + return pack_index, rank_in_pack + + indices = weight.float().sort(-1, descending=True).indices.cpu() + pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device='cpu') + rank_in_pack = torch.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + +def replicate_experts(weight: torch.Tensor, num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + """ + Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. + + Parameters: + weight: [X, num_log] + num_phy: total number of experts after replication + + Returns: + phy2log: [X, num_phy], logical expert id of each physical expert + rank: [X, num_phy], the replica rank + logcnt: [X, num_log], number of replicas for each logical expert + """ + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + +def optimize_group_placement(pphy2log, comm_matrix, num_nodes, num_gpus, group_size): + """ + Optimize the placement of expert groups to minimize inter-node communication cost. + + Parameters: + pphy2log: [num_layers, num_physical_experts], physical to logical expert mapping + comm_matrix: [num_logical_experts, num_logical_experts], communication cost between experts + num_nodes: number of server nodes + num_gpus: number of GPUs, must be a multiple of `num_nodes` + group_size: number of experts in each group + + Returns: + optimized_pphy2log: [num_layers, num_physical_experts], optimized physical to logical expert mapping + """ + num_layers, num_physical_experts = pphy2log.shape + num_groups = num_physical_experts // group_size + groups_per_node = num_groups // num_nodes + optimized_pphy2log = pphy2log.clone() + + # compute group start indices before-hand + group_start_indices = [g * group_size for g in range(num_groups)] + + for layer in range(num_layers): + # compute group to node mapping before-hand + group_to_node = torch.zeros(num_groups, dtype=torch.int64) + for g in range(num_groups): + group_to_node[g] = (g * group_size) // (num_physical_experts // num_nodes) + + # get the leader expert of each group + leader_experts = torch.zeros(num_groups, dtype=torch.int64) + for g in range(num_groups): + leader_idx = g * group_size + leader_experts[g] = pphy2log[layer, leader_idx].item() + + # Only consider leader experts for inter-group communication cost + group_comm_cost = torch.zeros((num_groups, num_groups), dtype=torch.float32) + for g1 in range(num_groups): + leader_expert_g1 = leader_experts[g1] + for g2 in range(num_groups): + if g1 != g2: + leader_expert_g2 = leader_experts[g2] + group_comm_cost[g1, g2] = comm_matrix[layer, leader_expert_g1, leader_expert_g2] + + # construct initial node groups + node_groups = [[] for _ in range(num_nodes)] + for g in range(num_groups): + node_idx = group_to_node[g].item() + node_groups[node_idx].append(g) + + # compute initial node pair costs + node_pair_costs = {} + for node1 in range(num_nodes): + for node2 in range(node1 + 1, num_nodes): + cost = 0 + for g1 in node_groups[node1]: + for g2 in node_groups[node2]: + cost += group_comm_cost[g1, g2] + node_pair_costs[(node1, node2)] = cost + + # Do the optimization iterations + improved = True + iterations = 0 + max_iterations = 20 + + while improved and iterations < max_iterations: + improved = False + iterations += 1 + + # Find the best swap + best_gain = 0 + best_swap = None + + for node1 in range(num_nodes): + for node2 in range(node1 + 1, num_nodes): + current_cost = node_pair_costs[(node1, node2)] + + # Try all pairs of groups between node1 and node2 + for g1_idx, g1 in enumerate(node_groups[node1]): + for g2_idx, g2 in enumerate(node_groups[node2]): + gain = 0 + + # Compute the gain from swapping g1 and g2 + for other_g1 in node_groups[node1]: + if other_g1 != g1: + gain += group_comm_cost[g1, other_g1] + gain -= group_comm_cost[g2, other_g1] + + for other_g2 in node_groups[node2]: + if other_g2 != g2: + gain += group_comm_cost[g2, other_g2] + gain -= group_comm_cost[g1, other_g2] + + if gain > best_gain: + best_gain = gain + best_swap = (node1, g1_idx, node2, g2_idx) + + if best_gain > 0 and best_swap: + node1, g1_idx, node2, g2_idx = best_swap + g1 = node_groups[node1][g1_idx] + g2 = node_groups[node2][g2_idx] + + # update node groups + node_groups[node1][g1_idx] = g2 + node_groups[node2][g2_idx] = g1 + + # update node pair costs + for n1, n2 in node_pair_costs: + if n1 == node1 or n1 == node2 or n2 == node1 or n2 == node2: + cost = 0 + for g_n1 in node_groups[n1]: + for g_n2 in node_groups[n2]: + cost += group_comm_cost[g_n1, g_n2] + node_pair_costs[(n1, n2)] = cost + + # swap physical expert mapping + for offset in range(group_size): + idx1 = g1 * group_size + offset + idx2 = g2 * group_size + offset + optimized_pphy2log[layer, idx1], optimized_pphy2log[layer, idx2] = \ + optimized_pphy2log[layer, idx2].item(), optimized_pphy2log[layer, idx1].item() + + improved = True + + print(f"Layer {layer} optimized in {iterations} iterations") + + return optimized_pphy2log + + +def rebalance_experts_hierarchical(weight: torch.Tensor, num_physical_experts: int, + num_groups: int, num_nodes: int, num_gpus: int, + comm_matrix: torch.Tensor = None): + """ + Parameters: + weight: [num_moe_layers, num_logical_experts] + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + comm_matrix: [num_logical_experts, num_logical_experts], communication cost between experts + + Returns: + physical_to_logical_map: [num_moe_layers, num_physical_experts] + logical_to_physical_map: [num_moe_layers, num_logical_experts, X] + logical_count: [num_moe_layers, num_logical_experts] + """ + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape)) + return inv + + # Step 1: pack groups to nodes + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) + log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1) + + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)).flatten(-2) + mlog2log = inverse(log2mlog) + + # Step 2: construct redundant experts within nodes + # [num_layers * num_nodes, num_logical_experts // num_nodes] + tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes) + phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes) + + # Step 3: pack physical_experts to GPUs + # [num_layers * num_nodes, num_physical_experts // num_nodes] + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes) + + + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + + pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + + torch.arange(0, num_logical_experts, num_logical_experts // num_nodes, + device=group_pack_index.device).view(1, -1, 1)).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + # Step 4: rearrange group placement according to communication cost + start_time = time.perf_counter() + if comm_matrix is not None: + pphy2log = optimize_group_placement(pphy2log, comm_matrix, num_nodes, num_gpus, group_size) + print(pphy2log) + print(pphy2log.shape) + end_time = time.perf_counter() + print(f"Group placement optimization time: {end_time - start_time:.4f} seconds") + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, + comm_matrix: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer with communication optimization. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all logical experts + num_replicas: number of physical experts, must be a multiple of `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + comm_matrix: [num_logical_experts, num_logical_experts], communication cost between experts + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert + expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert + """ + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + + if comm_matrix is not None: + comm_matrix = comm_matrix.float().cpu() + + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy with communication awareness + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus, comm_matrix) + else: + # use global load-balance policy with communication awareness + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus, comm_matrix) + + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt), + -1, dtype=torch.int64, device=logcnt.device) + log2phy.view(num_layers, -1).scatter_(-1, phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1)) + return phy2log, log2phy, logcnt \ No newline at end of file diff --git a/python/sglang/srt/eplb/eplb_manager.py b/python/sglang/srt/eplb/eplb_manager.py index e88a3d28e0f..3f2e31ffc3e 100644 --- a/python/sglang/srt/eplb/eplb_manager.py +++ b/python/sglang/srt/eplb/eplb_manager.py @@ -65,13 +65,14 @@ def rebalance(self): average_utilization_rate_over_window = dump_record_output[ "average_utilization_rate_over_window" ] + topk_history_data = dump_record_output["topk_history_data"] # Check whether rebalancing is needed if not self._check_rebalance_needed(average_utilization_rate_over_window): return expert_location_metadata = ExpertLocationMetadata.init_by_eplb( - self._server_args, self._model_runner.model_config, logical_count + self._server_args, self._model_runner.model_config, logical_count, topk_history_data ) update_layer_ids_chunks = self._compute_update_layer_ids_chunks() diff --git a/python/sglang/srt/eplb/expert_distribution.py b/python/sglang/srt/eplb/expert_distribution.py index 3faf981ef38..76b17bb6b4e 100644 --- a/python/sglang/srt/eplb/expert_distribution.py +++ b/python/sglang/srt/eplb/expert_distribution.py @@ -84,6 +84,9 @@ def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch): def on_select_experts(self, topk_ids: torch.Tensor): pass + def record_topk_ids(self, topk_ids: torch.Tensor): + pass + def on_deepep_dispatch_normal( self, local_physical_count_of_layer: List[int], @@ -192,6 +195,9 @@ def _on_forward_pass_end(self, forward_pass_id: int): def on_select_experts(self, topk_ids: torch.Tensor): self._on_hook("on_select_experts", topk_ids=topk_ids) + def record_topk_ids(self, topk_ids: torch.Tensor): + self._on_hook("record_topk_ids", topk_ids=topk_ids) + def on_deepep_dispatch_normal( self, local_physical_count_of_layer: List[int], @@ -285,6 +291,8 @@ def set_global_expert_distribution_recorder(value): class _SinglePassGatherer(ABC): + _TOP_K_NUM = 8 + @staticmethod def init_new( server_args: ServerArgs, @@ -326,6 +334,9 @@ def on_forward_pass_start(self, forward_batch: ForwardBatch): def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): pass + def record_topk_ids(self, layer_idx:int, topk_ids: torch.Tensor): + pass + def on_deepep_dispatch_normal( self, layer_idx: int, @@ -350,7 +361,7 @@ def collect(self) -> Dict: class _DetailSinglePassGatherer(_SinglePassGatherer): # DeepSeek V3 has this value; should generalize later - _TOP_K_NUM = 8 + # _TOP_K_NUM = 8 def __init__( self, @@ -365,7 +376,7 @@ def __init__( expert_location_metadata.num_layers, # TODO determine the max number server_args.chunked_prefill_size * 8, - self._TOP_K_NUM, + super()._TOP_K_NUM, ), dtype=torch.int32, device=server_args.device, @@ -472,9 +483,18 @@ def __init__(self, *args, enable_global_physical_experts: bool, **kwargs): dtype=torch.int, device=device, ) + self._topk_ids_data = torch.zeros( + ( + self._expert_location_metadata.num_layers, + super()._TOP_K_NUM, + ), + dtype=torch.int, + device=device, + ) def reset(self): self._data[...] = 0 + self._topk_ids_data[...] = -1 def collect(self) -> Dict: if self._enable_global_physical_experts: @@ -488,7 +508,7 @@ def collect(self) -> Dict: num_physical_experts=self._expert_location_metadata.num_physical_experts, ) - return dict(global_physical_count=global_physical_count) + return dict(global_physical_count=global_physical_count, topk_ids=self._topk_ids_data.clone().cpu()) class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer): @@ -547,6 +567,15 @@ def on_deepep_dispatch_low_latency( # Most naive implementation, can optimize later self._data[layer_idx, :] += local_physical_count_of_layer + def record_topk_ids(self, layer_idx: int, topk_ids: torch.Tensor): + topk_ids = topk_ids.flatten() + if (topk_ids.shape[0] == super()._TOP_K_NUM): + self._topk_ids_data[layer_idx] = topk_ids + else: + logger.info(f"Expected shape: {super()._TOP_K_NUM}, got {topk_ids.shape[0]}") + # TODO should record full topk result for batch inference + self._topk_ids_data[layer_idx] = topk_ids[:super()._TOP_K_NUM] + def _convert_local_to_global_physical_count( local_physical_count: torch.Tensor, @@ -622,6 +651,7 @@ def dump(self, output_mode: _OutputMode): class _UtilizationRateAccumulatorMixin(_Accumulator): + _TOP_K_NUM = 8 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -764,6 +794,15 @@ def __init__(self, *args, **kwargs): dtype=torch.int32, device=self._server_args.device, ) + self._topk_ids_of_buffered_step = _Buffer.init_new( + item_shape=( + self._expert_location_metadata.num_layers, + super()._TOP_K_NUM, + ), + buffer_size=self._server_args.expert_distribution_recorder_buffer_size, + dtype=torch.int32, + device=self._server_args.device, + ) self._first_dump = True def append( @@ -777,10 +816,14 @@ def append( self._global_physical_count_of_buffered_step.append( single_pass_data["global_physical_count"] ) + self._topk_ids_of_buffered_step.append( + single_pass_data["topk_ids"] + ) def reset(self): super().reset() self._global_physical_count_of_buffered_step.reset() + self._topk_ids_of_buffered_step.reset() def dump(self, output_mode: _OutputMode): logical_count_of_buffered_step = _convert_global_physical_count_to_logical_count( @@ -802,6 +845,7 @@ def dump(self, output_mode: _OutputMode): rank=self._rank, logical_count=logical_count_of_buffered_step, average_utilization_rate_over_window=self._get_global_average_utilization_rate(), + topk_ids=self._topk_ids_of_buffered_step.get_all().clone().cpu(), ) if output_mode == "file": diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index ee5f2c7ca8b..cab6bd3acf7 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -27,6 +27,7 @@ from sglang.srt.eplb import eplb_algorithms from sglang.srt.model_loader import get_model_architecture +from sglang.srt.eplb.utils.comm_matrix_process import generate_comm_matrix if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig @@ -138,7 +139,7 @@ def init_by_mapping( @staticmethod def init_by_eplb( - server_args: ServerArgs, model_config: ModelConfig, logical_count: torch.Tensor + server_args: ServerArgs, model_config: ModelConfig, logical_count: torch.Tensor, topk_history_data: torch.Tensor ): if not isinstance(logical_count, torch.Tensor): logical_count = torch.tensor(logical_count) @@ -146,6 +147,16 @@ def init_by_eplb( logical_count = logical_count.unsqueeze(0) logical_count = logical_count.to(server_args.device) + if topk_history_data is None or topk_history_data.numel() == 0: + logger.info("No topk_history_data provided, skipping communication matrix computation.") + comm_matrix = None + else: + if not isinstance(topk_history_data, torch.Tensor): + topk_history_data = torch.tensor(topk_history_data) + topk_history_data = topk_history_data.to(server_args.device) + comm_matrix = generate_comm_matrix(topk_history_data, num_experts=256) + + common = ExpertLocationMetadata._init_common(server_args, model_config) if common is None: @@ -168,6 +179,7 @@ def init_by_eplb( num_groups=num_groups, num_nodes=num_nodes, ), + comm_matrix=comm_matrix, ) ) @@ -456,12 +468,19 @@ def compute_initial_expert_location_metadata( server_args, model_config, **data_dict ) elif "logical_count" in data_dict: - logger.info( - "init_expert_location from init_by_eplb using ServerArgs.init_expert_location" - ) - return ExpertLocationMetadata.init_by_eplb( - server_args, model_config, logical_count=data_dict["logical_count"] - ) + if "topk_ids" in data_dict: + logger.info("init_expert_location with topk_history_data") + return ExpertLocationMetadata.init_by_eplb( + server_args, + model_config, + logical_count=data_dict["logical_count"], + topk_history_data=data_dict["topk_ids"], + ) + else: + logger.info("init_expert_location without topk_history_data") + return ExpertLocationMetadata.init_by_eplb( + server_args, model_config, logical_count=data_dict["logical_count"], topk_history_data=None + ) else: raise NotImplementedError( f"Unknown init_expert_location format ({list(data_dict.keys())=})" diff --git a/python/sglang/srt/eplb/utils/comm_matrix_process.py b/python/sglang/srt/eplb/utils/comm_matrix_process.py new file mode 100644 index 00000000000..72f38e15ef2 --- /dev/null +++ b/python/sglang/srt/eplb/utils/comm_matrix_process.py @@ -0,0 +1,41 @@ +import numpy as np +import torch +#import numba + +# Could enable numba for performance optimization if needed +# @numba.jit(nopython=True, parallel=True, cache=True) +def compute_expert_co_occurrence_matrix(history_data, num_experts): + """Compute expert co-occurrence matrix from history data.""" + history_data = history_data.cpu().numpy() + num_samples, num_layers, top_k = history_data.shape + expert_co_occurrence = np.zeros((num_layers, num_experts, num_experts), dtype=np.int64) + + for sample_idx in range(num_samples): + for layer_idx in range(num_layers): + experts = history_data[sample_idx, layer_idx] + if (-1 in experts) or (len(set(experts)) < top_k): + continue + for i in range(top_k): + for j in range(i+1, top_k): + expert_i = experts[i] + expert_j = experts[j] + + if expert_i < num_experts and expert_j < num_experts: + expert_co_occurrence[layer_idx, expert_i, expert_j] += 1 + expert_co_occurrence[layer_idx, expert_j, expert_i] += 1 + co_occurrence = torch.tensor(expert_co_occurrence, dtype=torch.int64) + + return co_occurrence + +def generate_comm_matrix(history_data, num_experts): + """ + Process input tensor to compute expert co-occurrence matrix and generate communication matrix + """ + + if history_data.numel() == 0: + return None + co_occurrence = compute_expert_co_occurrence_matrix(history_data, num_experts) + comm_matrix = co_occurrence.float() + comm_matrix = comm_matrix / comm_matrix.max() # Normalize to [0,1] + + return comm_matrix diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index 598f513316d..21d8d2197e4 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -510,6 +510,14 @@ def dispatch_a( ): buffer = self._get_buffer() topk_idx = topk_idx.to(torch.int64) + + if topk_idx.numel() > 0: + get_global_expert_distribution_recorder().record_topk_ids( + topk_idx + ) + else: + logger.warning("topk_idx is empty in DeepEP low latency dispatch.") + expected_m = ( hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1] + self.num_experts