Skip to content

Commit afde65c

Browse files
authored
[Serving] Introduce DraftTokenWorkspaceManager (mlc-ai#2250)
Using DraftTokenWorkspaceManager to maintain workspace for draft probs and hidden states (if needed). This allows states of the draft token to be kept fully on GPU.
1 parent 2489964 commit afde65c

26 files changed

+627
-231
lines changed
+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*!
2+
* Copyright (c) 2024 by Contributors
3+
* \file serve/draft_token_workspace_manager.cc
4+
*/
5+
6+
#include "draft_token_workspace_manager.h"
7+
8+
#include "model.h"
9+
10+
namespace mlc {
11+
namespace llm {
12+
namespace serve {
13+
14+
DraftTokenWorkspaceManagerObj::DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size,
15+
int hidden_size,
16+
DLDataType hidden_states_dtype,
17+
DLDevice device,
18+
const FunctionTable& ft)
19+
: max_num_tokens_(max_num_tokens),
20+
vocab_size_(vocab_size),
21+
hidden_size_(hidden_size),
22+
hidden_states_dtype_(hidden_states_dtype),
23+
device_(device),
24+
ft_(ft) {
25+
free_slots_.resize(max_num_tokens);
26+
std::iota(free_slots_.begin(), free_slots_.end(), 0);
27+
}
28+
29+
void DraftTokenWorkspaceManagerObj::AllocSlots(int num_slots, std::vector<int>* result) {
30+
ICHECK_LE(num_slots, free_slots_.size());
31+
result->assign(free_slots_.rbegin(), free_slots_.rbegin() + num_slots);
32+
std::vector<int> allocated(free_slots_.begin(), free_slots_.begin() + num_slots);
33+
free_slots_.resize(free_slots_.size() - num_slots);
34+
}
35+
36+
void DraftTokenWorkspaceManagerObj::FreeSlots(const std::vector<int>& slots) {
37+
std::copy(slots.begin(), slots.end(), std::back_inserter(free_slots_));
38+
}
39+
40+
void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace,
41+
bool require_hidden_states) {
42+
workspace->draft_probs =
43+
NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_);
44+
workspace->draft_probs_storage =
45+
NDArray::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_);
46+
if (require_hidden_states) {
47+
workspace->draft_hidden_states_storage =
48+
NDArray::Empty({max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_);
49+
}
50+
}
51+
52+
} // namespace serve
53+
} // namespace llm
54+
} // namespace mlc
+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*!
2+
* Copyright (c) 2024 by Contributors
3+
* \file serve/draft_token_workspace_manager.h
4+
*/
5+
6+
#ifndef MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_
7+
#define MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_
8+
#include <tvm/runtime/device_api.h>
9+
10+
#include <numeric>
11+
#include <optional>
12+
#include <vector>
13+
14+
#include "data.h"
15+
#include "function_table.h"
16+
namespace mlc {
17+
namespace llm {
18+
namespace serve {
19+
20+
using tvm::Device;
21+
using namespace tvm::runtime;
22+
23+
struct ModelWorkspace;
24+
25+
/*!
26+
* \brief Managing the workspace for draft token generation.
27+
*
28+
* The workspace is used to store the associated states for each draft token, including the
29+
* probability distribution of the draft token, the hidden states, etc. The workspace manager
30+
* maintains a pool of slots for the draft tokens to store the states.
31+
*/
32+
class DraftTokenWorkspaceManagerObj : public Object {
33+
public:
34+
/*!
35+
* \brief Constructor
36+
* \param max_num_tokens The maximum number of draft tokens that can be stored in the workspace.
37+
* \param vocab_size The size of the vocabulary.
38+
* \param hidden_size The size of the hidden states.
39+
* \param hidden_states_dtype The data type of the hidden states.
40+
* \param device The device running the model.
41+
* \param ft The function table.
42+
*/
43+
DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size, int hidden_size,
44+
DLDataType hidden_states_dtype, DLDevice device,
45+
const FunctionTable& ft);
46+
47+
/*!
48+
* \brief Allocate the workspace for draft tokens and update `ModelWorkspace` data structure.
49+
* \param workspace The object to stored the allocated draft token workspace.
50+
* \param require_hidden_states Whether to allocate workspace for the hidden states.
51+
*/
52+
void AllocWorkspace(ModelWorkspace* workspace, bool require_hidden_states);
53+
54+
/*!
55+
* \brief Allocate slots for the draft tokens.
56+
* \param num_slots The number of slots to allocate.
57+
* \param result The vector to store the allocated slots.
58+
*/
59+
void AllocSlots(int num_slots, std::vector<int>* result);
60+
61+
/*!
62+
* \brief Free the slots.
63+
* \param slots The slots to free.
64+
*/
65+
void FreeSlots(const std::vector<int>& slots);
66+
67+
static constexpr const char* _type_key = "mlc.serve.DraftTokenWorkspaceManager";
68+
69+
private:
70+
std::vector<int> free_slots_;
71+
int max_num_tokens_;
72+
int vocab_size_;
73+
int hidden_size_;
74+
DataType hidden_states_dtype_;
75+
DLDevice device_;
76+
const FunctionTable& ft_;
77+
};
78+
79+
class DraftTokenWorkspaceManager : public ObjectRef {
80+
public:
81+
DraftTokenWorkspaceManager(int max_num_tokens, int vocab_size, int hidden_size,
82+
DLDataType hidden_states_dtype, DLDevice device,
83+
const FunctionTable& ft) {
84+
data_ = make_object<DraftTokenWorkspaceManagerObj>(max_num_tokens, vocab_size, hidden_size,
85+
hidden_states_dtype, device, ft);
86+
}
87+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DraftTokenWorkspaceManager, ObjectRef,
88+
DraftTokenWorkspaceManagerObj);
89+
};
90+
91+
} // namespace serve
92+
} // namespace llm
93+
} // namespace mlc
94+
95+
#endif // MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_

