Skip to content

Commit ba0224f

Browse files
Added LoRA support to CB, SD, PL (#1452)
CVS-159960
1 parent 6c56a7b commit ba0224f

18 files changed

+214
-81
lines changed

src/cpp/include/openvino/genai/continuous_batching_pipeline.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ struct PipelineMetrics {
5252

5353
class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
5454
protected:
55-
class ImplInterface;
55+
class IContinuousBatchingPipeline;
5656
class ContinuousBatchingImpl;
57+
5758
class ContinuousBatchingForSpeculativeDecodingImpl;
5859
class ContinuousBatchingForPromptLookupImpl;
5960
class SpeculativeDecodingImpl;
@@ -64,7 +65,7 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
6465
friend class SpeculativeDecodingImpl;
6566
friend class PromptLookupImpl;
6667

67-
std::shared_ptr<ImplInterface> m_impl;
68+
std::shared_ptr<IContinuousBatchingPipeline> m_impl;
6869

6970
ContinuousBatchingPipeline() = default;
7071

src/cpp/include/openvino/genai/lora_adapter.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class OPENVINO_GENAI_EXPORTS AdapterController {
188188
AdapterController(std::shared_ptr<ov::Model> model, const AdapterConfig& config, std::string device);
189189

190190
// Apply adapters configured in the current config set last time, or set and use new config given as optional `config` argument
191-
void apply(ov::InferRequest& request, const std::optional<AdapterConfig>& config = std::nullopt);
191+
void apply(ov::InferRequest request, const std::optional<AdapterConfig>& config = std::nullopt);
192192

193193
// Returns true if a given name is one of the state names created by this adapter controller for dynamic LoRA
194194
// Helps to distinguish LoRA states from other states (e.g. KV cache state) in the model for a partial state reset.

src/cpp/src/continuous_batching_impl.cpp

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "continuous_batching_impl.hpp"
66
#include "utils.hpp"
77
#include "utils/paged_attention_transformations.hpp"
8+
#include "lora_helper.hpp"
89

910
namespace ov::genai {
1011
template<class... Ts> struct overloaded : Ts... {using Ts::operator()...;};
@@ -17,8 +18,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
1718
const std::string& device,
1819
const ov::AnyMap& properties,
1920
const ov::genai::GenerationConfig& generation_config,
20-
bool is_validation_mode_enabled
21-
) {
21+
bool is_validation_mode_enabled) {
2222
m_tokenizer = tokenizer;
2323
m_generation_config = generation_config;
2424
m_is_validation_mode_enabled = is_validation_mode_enabled;
@@ -33,22 +33,33 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
3333
bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction;
3434
utils::apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control);
3535

36-
init(model, scheduler_config, compile_properties, device_config, core);
36+
initialize_pipeline(model, scheduler_config, compile_properties, device_config, core);
3737
}
3838

3939
void ContinuousBatchingPipeline::ContinuousBatchingImpl::_pull_awaiting_requests() {
4040
std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex};
4141
m_requests.insert(m_requests.end(), m_awaiting_requests.begin(), m_awaiting_requests.end());
4242
m_awaiting_requests.clear();
43+
m_pipeline_metrics.requests = m_requests.size();
4344
}
4445

45-
void ContinuousBatchingPipeline::ContinuousBatchingImpl::init(
46+
void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
4647
std::shared_ptr<ov::Model> model,
4748
const SchedulerConfig& scheduler_config,
4849
const ov::AnyMap& properties,
4950
const DeviceConfig& device_config,
5051
ov::Core& core) {
51-
auto compiled_model = core.compile_model(model, device_config.get_device(), properties);
52+
ov::CompiledModel compiled_model;
53+
54+
// apply LoRA
55+
if (auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters)) {
56+
m_generation_config.adapters->set_tensor_name_prefix("base_model.model.model.");
57+
m_adapter_controller = AdapterController(model, *m_generation_config.adapters, device_config.get_device()); // TODO: Make the prefix name configurable
58+
compiled_model = core.compile_model(model, device_config.get_device(), *filtered_properties);
59+
} else {
60+
compiled_model = core.compile_model(model, device_config.get_device(), properties);
61+
}
62+
5263
ov::genai::utils::print_compiled_model_properties(compiled_model, "LLM with Paged Attention");
5364
ov::InferRequest infer_request = compiled_model.create_infer_request();
5465

@@ -68,9 +79,12 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init(
6879
can_use_partial_preemption = false;
6980
}
7081
m_scheduler = std::make_shared<Scheduler>(device_config.get_block_size(), m_cache_manager, updated_config, device_config.get_num_layers(), can_use_partial_preemption);
71-
// and finally create model runner
82+
83+
// model runner
7284
bool is_use_cache_eviction = m_scheduler->get_config().use_cache_eviction;
7385
m_model_runner = std::make_shared<ModelRunner>(infer_request, m_scheduler->get_block_size(), device_config.get_num_layers(), is_use_cache_eviction);
86+
87+
// sampler
7488
m_sampler = std::make_shared<Sampler>(m_tokenizer);
7589
m_sampler->set_seed(m_generation_config.rng_seed);
7690

@@ -94,6 +108,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
94108
m_scheduler->get_block_size(),
95109
m_scheduler->get_config().enable_prefix_caching);
96110
sequence_group->set_sequence_group_ptr(sequence_group);
111+
97112
if (m_scheduler->get_config().enable_prefix_caching) {
98113
m_scheduler->restore_cached_blocks(sequence_group);
99114
}
@@ -102,6 +117,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
102117
std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex};
103118
m_awaiting_requests.push_back(sequence_group);
104119
}
120+
105121
return std::make_shared<GenerationHandleImpl>(sequence_group->get_generation_stream(), sampling_params);
106122
};
107123

