8
8
#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_
9
9
10
10
#include " ../config.h"
11
+ #include " ../draft_token_workspace_manager.h"
11
12
#include " ../engine_state.h"
12
13
#include " ../event_trace_recorder.h"
13
14
#include " ../model.h"
@@ -72,15 +73,16 @@ class EngineAction : public ObjectRef {
72
73
* \param logit_processor The logit processor.
73
74
* \param sampler The sampler to sample new tokens.
74
75
* \param model_workspaces The workspace of each model.
76
+ * \param draft_token_workspace_manager The draft token workspace manager.
75
77
* \param engine_config The engine config.
76
78
* \param trace_recorder The event trace recorder for requests.
77
79
* \return The created action object.
78
80
*/
79
- static EngineAction EagleNewRequestPrefill (Array<Model> models, LogitProcessor logit_processor,
80
- Sampler sampler,
81
- std::vector<ModelWorkspace> model_workspaces,
82
- EngineConfig engine_config,
83
- Optional<EventTraceRecorder> trace_recorder);
81
+ static EngineAction EagleNewRequestPrefill (
82
+ Array<Model> models, LogitProcessor logit_processor, Sampler sampler,
83
+ std::vector<ModelWorkspace> model_workspaces,
84
+ DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config,
85
+ Optional<EventTraceRecorder> trace_recorder);
84
86
/* !
85
87
* \brief Create the action that runs one-step decode for requests in the
86
88
* `running_queue` of engine state. Preempt low-priority requests
@@ -104,13 +106,16 @@ class EngineAction : public ObjectRef {
104
106
* \param models The model to run decode in. When there are multiple
105
107
* models, the `Step` function of the created action will not take effect.
106
108
* \param sampler The sampler to sample new tokens.
109
+ * \param model_workspaces The workspace of each model.
110
+ * \param draft_token_workspace_manager The draft token workspace manager.
107
111
* \param trace_recorder The event trace recorder for requests.
108
112
* \param draft_length The number of draft proposal rounds.
109
113
* \return The created action object.
110
114
*/
111
115
static EngineAction BatchDraft (Array<Model> models, LogitProcessor logit_processor,
112
- Sampler sampler, Optional<EventTraceRecorder> trace_recorder,
113
- int draft_length = 4 );
116
+ Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
117
+ DraftTokenWorkspaceManager draft_token_workspace_manager,
118
+ Optional<EventTraceRecorder> trace_recorder, int draft_length = 4 );
114
119
115
120
/* !
116
121
* \brief Create the action that runs one-step speculative draft proposal for
@@ -120,12 +125,14 @@ class EngineAction : public ObjectRef {
120
125
* models, the `Step` function of the created action will not take effect.
121
126
* \param sampler The sampler to sample new tokens.
122
127
* \param model_workspaces The workspace of each model.
128
+ * \param draft_token_workspace_manager The draft token workspace manager.
123
129
* \param trace_recorder The event trace recorder for requests.
124
130
* \param draft_length The number of draft proposal rounds.
125
131
* \return The created action object.
126
132
*/
127
133
static EngineAction EagleBatchDraft (Array<Model> models, LogitProcessor logit_processor,
128
134
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
135
+ DraftTokenWorkspaceManager draft_token_workspace_manager,
129
136
Optional<EventTraceRecorder> trace_recorder,
130
137
int draft_length = 4 );
131
138
@@ -135,13 +142,17 @@ class EngineAction : public ObjectRef {
135
142
* accordingly when it is impossible to decode all the running requests.
136
143
* \param models The model to run decode in. When there are multiple
137
144
* models, the `Step` function of the created action will not take effect.
145
+ * \param model_workspaces The workspace of each model.
146
+ * \param draft_token_workspace_manager The draft token workspace manager.
138
147
* \param sampler The sampler to sample new tokens.
139
148
* \param engine_config The engine config.
140
149
* \param trace_recorder The event trace recorder for requests.
141
150
* \return The created action object.
142
151
*/
143
152
static EngineAction BatchVerify (Array<Model> models, LogitProcessor logit_processor,
144
- Sampler sampler, EngineConfig engine_config,
153
+ Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
154
+ DraftTokenWorkspaceManager draft_token_workspace_manager,
155
+ EngineConfig engine_config,
145
156
Optional<EventTraceRecorder> trace_recorder);
146
157
147
158
/* !
@@ -152,13 +163,15 @@ class EngineAction : public ObjectRef {
152
163
* models, the `Step` function of the created action will not take effect.
153
164
* \param sampler The sampler to sample new tokens.
154
165
* \param model_workspaces The workspace of each model.
166
+ * \param draft_token_workspace_manager The draft token workspace manager.
155
167
* \param engine_config The engine config.
156
168
* \param trace_recorder The event trace recorder for requests.
157
169
* \return The created action object.
158
170
*/
159
171
static EngineAction EagleBatchVerify (Array<Model> models, LogitProcessor logit_processor,
160
172
Sampler sampler,
161
173
std::vector<ModelWorkspace> model_workspaces,
174
+ DraftTokenWorkspaceManager draft_token_workspace_manager,
162
175
EngineConfig engine_config,
163
176
Optional<EventTraceRecorder> trace_recorder);
164
177
0 commit comments