cpp/serve/engine.cc

+33-22
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,13 @@ class EngineImpl : public Engine {
101101
}
102102

103103
int max_num_tokens = engine_config->max_num_sequence;
104+
DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr};
104105
if (engine_config->speculative_mode != SpeculativeMode::kDisable) {
105106
max_num_tokens *= engine_config->spec_draft_length + 1;
107+
draft_token_workspace_manager = models_[0]->CreateDraftTokenWorkspaceManager(max_num_tokens);
108+
draft_token_workspace_manager->AllocWorkspace(
109+
&model_workspaces_[0],
110+
/*require_hidden_states=*/engine_config->speculative_mode == SpeculativeMode::kEagle);
106111
}
107112
LogitProcessor logit_processor =
108113
this->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder);
@@ -114,30 +119,36 @@ class EngineImpl : public Engine {
114119
ICHECK_GT(this->models_.size(), 1U);
115120
switch (engine_config->speculative_mode) {
116121
case SpeculativeMode::kEagle:
117-
this->actions_ = {EngineAction::EagleNewRequestPrefill(this->models_, //
118-
logit_processor, //
119-
sampler, //
120-
this->model_workspaces_, //
121-
engine_config, //
122-
this->trace_recorder_),
123-
EngineAction::EagleBatchDraft(
124-
this->models_, logit_processor, sampler, this->model_workspaces_,
125-
this->trace_recorder_, engine_config->spec_draft_length),
126-
EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler,
127-
this->model_workspaces_, engine_config,
128-
this->trace_recorder_)};
122+
this->actions_ = {
123+
EngineAction::EagleNewRequestPrefill(this->models_, //
124+
logit_processor, //
125+
sampler, //
126+
this->model_workspaces_, //
127+
draft_token_workspace_manager, //
128+
engine_config, //
129+
this->trace_recorder_),
130+
EngineAction::EagleBatchDraft(this->models_, logit_processor, sampler,
131+
this->model_workspaces_, draft_token_workspace_manager,
132+
this->trace_recorder_,
133+
engine_config->spec_draft_length),
134+
EngineAction::EagleBatchVerify(this->models_, logit_processor, sampler,
135+
this->model_workspaces_, draft_token_workspace_manager,
136+
engine_config, this->trace_recorder_)};
129137
break;
130138
default:
131-
this->actions_ = {EngineAction::NewRequestPrefill(this->models_, //
132-
logit_processor, //
133-
sampler, //
134-
this->model_workspaces_, //
135-
engine_config, //
136-
this->trace_recorder_),
137-
EngineAction::BatchDraft(this->models_, logit_processor, sampler,
138-
this->trace_recorder_),
139-
EngineAction::BatchVerify(this->models_, logit_processor, sampler,
140-
engine_config, this->trace_recorder_)};
139+
this->actions_ = {
140+
EngineAction::NewRequestPrefill(this->models_, //
141+
logit_processor, //
142+
sampler, //
143+
this->model_workspaces_, //
144+
engine_config, //
145+
this->trace_recorder_),
146+
EngineAction::BatchDraft(this->models_, logit_processor, sampler,
147+
this->model_workspaces_, draft_token_workspace_manager,
148+
this->trace_recorder_),
149+
EngineAction::BatchVerify(this->models_, logit_processor, sampler,
150+
this->model_workspaces_, draft_token_workspace_manager,
151+
engine_config, this->trace_recorder_)};
141152
}
142153
} else {
143154
this->actions_ = {EngineAction::NewRequestPrefill(this->models_, //

cpp/serve/engine_actions/action.h

+21-8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_
99

1010
#include "../config.h"
11+
#include "../draft_token_workspace_manager.h"
1112
#include "../engine_state.h"
1213
#include "../event_trace_recorder.h"
1314
#include "../model.h"
@@ -72,15 +73,16 @@ class EngineAction : public ObjectRef {
7273
* \param logit_processor The logit processor.
7374
* \param sampler The sampler to sample new tokens.
7475
* \param model_workspaces The workspace of each model.
76+
* \param draft_token_workspace_manager The draft token workspace manager.
7577
* \param engine_config The engine config.
7678
* \param trace_recorder The event trace recorder for requests.
7779
* \return The created action object.
7880
*/
79-
static EngineAction EagleNewRequestPrefill(Array<Model> models, LogitProcessor logit_processor,
80-
Sampler sampler,
81-
std::vector<ModelWorkspace> model_workspaces,
82-
EngineConfig engine_config,
83-
Optional<EventTraceRecorder> trace_recorder);
81+
static EngineAction EagleNewRequestPrefill(
82+
Array<Model> models, LogitProcessor logit_processor, Sampler sampler,
83+
std::vector<ModelWorkspace> model_workspaces,
84+
DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config,
85+
Optional<EventTraceRecorder> trace_recorder);
8486
/*!
8587
* \brief Create the action that runs one-step decode for requests in the
8688
* `running_queue` of engine state. Preempt low-priority requests
@@ -104,13 +106,16 @@ class EngineAction : public ObjectRef {
104106
* \param models The model to run decode in. When there are multiple
105107
* models, the `Step` function of the created action will not take effect.
106108
* \param sampler The sampler to sample new tokens.
109+
* \param model_workspaces The workspace of each model.
110+
* \param draft_token_workspace_manager The draft token workspace manager.
107111
* \param trace_recorder The event trace recorder for requests.
108112
* \param draft_length The number of draft proposal rounds.
109113
* \return The created action object.
110114
*/
111115
static EngineAction BatchDraft(Array<Model> models, LogitProcessor logit_processor,
112-
Sampler sampler, Optional<EventTraceRecorder> trace_recorder,
113-
int draft_length = 4);
116+
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
117+
DraftTokenWorkspaceManager draft_token_workspace_manager,
118+
Optional<EventTraceRecorder> trace_recorder, int draft_length = 4);
114119

115120
/*!
116121
* \brief Create the action that runs one-step speculative draft proposal for
@@ -120,12 +125,14 @@ class EngineAction : public ObjectRef {
120125
* models, the `Step` function of the created action will not take effect.
121126
* \param sampler The sampler to sample new tokens.
122127
* \param model_workspaces The workspace of each model.
128+
* \param draft_token_workspace_manager The draft token workspace manager.
123129
* \param trace_recorder The event trace recorder for requests.
124130
* \param draft_length The number of draft proposal rounds.
125131
* \return The created action object.
126132
*/
127133
static EngineAction EagleBatchDraft(Array<Model> models, LogitProcessor logit_processor,
128134
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
135+
DraftTokenWorkspaceManager draft_token_workspace_manager,
129136
Optional<EventTraceRecorder> trace_recorder,
130137
int draft_length = 4);
131138

@@ -135,13 +142,17 @@ class EngineAction : public ObjectRef {
135142
* accordingly when it is impossible to decode all the running requests.
136143
* \param models The model to run decode in. When there are multiple
137144
* models, the `Step` function of the created action will not take effect.
145+
* \param model_workspaces The workspace of each model.
146+
* \param draft_token_workspace_manager The draft token workspace manager.
138147
* \param sampler The sampler to sample new tokens.
139148
* \param engine_config The engine config.
140149
* \param trace_recorder The event trace recorder for requests.
141150
* \return The created action object.
142151
*/
143152
static EngineAction BatchVerify(Array<Model> models, LogitProcessor logit_processor,
144-
Sampler sampler, EngineConfig engine_config,
153+
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
154+
DraftTokenWorkspaceManager draft_token_workspace_manager,
155+
EngineConfig engine_config,
145156
Optional<EventTraceRecorder> trace_recorder);
146157

147158
/*!
@@ -152,13 +163,15 @@ class EngineAction : public ObjectRef {
152163
* models, the `Step` function of the created action will not take effect.
153164
* \param sampler The sampler to sample new tokens.
154165
* \param model_workspaces The workspace of each model.
166+
* \param draft_token_workspace_manager The draft token workspace manager.
155167
* \param engine_config The engine config.
156168
* \param trace_recorder The event trace recorder for requests.
157169
* \return The created action object.
158170
*/
159171
static EngineAction EagleBatchVerify(Array<Model> models, LogitProcessor logit_processor,
160172
Sampler sampler,
161173
std::vector<ModelWorkspace> model_workspaces,
174+
DraftTokenWorkspaceManager draft_token_workspace_manager,
162175
EngineConfig engine_config,
163176
Optional<EventTraceRecorder> trace_recorder);
164177

cpp/serve/engine_actions/action_commons.cc

+9-4
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,10 @@ void ActionStepPostProcess(Array<Request> requests, EngineState estate, Array<Mo
142142
std::move(models), max_single_sequence_length);
143143
}
144144

145-
RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate,
146-
const Array<Model>& models,
147-
Optional<EventTraceRecorder> trace_recorder) {
145+
RequestStateEntry PreemptLastRunningRequestStateEntry(
146+
EngineState estate, const Array<Model>& models,
147+
Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager,
148+
Optional<EventTraceRecorder> trace_recorder) {
148149
ICHECK(!estate->running_queue.empty());
149150
Request request = estate->running_queue.back();
150151

@@ -168,8 +169,12 @@ RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate,
168169
// - Update `inputs` for future prefill.
169170
RECORD_EVENT(trace_recorder, rsentry->request->id, "preempt");
170171
rsentry->status = RequestStateStatus::kPending;
172+
std::vector<int> draft_token_slots;
171173
for (RequestModelState mstate : rsentry->mstates) {
172-
mstate->RemoveAllDraftTokens();
174+
if (draft_token_workspace_manager.defined()) {
175+
mstate->RemoveAllDraftTokens(&draft_token_slots);
176+
draft_token_workspace_manager.value()->FreeSlots(draft_token_slots);
177+
}
173178
std::vector<int32_t> committed_token_ids;
174179
committed_token_ids.reserve(mstate->committed_tokens.size());
175180
for (const SampleResult& committed_token : mstate->committed_tokens) {

cpp/serve/engine_actions/action_commons.h

+8-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_
88

99
#include "../../tokenizers.h"
10+
#include "../draft_token_workspace_manager.h"
1011
#include "../engine.h"
1112
#include "../engine_state.h"
1213
#include "../event_trace_recorder.h"
@@ -52,12 +53,14 @@ void ActionStepPostProcess(Array<Request> requests, EngineState estate, Array<Mo
5253
* If it is not in the waiting request queue, add it to the waiting queue.
5354
* \param estate The engine state to update due to preemption.
5455
* \param models The models to remove preempted requests from.
55-
* \param trace_recorder The event trace recorder for requests.
56-
* \return The preempted request state.
56+
* \param draft_token_workspace_manager The draft token workspace manager for requests. Must be
57+
* provided if speculative decoding is enabled. \param trace_recorder The event trace recorder for
58+
* requests. \return The preempted request state.
5759
*/
58-
RequestStateEntry PreemptLastRunningRequestStateEntry(EngineState estate,
59-
const Array<Model>& models,
60-
Optional<EventTraceRecorder> trace_recorder);
60+
RequestStateEntry PreemptLastRunningRequestStateEntry(
61+
EngineState estate, const Array<Model>& models,
62+
Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager,
63+
Optional<EventTraceRecorder> trace_recorder);
6164

6265
/*! \brief Get the running request entries from the engine state. */
6366
inline std::vector<RequestStateEntry> GetRunningRequestStateEntries(const EngineState& estate) {

cpp/serve/engine_actions/batch_decode.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class BatchDecodeActionObj : public EngineActionObj {
4848
running_rsentries = GetRunningRequestStateEntries(estate);
4949
while (!CanDecode(running_rsentries.size())) {
5050
RequestStateEntry preempted =
51-
PreemptLastRunningRequestStateEntry(estate, models_, trace_recorder_);
51+
PreemptLastRunningRequestStateEntry(estate, models_, NullOpt, trace_recorder_);
5252
if (preempted.same_as(running_rsentries.back())) {
5353
running_rsentries.pop_back();
5454
}

0 commit comments

Comments
 (0)