@@ -113,6 +129,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
113129
timer.start();
114130
ov::Tensor input_ids = m_tokenizer.encode(prompt).input_ids;
115131
timer.end();
132+
116133
return add_request(request_id, input_ids, sampling_params);
117134
}
118135

@@ -127,24 +144,26 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
127144

128145
_pull_awaiting_requests();
129146

130-
m_pipeline_metrics.requests = m_requests.size();
131147
Scheduler::Output scheduler_output;
132148
{
133-
static ManualTimer timer("scheduling");
134-
timer.start();
135-
m_scheduler->clean_empty_blocks(m_requests);
149+
static ManualTimer scheduling_timer("scheduling");
150+
scheduling_timer.start();
136151
scheduler_output = m_scheduler->schedule(m_requests);
152+
scheduling_timer.end();
153+
137154
m_pipeline_metrics.scheduled_requests = scheduler_output.m_scheduled_sequence_groups_ids.size();
138155
m_pipeline_metrics.cache_usage = scheduler_output.m_cache_usage;
139-
m_pipeline_metrics.max_cache_usage =
140-
std::max(m_pipeline_metrics.max_cache_usage, scheduler_output.m_cache_usage);
156+
m_pipeline_metrics.max_cache_usage = std::max(m_pipeline_metrics.max_cache_usage, scheduler_output.m_cache_usage);
141157
_register_step_cache_usage(scheduler_output.m_cache_usage);
142158
m_pipeline_metrics.avg_cache_usage = _get_current_running_average_cache_usage();
159+
160+
static ManualTimer copy_blocks_timer("scheduling");
161+
copy_blocks_timer.start();
143162
m_cache_manager->copy_blocks(scheduler_output.m_block_copy_map);
144-
timer.end();
163+
copy_blocks_timer.end();
145164
}
146165

