Skip to content

Commit 759e737

Browse files
Squashed commit of lwilkison/dbo-plus-plus changes relative to nm/sage/dbo-full-cudagraphs
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 2def98d commit 759e737

File tree

17 files changed

+845
-299
lines changed

17 files changed

+845
-299
lines changed

docs/design/fused_moe_modular_kernel.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExperts
5757
The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions.
5858
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)
5959

60+
6061
![](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png "FusedMoEPrepareAndFinalize Blocks")
6162

6263
### FusedMoEPermuteExpertsUnpermute

examples/offline_inference/data_parallel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,13 @@ def parse_args():
9090
parser.add_argument(
9191
"--enable-microbatching",
9292
action="store_true",
93-
help=("Enable microbatched execution"),
93+
help=("Enable microbatched execution")
94+
)
95+
parser.add_argument(
96+
"--compilation-config",
97+
type=int,
98+
default=0,
99+
help=("Compilation optimization (O) level 0-3."),
94100
)
95101
parser.add_argument(
96102
"--compilation-config",

vllm/config/parallel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,13 @@ class ParallelConfig:
135135
request is greater than this threshold, microbatching will be used.
136136
Otherwise, the request will be processed in a single batch."""
137137

138+
microbatch_schedule: Literal["mlp_overlap", "ATTN_SHARED_OVERLAP"] = "mlp_overlap"
139+
"""Schedule policy for microbatch overlap coordination.
140+
141+
- "mlp_overlap": overlap MLP compute and communication across ubatches
142+
- "ATTN_SHARED_OVERLAP": overlap MLA attention and communication across ubatches
143+
"""
144+
138145
ray_workers_use_nsight: bool = False
139146
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
140147

vllm/engine/arg_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ class EngineArgs:
317317
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
318318
enable_microbatching: bool = ParallelConfig.enable_microbatching
319319
microbatching_token_threshold: int = ParallelConfig.microbatching_token_threshold
320+
microbatch_schedule: str = ParallelConfig.microbatch_schedule
320321
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
321322
enable_eplb: bool = ParallelConfig.enable_eplb
322323
num_redundant_experts: int = EPLBConfig.num_redundant_experts
@@ -682,6 +683,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
682683
**parallel_kwargs["enable_microbatching"])
683684
parallel_group.add_argument("--microbatching-token-threshold",
684685
**parallel_kwargs["microbatching_token_threshold"])
686+
parallel_group.add_argument(
687+
"--microbatch-schedule",
688+
dest="microbatch_schedule",
689+
**parallel_kwargs["microbatch_schedule"])
690+
parallel_group.add_argument("--enable-async-comms",
691+
**parallel_kwargs["enable_async_comms"])
685692
parallel_group.add_argument("--enable-eplb",
686693
**parallel_kwargs["enable_eplb"])
687694
parallel_group.add_argument("--eplb-config",
@@ -1304,6 +1311,7 @@ def create_engine_config(
13041311
enable_expert_parallel=self.enable_expert_parallel,
13051312
enable_microbatching=self.enable_microbatching,
13061313
microbatching_token_threshold=self.microbatching_token_threshold,
1314+
microbatch_schedule=self.microbatch_schedule,
13071315
enable_eplb=self.enable_eplb,
13081316
eplb_config=self.eplb_config,
13091317
max_parallel_loading_workers=self.max_parallel_loading_workers,

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 86 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -58,31 +58,63 @@ def _get_combine_config(self) -> Optional[deep_ep.Config]:
5858
return None
5959
return deep_ep.Buffer.get_combine_config(self.dp_size)
6060

61-
def _do_dispatch(
61+
def _create_prepare_ops(
6262
self,
63-
tokens: torch.Tensor,
64-
token_scales: Optional[torch.Tensor],
65-
rank_topk_ids: torch.Tensor,
66-
rank_topk_weights: torch.Tensor,
67-
num_experts: int,
63+
a1: torch.Tensor,
6864
a1_scale: Optional[torch.Tensor],
65+
a2_scale: Optional[torch.Tensor],
66+
topk_weights: torch.Tensor,
67+
topk_ids: torch.Tensor,
68+
num_experts: int,
69+
expert_map: Optional[torch.Tensor],
70+
apply_router_weight_on_input: bool,
6971
quant_config: FusedMoEQuantConfig,
70-
) -> Callable:
72+
) -> mk.PrepareResultType:
73+
74+
# Apply router weights on input if requested (only supports topk=1)
75+
if apply_router_weight_on_input:
76+
topk = topk_ids.size(1)
77+
assert topk == 1, (
78+
"apply_router_weight_on_input is only implemented for topk=1")
79+
a1 = a1 * topk_weights.to(a1.dtype)
80+
81+
# Quantize prior to dispatch for block-quantized path, otherwise defer
82+
if quant_config.is_block_quantized:
83+
a1q, a1q_scale = moe_kernel_quantize_input(
84+
a1,
85+
a1_scale,
86+
quant_dtype=quant_config.quant_dtype,
87+
per_act_token_quant=quant_config.per_act_token_quant,
88+
block_shape=quant_config.block_shape,
89+
)
90+
if a1q_scale is not None and a1q_scale.numel() == 1:
91+
a1q_scale = a1q_scale.view(1, 1)
92+
a1_post_scale = None
93+
else:
94+
a1q = a1
95+
a1q_scale = None
96+
a1_post_scale = a1_scale
7197

72-
has_scales = token_scales is not None
98+
# Inline dispatch (sync send+recv)
99+
has_scales = a1q_scale is not None
73100

74101
(num_tokens_per_rank, num_tokens_per_rdma_rank,
75102
dispatch_expert_num_tokens, is_token_in_rank,
76103
event) = self.buffer.get_dispatch_layout(
77-
topk_idx=rank_topk_ids,
104+
topk_idx=topk_ids,
78105
num_experts=num_experts,
79106
previous_event=None,
80107
async_finish=False,
81108
allocate_on_comm_stream=False)
82109

83-
token_data = tokens
110+
token_data: Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]
111+
token_data = a1q
84112
if has_scales:
85-
token_data = (tokens, token_scales)
113+
token_data = (a1q, a1q_scale)
114+
115+
########################################################################
116+
yield # Pre-dispatch done
117+
########################################################################
86118

87119
(
88120
token_data, expert_topk_ids, expert_topk_weights,
@@ -94,10 +126,8 @@ def _do_dispatch(
94126
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
95127
is_token_in_rank=is_token_in_rank,
96128
num_tokens_per_expert=dispatch_expert_num_tokens,
97-
topk_idx=rank_topk_ids,
98-
topk_weights=rank_topk_weights,
99-
# expert_alignment rounds the number of tokens per expert
100-
# to this value.
129+
topk_idx=topk_ids,
130+
topk_weights=topk_weights,
101131
expert_alignment=1,
102132
config=self._get_dispatch_config(),
103133
previous_event=None,
@@ -131,9 +161,12 @@ def _receiver(
131161
if self.async_prepare:
132162
event.current_stream_wait()
133163

164+
# Unpack token data
134165
if has_scales:
166+
assert isinstance(token_data, tuple)
135167
expert_x, expert_x_scale = token_data
136168
else:
169+
assert isinstance(token_data, torch.Tensor)
137170
expert_x, expert_x_scale = token_data, None
138171

139172
# The existing MOE kernels assume that all entries of topk_ids are
@@ -174,58 +207,14 @@ def _receiver(
174207
per_act_token_quant=False,
175208
block_shape=quant_config.block_shape)
176209

210+
########################################################################
211+
yield # Dispatch send+recv done (sync)
212+
########################################################################
213+
177214
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
178215
expert_topk_weights)
179216

180-
def supports_async(self) -> bool:
181-
return True
182-
183-
def prepare_async(
184-
self,
185-
a1: torch.Tensor,
186-
a1_scale: Optional[torch.Tensor],
187-
a2_scale: Optional[torch.Tensor],
188-
topk_weights: torch.Tensor,
189-
topk_ids: torch.Tensor,
190-
num_experts: int,
191-
expert_map: Optional[torch.Tensor],
192-
apply_router_weight_on_input: bool,
193-
quant_config: FusedMoEQuantConfig,
194-
) -> Callable:
195-
196-
if apply_router_weight_on_input:
197-
topk = topk_ids.size(1)
198-
# TODO: this only works for topK=1, will need to update for topK>1
199-
assert topk == 1, (
200-
"apply_router_weight_on_input is only implemented for topk=1")
201-
a1 = a1 * topk_weights.to(a1.dtype)
202-
203-
if quant_config.is_block_quantized:
204-
# Quant and Dispatch
205-
a1q, a1q_scale = moe_kernel_quantize_input(
206-
a1,
207-
a1_scale,
208-
quant_dtype=quant_config.quant_dtype,
209-
per_act_token_quant=quant_config.per_act_token_quant,
210-
block_shape=quant_config.block_shape,
211-
)
212-
if a1q_scale is not None and a1q_scale.numel() == 1:
213-
a1q_scale = a1q_scale.view(1, 1)
214-
a1_post_scale = None
215-
else:
216-
a1q = a1
217-
a1q_scale = None
218-
a1_post_scale = a1_scale
219-
220-
return self._do_dispatch(tokens=a1q,
221-
token_scales=a1q_scale,
222-
rank_topk_ids=topk_ids,
223-
rank_topk_weights=topk_weights,
224-
num_experts=num_experts,
225-
a1_scale=a1_post_scale,
226-
quant_config=quant_config)
227-
228-
def prepare(
217+
def create_prepare_ops(
229218
self,
230219
a1: torch.Tensor,
231220
a1_scale: Optional[torch.Tensor],
@@ -236,14 +225,14 @@ def prepare(
236225
expert_map: Optional[torch.Tensor],
237226
apply_router_weight_on_input: bool,
238227
quant_config: FusedMoEQuantConfig,
239-
) -> mk.PrepareResultType:
240-
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
241-
topk_ids, num_experts, expert_map,
242-
apply_router_weight_on_input,
243-
quant_config)
244-
return receiver()
245-
246-
def finalize(
228+
) -> mk.SyncPrepareOps:
229+
return mk.SyncPrepareOps.from_generator(
230+
self._create_prepare_ops(a1, a1_scale, a2_scale, topk_weights,
231+
topk_ids, num_experts, expert_map,
232+
apply_router_weight_on_input,
233+
quant_config))
234+
235+
def _create_finalize_ops(
247236
self,
248237
output: torch.Tensor,
249238
fused_expert_output: torch.Tensor,
@@ -268,6 +257,10 @@ def finalize(
268257
apply_router_weight_on_input=apply_router_weight_on_input,
269258
)
270259

260+
########################################################################
261+
yield # Pre-combine done
262+
########################################################################
263+
271264
combined_x, _, event = self.buffer.combine(
272265
x=fused_expert_output,
273266
handle=self.handle,
@@ -278,3 +271,24 @@ def finalize(
278271
allocate_on_comm_stream=False)
279272
# Respect inplace outputs.
280273
output.copy_(combined_x, non_blocking=True)
274+
275+
########################################################################
276+
yield # Combine send-recv done
277+
########################################################################
278+
279+
return None
280+
281+
def create_finalize_ops(
282+
self,
283+
output: torch.Tensor,
284+
fused_expert_output: torch.Tensor,
285+
topk_weights: torch.Tensor,
286+
topk_ids: torch.Tensor,
287+
apply_router_weight_on_input: bool,
288+
weight_and_reduce_impl: mk.TopKWeightAndReduce,
289+
) -> mk.SyncFinalizeOps:
290+
return mk.SyncFinalizeOps.from_generator(
291+
self._create_finalize_ops(output, fused_expert_output,
292+
topk_weights, topk_ids,
293+
apply_router_weight_on_input,
294+
weight_and_reduce_impl))

0 commit comments

Comments
 (0)