-
Notifications
You must be signed in to change notification settings - Fork 684
[Speculative Decoding]Support multi-step mtp with cudagraph #5886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Speculative Decoding]Support multi-step mtp with cudagraph #5886
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds CUDA graph support for multi-step MTP (Medusa Tree Proposal) in speculative decoding. The changes ensure proper CUDA graph capture behavior when executing multiple model steps during draft token generation.
Key Changes
- Modified
_initialize_forward_metato accept new parameters (is_dummy_run,substep) for controlling CUDA graph capture during multi-step execution - Updated CUDA graph logic to only capture the first substep during dummy runs (graph warmup phase)
- Replaced
paddle.clone()withcopy_()forlast_seq_lens_this_timeto avoid error 700 in CUDA graph mode
e74016f to
e352260
Compare
| # paddle.clone would raise error 700 in cudaGraph mode | ||
| if self.num_model_steps > 1: | ||
| self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"]) | ||
| self.last_seq_lens_this_time.copy_(self.model_inputs["seq_lens_this_time"], False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是个新增的模型输入吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不是新增的,多步下一直有,如果用 clone 会奇怪和 cudagraph 耦合
gongshaotian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #5886 +/- ##
==========================================
Coverage ? 65.96%
==========================================
Files ? 347
Lines ? 44410
Branches ? 6818
==========================================
Hits ? 29294
Misses ? 12956
Partials ? 2160
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.