Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added LoRA support to CB, SD, PL #1452

Merged
merged 1 commit into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ struct PipelineMetrics {

class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
protected:
class ImplInterface;
class IContinuousBatchingPipeline;
class ContinuousBatchingImpl;

class ContinuousBatchingForSpeculativeDecodingImpl;
class ContinuousBatchingForPromptLookupImpl;
class SpeculativeDecodingImpl;
Expand All @@ -64,7 +65,7 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
friend class SpeculativeDecodingImpl;
friend class PromptLookupImpl;

std::shared_ptr<ImplInterface> m_impl;
std::shared_ptr<IContinuousBatchingPipeline> m_impl;

ContinuousBatchingPipeline() = default;

Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/lora_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class OPENVINO_GENAI_EXPORTS AdapterController {
AdapterController(std::shared_ptr<ov::Model> model, const AdapterConfig& config, std::string device);

// Apply adapters configured in the current config set last time, or set and use new config given as optional `config` argument
void apply(ov::InferRequest& request, const std::optional<AdapterConfig>& config = std::nullopt);
void apply(ov::InferRequest request, const std::optional<AdapterConfig>& config = std::nullopt);

// Returns true if a given name is one of the state names created by this adapter controller for dynamic LoRA
// Helps to distinguish LoRA states from other states (e.g. KV cache state) in the model for a partial state reset.
Expand Down
87 changes: 60 additions & 27 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "continuous_batching_impl.hpp"
#include "utils.hpp"
#include "utils/paged_attention_transformations.hpp"
#include "lora_helper.hpp"

namespace ov::genai {
template<class... Ts> struct overloaded : Ts... {using Ts::operator()...;};
Expand All @@ -17,8 +18,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
const std::string& device,
const ov::AnyMap& properties,
const ov::genai::GenerationConfig& generation_config,
bool is_validation_mode_enabled
) {
bool is_validation_mode_enabled) {
m_tokenizer = tokenizer;
m_generation_config = generation_config;
m_is_validation_mode_enabled = is_validation_mode_enabled;
Expand All @@ -33,22 +33,33 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction;
utils::apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control);

init(model, scheduler_config, compile_properties, device_config, core);
initialize_pipeline(model, scheduler_config, compile_properties, device_config, core);
}

void ContinuousBatchingPipeline::ContinuousBatchingImpl::_pull_awaiting_requests() {
std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex};
m_requests.insert(m_requests.end(), m_awaiting_requests.begin(), m_awaiting_requests.end());
m_awaiting_requests.clear();
m_pipeline_metrics.requests = m_requests.size();
}

void ContinuousBatchingPipeline::ContinuousBatchingImpl::init(
void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
std::shared_ptr<ov::Model> model,
const SchedulerConfig& scheduler_config,
const ov::AnyMap& properties,
const DeviceConfig& device_config,
ov::Core& core) {
auto compiled_model = core.compile_model(model, device_config.get_device(), properties);
ov::CompiledModel compiled_model;

// apply LoRA
if (auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters)) {
m_generation_config.adapters->set_tensor_name_prefix("base_model.model.model.");
m_adapter_controller = AdapterController(model, *m_generation_config.adapters, device_config.get_device()); // TODO: Make the prefix name configurable
compiled_model = core.compile_model(model, device_config.get_device(), *filtered_properties);
} else {
compiled_model = core.compile_model(model, device_config.get_device(), properties);
}

ov::genai::utils::print_compiled_model_properties(compiled_model, "LLM with Paged Attention");
ov::InferRequest infer_request = compiled_model.create_infer_request();

Expand All @@ -68,9 +79,12 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init(
can_use_partial_preemption = false;
}
m_scheduler = std::make_shared<Scheduler>(device_config.get_block_size(), m_cache_manager, updated_config, device_config.get_num_layers(), can_use_partial_preemption);
// and finally create model runner

// model runner
bool is_use_cache_eviction = m_scheduler->get_config().use_cache_eviction;
m_model_runner = std::make_shared<ModelRunner>(infer_request, m_scheduler->get_block_size(), device_config.get_num_layers(), is_use_cache_eviction);

// sampler
m_sampler = std::make_shared<Sampler>(m_tokenizer);
m_sampler->set_seed(m_generation_config.rng_seed);

Expand All @@ -94,6 +108,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
m_scheduler->get_block_size(),
m_scheduler->get_config().enable_prefix_caching);
sequence_group->set_sequence_group_ptr(sequence_group);

if (m_scheduler->get_config().enable_prefix_caching) {
m_scheduler->restore_cached_blocks(sequence_group);
}
Expand All @@ -102,6 +117,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
std::lock_guard<std::mutex> lock{m_awaiting_requests_mutex};
m_awaiting_requests.push_back(sequence_group);
}

return std::make_shared<GenerationHandleImpl>(sequence_group->get_generation_stream(), sampling_params);
};

