diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 37ab95e969b..aa8f01d2073 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -708,7 +708,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: self.model_inputs["not_need_stop"][0] = True self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer - def _initialize_forward_meta(self, step_use_cudagraph: bool = False): + def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0): """ Initialize forward meta and attention meta data """ @@ -744,7 +744,12 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False): for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) - self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.draft_model_use_cudagraph + # Notes(liuzichang): + # 1. CUDA Graph capture sizes must be recorded in descending order (large → small). + # 2. In multi-step execution, only the first step should be captured. + self.forward_meta.step_use_cudagraph = ( + step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run) + ) def _initialize_forward_meta_xpu(self): @@ -922,7 +927,9 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False) # Initialize forward meta data - self._initialize_forward_meta(step_use_cudagraph=step_use_cudagraph) + self._initialize_forward_meta( + step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep + ) self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False) # Padding inputs for cuda graph @@ -947,9 +954,10 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], share_inputs=self.model_inputs, ) - + # Note(liuzichang): + # paddle.clone would raise error 700 in cudaGraph mode if self.num_model_steps > 1: - self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"]) + self.last_seq_lens_this_time.copy_(self.model_inputs["seq_lens_this_time"], False) model_output = self.model( ids_remove_padding=self.model_inputs["ids_remove_padding"], diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index ca3538ff264..9990241b4af 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2155,51 +2155,12 @@ def capture_model(self) -> None: ), batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)), in_capturing=True, - expected_decode_len=self.speculative_config.num_speculative_tokens, + expected_decode_len=self.speculative_config.num_speculative_tokens * 2 + 1, accept_all_drafts=True, ) logger.info( - f"Warm up the Target model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}" + f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}" ) - if self.graph_opt_config.draft_model_use_cudagraph: - # Capture Draft Model without bsz 1 - # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph - for batch_size in sorted(capture_sizes, reverse=True): - if batch_size == 1: - logger.info("Skip token_num = 1, when capture Draft model for mtp") - else: - assert batch_size % 2 == 0 - self._dummy_run( - num_tokens=( - self.scheduler_config.max_num_seqs - if self.scheduler_config.splitwise_role == "decode" - else self.scheduler_config.max_num_batched_tokens - ), - batch_size=int(batch_size / 2), - in_capturing=True, - expected_decode_len=3, - accept_all_drafts=True, - ) - logger.info( - f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}" - ) - # Capture Draft Model with bsz 1 - if 1 in capture_sizes: - self._dummy_run( - num_tokens=( - self.scheduler_config.max_num_seqs - if self.scheduler_config.splitwise_role == "decode" - else self.scheduler_config.max_num_batched_tokens - ), - batch_size=int(1), - in_capturing=True, - expected_decode_len=3, - accept_all_drafts=False, - reject_all_drafts=True, - ) - logger.info( - f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}" - ) else: for batch_size in sorted(capture_sizes, reverse=True): self._dummy_run(