5
5
#include " continuous_batching_impl.hpp"
6
6
#include " utils.hpp"
7
7
#include " utils/paged_attention_transformations.hpp"
8
+ #include " lora_helper.hpp"
8
9
9
10
namespace ov ::genai {
10
11
template <class ... Ts> struct overloaded : Ts... {using Ts::operator ()...;};
@@ -17,8 +18,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
17
18
const std::string& device,
18
19
const ov::AnyMap& properties,
19
20
const ov::genai::GenerationConfig& generation_config,
20
- bool is_validation_mode_enabled
21
- ) {
21
+ bool is_validation_mode_enabled) {
22
22
m_tokenizer = tokenizer;
23
23
m_generation_config = generation_config;
24
24
m_is_validation_mode_enabled = is_validation_mode_enabled;
@@ -33,22 +33,33 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
33
33
bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction ;
34
34
utils::apply_paged_attention_transformations (model, device_config, is_need_per_layer_cache_control);
35
35
36
- init (model, scheduler_config, compile_properties, device_config, core);
36
+ initialize_pipeline (model, scheduler_config, compile_properties, device_config, core);
37
37
}
38
38
39
39
void ContinuousBatchingPipeline::ContinuousBatchingImpl::_pull_awaiting_requests () {
40
40
std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex};
41
41
m_requests.insert (m_requests.end (), m_awaiting_requests.begin (), m_awaiting_requests.end ());
42
42
m_awaiting_requests.clear ();
43
+ m_pipeline_metrics.requests = m_requests.size ();
43
44
}
44
45
45
- void ContinuousBatchingPipeline::ContinuousBatchingImpl::init (
46
+ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline (
46
47
std::shared_ptr<ov::Model> model,
47
48
const SchedulerConfig& scheduler_config,
48
49
const ov::AnyMap& properties,
49
50
const DeviceConfig& device_config,
50
51
ov::Core& core) {
51
- auto compiled_model = core.compile_model (model, device_config.get_device (), properties);
52
+ ov::CompiledModel compiled_model;
53
+
54
+ // apply LoRA
55
+ if (auto filtered_properties = extract_adapters_from_properties (properties, &m_generation_config.adapters )) {
56
+ m_generation_config.adapters ->set_tensor_name_prefix (" base_model.model.model." );
57
+ m_adapter_controller = AdapterController (model, *m_generation_config.adapters , device_config.get_device ()); // TODO: Make the prefix name configurable
58
+ compiled_model = core.compile_model (model, device_config.get_device (), *filtered_properties);
59
+ } else {
60
+ compiled_model = core.compile_model (model, device_config.get_device (), properties);
61
+ }
62
+
52
63
ov::genai::utils::print_compiled_model_properties (compiled_model, " LLM with Paged Attention" );
53
64
ov::InferRequest infer_request = compiled_model.create_infer_request ();
54
65
@@ -68,9 +79,12 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init(
68
79
can_use_partial_preemption = false ;
69
80
}
70
81
m_scheduler = std::make_shared<Scheduler>(device_config.get_block_size (), m_cache_manager, updated_config, device_config.get_num_layers (), can_use_partial_preemption);
71
- // and finally create model runner
82
+
83
+ // model runner
72
84
bool is_use_cache_eviction = m_scheduler->get_config ().use_cache_eviction ;
73
85
m_model_runner = std::make_shared<ModelRunner>(infer_request, m_scheduler->get_block_size (), device_config.get_num_layers (), is_use_cache_eviction);
86
+
87
+ // sampler
74
88
m_sampler = std::make_shared<Sampler>(m_tokenizer);
75
89
m_sampler->set_seed (m_generation_config.rng_seed );
76
90
@@ -94,6 +108,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
94
108
m_scheduler->get_block_size (),
95
109
m_scheduler->get_config ().enable_prefix_caching );
96
110
sequence_group->set_sequence_group_ptr (sequence_group);
111
+
97
112
if (m_scheduler->get_config ().enable_prefix_caching ) {
98
113
m_scheduler->restore_cached_blocks (sequence_group);
99
114
}
@@ -102,6 +117,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
102
117
std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex};
103
118
m_awaiting_requests.push_back (sequence_group);
104
119
}
120
+
105
121
return std::make_shared<GenerationHandleImpl>(sequence_group->get_generation_stream (), sampling_params);
106
122
};
107
123
@@ -113,6 +129,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
113
129
timer.start ();
114
130
ov::Tensor input_ids = m_tokenizer.encode (prompt).input_ids ;
115
131
timer.end ();
132
+
116
133
return add_request (request_id, input_ids, sampling_params);
117
134
}
118
135
@@ -127,24 +144,26 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
127
144
128
145
_pull_awaiting_requests ();
129
146
130
- m_pipeline_metrics.requests = m_requests.size ();
131
147
Scheduler::Output scheduler_output;
132
148
{
133
- static ManualTimer timer (" scheduling" );
134
- timer.start ();
135
- m_scheduler->clean_empty_blocks (m_requests);
149
+ static ManualTimer scheduling_timer (" scheduling" );
150
+ scheduling_timer.start ();
136
151
scheduler_output = m_scheduler->schedule (m_requests);
152
+ scheduling_timer.end ();
153
+
137
154
m_pipeline_metrics.scheduled_requests = scheduler_output.m_scheduled_sequence_groups_ids .size ();
138
155
m_pipeline_metrics.cache_usage = scheduler_output.m_cache_usage ;
139
- m_pipeline_metrics.max_cache_usage =
140
- std::max (m_pipeline_metrics.max_cache_usage , scheduler_output.m_cache_usage );
156
+ m_pipeline_metrics.max_cache_usage = std::max (m_pipeline_metrics.max_cache_usage , scheduler_output.m_cache_usage );
141
157
_register_step_cache_usage (scheduler_output.m_cache_usage );
142
158
m_pipeline_metrics.avg_cache_usage = _get_current_running_average_cache_usage ();
159
+
160
+ static ManualTimer copy_blocks_timer (" scheduling" );
161
+ copy_blocks_timer.start ();
143
162
m_cache_manager->copy_blocks (scheduler_output.m_block_copy_map );
144
- timer .end ();
163
+ copy_blocks_timer .end ();
145
164
}
146
165
147
- // if no tokens were scheduled, we are out of memory
166
+ // if no tokens were scheduled, we are out of memory => free all requests and return
148
167
if (scheduler_output.m_total_num_scheduled_tokens == 0 ) {
149
168
for (size_t i = 0 ; i < m_requests.size (); ++i) {
150
169
SequenceGroup::Ptr sequence_group = m_requests[i];
@@ -166,15 +185,14 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
166
185
}
167
186
168
187
#ifdef DEBUG_CACHE_STATE_DUMP
169
-
170
188
CacheStateDumper dumper (CacheStateDumper::get_run_id_for_generation_step (step_count, " before_eviction" ));
171
189
dumper.dump_cache_state (*m_scheduler, m_requests, step_count);
172
190
#endif
173
- const auto & sched_config = m_scheduler->get_config ();
174
191
175
192
// evict unimportant blocks from KV cache, if requested
193
+ const auto & sched_config = m_scheduler->get_config ();
176
194
if (sched_config.use_cache_eviction ) {
177
- maybe_evict_cache_blocks (sched_config);
195
+ _maybe_evict_cache_blocks (sched_config);
178
196
}
179
197
180
198
#ifdef DEBUG_CACHE_STATE_DUMP
@@ -183,6 +201,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
183
201
step_count++;
184
202
#endif
185
203
204
+ // process generation_config.echo parameetr
186
205
_fill_prompt_log_probs (m_requests, logits);
187
206
188
207
SamplerOutput sampler_output;
@@ -195,8 +214,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
195
214
196
215
// process sampler_output (e.g. fork or drop sequences from BlockScheduler)
197
216
{
198
- static ManualTimer timer (" fork / free sequence" );
199
- timer .start ();
217
+ static ManualTimer free_fork_timer (" fork / free sequence" );
218
+ free_fork_timer .start ();
200
219
201
220
for (const auto & pair : sampler_output.m_forked_sequences ) {
202
221
uint64_t parent_id = pair.first ;
@@ -208,35 +227,49 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
208
227
for (auto seq_id : sampler_output.m_dropped_sequences )
209
228
m_scheduler->free_sequence (seq_id);
210
229
211
- timer .end ();
230
+ free_fork_timer .end ();
212
231
}
213
232
214
233
// notify requests dropped by handle
215
234
{
216
- static ManualTimer timer (" notify requests dropped by handle" );
217
- timer .start ();
235
+ static ManualTimer report_tokens_timer (" notify requests dropped by handle" );
236
+ report_tokens_timer .start ();
218
237
_notify_requests_dropped_by_handle ();
219
- timer .end ();
238
+ report_tokens_timer .end ();
220
239
}
221
240
222
241
// free non running requests for current step
223
242
224
243
{
225
- static ManualTimer timer (" free non running requests" );
226
- timer .start ();
244
+ static ManualTimer clean_up_requests_timer (" free non running requests" );
245
+ clean_up_requests_timer .start ();
227
246
_free_non_running_requests ();
228
- timer .end ();
247
+ clean_up_requests_timer .end ();
229
248
}
230
249
231
250
step_timer.end ();
232
251
}
233
252
253
+ void ContinuousBatchingPipeline::ContinuousBatchingImpl::set_adapters (const std::optional<AdapterConfig>& adapters) {
254
+ if (m_adapter_controller) {
255
+ m_adapter_controller->apply (m_model_runner->get_infer_request (), adapters);
256
+ }
257
+ }
258
+
234
259
std::vector<EncodedGenerationResult>
235
260
ContinuousBatchingPipeline::ContinuousBatchingImpl::generate (const std::vector<ov::Tensor>& input_ids,
236
261
const std::vector<GenerationConfig>& sampling_params,
237
262
const StreamerVariant& streamer) {
238
263
OPENVINO_ASSERT (!has_non_finished_requests (), " Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request" );
239
264
OPENVINO_ASSERT (input_ids.size () == sampling_params.size ());
265
+
266
+ // checks that all requests has the same LoRA adapters property value
267
+ for (size_t i = 1 ; i < sampling_params.size (); ++i) {
268
+ OPENVINO_ASSERT (sampling_params[i - 1 ].adapters == sampling_params[i].adapters ,
269
+ " LoRA adapters value must be the same for all requests" );
270
+ }
271
+ set_adapters (sampling_params[0 ].adapters );
272
+
240
273
const std::shared_ptr<StreamerBase>& streamer_ptr = std::visit (overloaded{
241
274
[](std::monostate) -> std::shared_ptr<StreamerBase> {
242
275
return nullptr ;
@@ -375,7 +408,7 @@ float ContinuousBatchingPipeline::ContinuousBatchingImpl::_get_current_running_a
375
408
return std::accumulate (m_previous_step_cache_usages.begin (), m_previous_step_cache_usages.end (), 0.0 ) / m_previous_step_cache_usages.size ();
376
409
}
377
410
378
- void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_blocks (const SchedulerConfig& sched_config) {
411
+ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_maybe_evict_cache_blocks (const SchedulerConfig& sched_config) {
379
412
std::unordered_map<SequenceGroup::Ptr, size_t > seq_group_to_num_blocks_evicted_map;
380
413
auto sequence_attention_scores = m_model_runner->get_last_attention_scores ();
381
414
for (auto & seq_id_and_attention_scores : sequence_attention_scores) {
0 commit comments