Expand All @@ -113,6 +129,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
timer.start();
ov::Tensor input_ids = m_tokenizer.encode(prompt).input_ids;
timer.end();

return add_request(request_id, input_ids, sampling_params);
}

Expand All @@ -127,24 +144,26 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {

_pull_awaiting_requests();

m_pipeline_metrics.requests = m_requests.size();
Scheduler::Output scheduler_output;
{
static ManualTimer timer("scheduling");
timer.start();
m_scheduler->clean_empty_blocks(m_requests);
static ManualTimer scheduling_timer("scheduling");
scheduling_timer.start();
scheduler_output = m_scheduler->schedule(m_requests);
scheduling_timer.end();

m_pipeline_metrics.scheduled_requests = scheduler_output.m_scheduled_sequence_groups_ids.size();
m_pipeline_metrics.cache_usage = scheduler_output.m_cache_usage;
m_pipeline_metrics.max_cache_usage =
std::max(m_pipeline_metrics.max_cache_usage, scheduler_output.m_cache_usage);
m_pipeline_metrics.max_cache_usage = std::max(m_pipeline_metrics.max_cache_usage, scheduler_output.m_cache_usage);
_register_step_cache_usage(scheduler_output.m_cache_usage);
m_pipeline_metrics.avg_cache_usage = _get_current_running_average_cache_usage();

static ManualTimer copy_blocks_timer("scheduling");
copy_blocks_timer.start();
m_cache_manager->copy_blocks(scheduler_output.m_block_copy_map);
timer.end();
copy_blocks_timer.end();
}

// if no tokens were scheduled, we are out of memory
// if no tokens were scheduled, we are out of memory => free all requests and return
if (scheduler_output.m_total_num_scheduled_tokens == 0) {
for (size_t i = 0; i < m_requests.size(); ++i) {
SequenceGroup::Ptr sequence_group = m_requests[i];
Expand All @@ -166,15 +185,14 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
}

#ifdef DEBUG_CACHE_STATE_DUMP

CacheStateDumper dumper(CacheStateDumper::get_run_id_for_generation_step(step_count, "before_eviction"));
dumper.dump_cache_state(*m_scheduler, m_requests, step_count);
#endif
const auto& sched_config = m_scheduler->get_config();

// evict unimportant blocks from KV cache, if requested
const auto& sched_config = m_scheduler->get_config();
if (sched_config.use_cache_eviction) {
maybe_evict_cache_blocks(sched_config);
_maybe_evict_cache_blocks(sched_config);
}

#ifdef DEBUG_CACHE_STATE_DUMP
Expand All @@ -183,6 +201,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
step_count++;
#endif

// process generation_config.echo parameetr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// process generation_config.echo parameetr
// process generation_config.echo parameter

_fill_prompt_log_probs(m_requests, logits);

SamplerOutput sampler_output;
Expand All @@ -195,8 +214,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {

// process sampler_output (e.g. fork or drop sequences from BlockScheduler)
{
static ManualTimer timer("fork / free sequence");
timer.start();
static ManualTimer free_fork_timer("fork / free sequence");
free_fork_timer.start();

for (const auto& pair : sampler_output.m_forked_sequences) {
uint64_t parent_id = pair.first;
Expand All @@ -208,35 +227,49 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
for (auto seq_id : sampler_output.m_dropped_sequences)
m_scheduler->free_sequence(seq_id);

timer.end();
free_fork_timer.end();
}

// notify requests dropped by handle
{
static ManualTimer timer("notify requests dropped by handle");
timer.start();
static ManualTimer report_tokens_timer("notify requests dropped by handle");
report_tokens_timer.start();
_notify_requests_dropped_by_handle();
timer.end();
report_tokens_timer.end();
}

// free non running requests for current step

{
static ManualTimer timer("free non running requests");
timer.start();
static ManualTimer clean_up_requests_timer("free non running requests");
clean_up_requests_timer.start();
_free_non_running_requests();
timer.end();
clean_up_requests_timer.end();
}

step_timer.end();
}

void ContinuousBatchingPipeline::ContinuousBatchingImpl::set_adapters(const std::optional<AdapterConfig>& adapters) {
if (m_adapter_controller) {
m_adapter_controller->apply(m_model_runner->get_infer_request(), adapters);
}
}

std::vector<EncodedGenerationResult>
ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<ov::Tensor>& input_ids,
const std::vector<GenerationConfig>& sampling_params,
const StreamerVariant& streamer) {
OPENVINO_ASSERT(!has_non_finished_requests(), "Generate cannot be called while ContinuousBatchingPipeline is already in running state. Use ContinuousBatchingPipeline::add_request");
OPENVINO_ASSERT(input_ids.size() == sampling_params.size());

// checks that all requests has the same LoRA adapters property value
for (size_t i = 1; i < sampling_params.size(); ++i) {
OPENVINO_ASSERT(sampling_params[i - 1].adapters == sampling_params[i].adapters,
"LoRA adapters value must be the same for all requests");
}
set_adapters(sampling_params[0].adapters);

const std::shared_ptr<StreamerBase>& streamer_ptr = std::visit(overloaded{
[](std::monostate) -> std::shared_ptr<StreamerBase> {
return nullptr;
Expand Down Expand Up @@ -375,7 +408,7 @@ float ContinuousBatchingPipeline::ContinuousBatchingImpl::_get_current_running_a
return std::accumulate(m_previous_step_cache_usages.begin(), m_previous_step_cache_usages.end(), 0.0) / m_previous_step_cache_usages.size();
}

void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_blocks(const SchedulerConfig& sched_config) {
void ContinuousBatchingPipeline::ContinuousBatchingImpl::_maybe_evict_cache_blocks(const SchedulerConfig& sched_config) {
std::unordered_map<SequenceGroup::Ptr, size_t> seq_group_to_num_blocks_evicted_map;
auto sequence_attention_scores = m_model_runner->get_last_attention_scores();
for (auto& seq_id_and_attention_scores : sequence_attention_scores) {
Expand Down
60 changes: 45 additions & 15 deletions src/cpp/src/continuous_batching_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@

#pragma once

#include "continuous_batching_impl_interface.hpp"
#include "openvino/genai/continuous_batching_pipeline.hpp"
#include "icontinuous_batching.hpp"

#include "openvino/genai/lora_adapter.hpp"
#include "cache_eviction.hpp"

namespace ov::genai {
class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatchingPipeline::ImplInterface {

class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatchingPipeline::IContinuousBatchingPipeline {
protected:
std::shared_ptr<Scheduler> m_scheduler;
std::shared_ptr<CacheManager> m_cache_manager;
std::shared_ptr<ModelRunner> m_model_runner;
std::optional<AdapterController> m_adapter_controller;
std::shared_ptr<Sampler> m_sampler;

// current requests to process
Expand All @@ -26,7 +29,7 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc

static const size_t AVG_CACHE_USAGE_WINDOW_SIZE_IN_STEPS = 1000;
std::deque<float> m_previous_step_cache_usages;

// flag to enable validation mode for sampler
bool m_is_validation_mode_enabled = false;

Expand All @@ -37,21 +40,41 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc
// used by tests only
ContinuousBatchingImpl() = default;

void initialize_pipeline(std::shared_ptr<ov::Model> model,
const SchedulerConfig& scheduler_config,
const ov::AnyMap& plugin_config,
const DeviceConfig& device_config,
ov::Core& core);

/**
* Pulls requests from awaiting queue to running queue
* Should be called within each call of step()
*/
virtual void _pull_awaiting_requests();

/**
* Releases non-running (finished, dropped or OOM) requests from running queue
*/
void _free_non_running_requests();

/**
* Notify dropped requests by pushing empty output
*/
void _notify_requests_dropped_by_handle();
void _register_step_cache_usage(float step_cache_usage);
float _get_current_running_average_cache_usage() const;
void maybe_evict_cache_blocks(const SchedulerConfig& sched_config);

void init(std::shared_ptr<ov::Model> model,
const SchedulerConfig& scheduler_config,
const ov::AnyMap& plugin_config,
const DeviceConfig& device_config,
ov::Core& core);
/**
* Handles 'echo' generation parameter
*/
void _fill_prompt_log_probs(std::vector<SequenceGroup::Ptr>& sequence_groups, ov::Tensor& logits);

virtual void _pull_awaiting_requests();
/**
* Performs KV cache eviction is enabled / requireed
*/
void _maybe_evict_cache_blocks(const SchedulerConfig& sched_config);

void _register_step_cache_usage(float step_cache_usage);
float _get_current_running_average_cache_usage() const;

void _fill_prompt_log_probs(std::vector<SequenceGroup::Ptr>& sequence_groups, ov::Tensor& logits);
public:
ContinuousBatchingImpl(const std::shared_ptr<ov::Model>& model,
const Tokenizer& tokenizer,
Expand All @@ -64,6 +87,7 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc
GenerationHandle add_request(uint64_t request_id,
const ov::Tensor& input_ids,
ov::genai::GenerationConfig sampling_params) override;

GenerationHandle add_request(uint64_t request_id,
const std::string& prompt,
ov::genai::GenerationConfig sampling_params) override;
Expand All @@ -76,5 +100,11 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc
generate(const std::vector<ov::Tensor>& input_ids,
const std::vector<GenerationConfig>& sampling_params,
const StreamerVariant& streamer) override;

/**
* Updates LoRA adapters for current generation call
*/
void set_adapters(const std::optional<AdapterConfig>& adapters);
};
}

} // namespace ov::genai
Loading
Loading