147-
// if no tokens were scheduled, we are out of memory
166+
// if no tokens were scheduled, we are out of memory => free all requests and return
148167
if (scheduler_output.m_total_num_scheduled_tokens == 0) {
149168
for (size_t i = 0; i < m_requests.size(); ++i) {
150169
SequenceGroup::Ptr sequence_group = m_requests[i];
@@ -166,15 +185,14 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
166185
}
167186

168187
#ifdef DEBUG_CACHE_STATE_DUMP
169-
170188
CacheStateDumper dumper(CacheStateDumper::get_run_id_for_generation_step(step_count, "before_eviction"));
171189
dumper.dump_cache_state(*m_scheduler, m_requests, step_count);
172190
#endif
173-
const auto& sched_config = m_scheduler->get_config();
174191

175192
// evict unimportant blocks from KV cache, if requested
193+
const auto& sched_config = m_scheduler->get_config();
176194
if (sched_config.use_cache_eviction) {
177-
maybe_evict_cache_blocks(sched_config);
195+
_maybe_evict_cache_blocks(sched_config);
178196
}
179197

180198
#ifdef DEBUG_CACHE_STATE_DUMP
@@ -183,6 +201,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
183201
step_count++;
184202
#endif
185203

204+
// process generation_config.echo parameetr
186205
_fill_prompt_log_probs(m_requests, logits);
187206

188207
SamplerOutput sampler_output;
@@ -195,8 +214,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
195214

196215
// process sampler_output (e.g. fork or drop sequences from BlockScheduler)
197216
{
198-
static ManualTimer timer("fork / free sequence");
199-
timer.start();
217+
static ManualTimer free_fork_timer("fork / free sequence");
218+
free_fork_timer.start();
200219

201220
for (const auto& pair : sampler_output.m_forked_sequences) {
202221
uint64_t parent_id = pair.first;
@@ -208,35 +227,49 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
208227
for (auto seq_id : sampler_output.m_dropped_sequences)
209228
m_scheduler->free_sequence(seq_id);
210229

211-
timer.end();
230+
free_fork_timer.end();
212231
}
213232

214233
// notify requests dropped by handle
215234
{
216-
static ManualTimer timer("notify requests dropped by handle");
217-
timer.start();
235+
static ManualTimer report_tokens_timer("notify requests dropped by handle");
236+
report_tokens_timer.start();
218237
_notify_requests_dropped_by_handle();
219-
timer.end();
238+
report_tokens_timer.end();
220239
}
221240

222241
// free non running requests for current step
223242

224243
{
225-
static ManualTimer timer("free non running requests");
226-
timer.start();
244+
static ManualTimer clean_up_requests_timer("free non running requests");
245+
clean_up_requests_timer.start();
227246
_free_non_running_requests();
228-
timer.end();
247+
clean_up_requests_timer.end();
229248
}
230249

231250
step_timer.end();
232251
}
233252

253+
void ContinuousBatchingPipeline::ContinuousBatchingImpl::set_adapters(const std::optional<AdapterConfig>& adapters) {
254+
if (m_adapter_controller) {
255+
m_adapter_controller->apply(m_model_runner->get_infer_request(), adapters);
256+
}
257+
}
258+
234259
std::vector<EncodedGenerationResult>
235260
ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<ov::Tensor>& input_ids,
236261
const std::vector<GenerationConfig>& sampling_params,
237262
const StreamerVariant& streamer) {
238263
OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request");
239264
OPENVINO_ASSERT(input_ids.size() == sampling_params.size());
265+
266+
// checks that all requests has the same LoRA adapters property value
267+
for (size_t i = 1; i < sampling_params.size(); ++i) {
268+
OPENVINO_ASSERT(sampling_params[i - 1].adapters == sampling_params[i].adapters,
269+
"LoRA adapters value must be the same for all requests");
270+
}
271+
set_adapters(sampling_params[0].adapters);
272+
240273
const std::shared_ptr<StreamerBase>& streamer_ptr = std::visit(overloaded{
241274
[](std::monostate) -> std::shared_ptr<StreamerBase> {
242275
return nullptr;
@@ -375,7 +408,7 @@ float ContinuousBatchingPipeline::ContinuousBatchingImpl::_get_current_running_a
375408
return std::accumulate(m_previous_step_cache_usages.begin(), m_previous_step_cache_usages.end(), 0.0) / m_previous_step_cache_usages.size();
376409
}
377410

378-
void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_blocks(const SchedulerConfig& sched_config) {
411+
void ContinuousBatchingPipeline::ContinuousBatchingImpl::_maybe_evict_cache_blocks(const SchedulerConfig& sched_config) {
379412
std::unordered_map<SequenceGroup::Ptr, size_t> seq_group_to_num_blocks_evicted_map;
380413
auto sequence_attention_scores = m_model_runner->get_last_attention_scores();
381414
for (auto& seq_id_and_attention_scores : sequence_attention_scores) {

src/cpp/src/continuous_batching_impl.hpp

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33

44
#pragma once
55

6-
#include "continuous_batching_impl_interface.hpp"
7-
#include "openvino/genai/continuous_batching_pipeline.hpp"
6+
#include "icontinuous_batching.hpp"
7+
8+
#include "openvino/genai/lora_adapter.hpp"
89
#include "cache_eviction.hpp"
910

1011
namespace ov::genai {
11-
class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatchingPipeline::ImplInterface {
12+
13+
class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatchingPipeline::IContinuousBatchingPipeline {
1214
protected:
1315
std::shared_ptr<Scheduler> m_scheduler;
1416
std::shared_ptr<CacheManager> m_cache_manager;
1517
std::shared_ptr<ModelRunner> m_model_runner;
18+
std::optional<AdapterController> m_adapter_controller;
1619
std::shared_ptr<Sampler> m_sampler;
1720

1821
// current requests to process
@@ -26,7 +29,7 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc
2629

2730
static const size_t AVG_CACHE_USAGE_WINDOW_SIZE_IN_STEPS = 1000;
2831
std::deque<float> m_previous_step_cache_usages;
29-
32+
3033
// flag to enable validation mode for sampler
3134
bool m_is_validation_mode_enabled = false;
3235

@@ -37,21 +40,41 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc
3740
// used by tests only
3841
ContinuousBatchingImpl() = default;
3942

43+
void initialize_pipeline(std::shared_ptr<ov::Model> model,
44+
const SchedulerConfig& scheduler_config,
45+
const ov::AnyMap& plugin_config,
46+
const DeviceConfig& device_config,
47+
ov::Core& core);
48+
49+
/**
50+
* Pulls requests from awaiting queue to running queue
51+
* Should be called within each call of step()
52+
*/
53+
virtual void _pull_awaiting_requests();
54+
55+
/**
56+
* Releases non-running (finished, dropped or OOM) requests from running queue
57+
*/
4058
void _free_non_running_requests();
59+
60+
/**
61+
* Notify dropped requests by pushing empty output
62+
*/
4163
void _notify_requests_dropped_by_handle();
42-
void _register_step_cache_usage(float step_cache_usage);
43-
float _get_current_running_average_cache_usage() const;
44-
void maybe_evict_cache_blocks(const SchedulerConfig& sched_config);
4564

46-
void init(std::shared_ptr<ov::Model> model,
47-
const SchedulerConfig& scheduler_config,
48-
const ov::AnyMap& plugin_config,
49-
const DeviceConfig& device_config,
50-
ov::Core& core);
65+
/**
66+
* Handles 'echo' generation parameter
67+
*/
68+
void _fill_prompt_log_probs(std::vector<SequenceGroup::Ptr>& sequence_groups, ov::Tensor& logits);
5169

52-
virtual void _pull_awaiting_requests();
70+
/**
71+
* Performs KV cache eviction is enabled / requireed
72+
*/
73+
void _maybe_evict_cache_blocks(const SchedulerConfig& sched_config);
74+
75+
void _register_step_cache_usage(float step_cache_usage);
76+
float _get_current_running_average_cache_usage() const;
5377

54-
void _fill_prompt_log_probs(std::vector<SequenceGroup::Ptr>& sequence_groups, ov::Tensor& logits);
5578
public:
5679
ContinuousBatchingImpl(const std::shared_ptr<ov::Model>& model,
5780
const Tokenizer& tokenizer,
@@ -64,6 +87,7 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc
6487
GenerationHandle add_request(uint64_t request_id,
6588
const ov::Tensor& input_ids,
6689
ov::genai::GenerationConfig sampling_params) override;
90+
6791
GenerationHandle add_request(uint64_t request_id,
6892
const std::string& prompt,
6993
ov::genai::GenerationConfig sampling_params) override;
@@ -76,5 +100,11 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc
76100
generate(const std::vector<ov::Tensor>& input_ids,
77101
const std::vector<GenerationConfig>& sampling_params,
78102
const StreamerVariant& streamer) override;
103+
104+
/**
105+
* Updates LoRA adapters for current generation call
106+
*/
107+
void set_adapters(const std::optional<AdapterConfig>& adapters);
79108
};
80-
}
109+
110+
} // namespace ov::genai

0 commit comments

Comments
 (0)