Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add slice before matmut transformation for CB scenario #1261

Merged
merged 12 commits into from
Jan 4, 2025
Merged
14 changes: 7 additions & 7 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(

bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction;
utils::apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control);
utils::apply_gather_before_matmul_transformation(model);

initialize_pipeline(model, scheduler_config, properties, device_config, core);
}
Expand Down Expand Up @@ -444,7 +445,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(
const float * logits_data = logits.data<float>();
ov::Shape logits_shape = logits.get_shape();
OPENVINO_ASSERT(logits_shape.size() == 3);
size_t batch_seq_len = logits_shape[1], vocab_size = logits_shape[2];
size_t vocab_size = logits_shape[2];
for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
// requests not scheduled, in decoding phase or not echoing are not processed
Expand All @@ -454,26 +455,25 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(

size_t num_running_sequences = sequence_group->num_running_seqs();
OPENVINO_ASSERT(num_running_sequences == 1);
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens();
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len);
size_t output_seq_len = sequence_group->get_output_seq_len();

const float * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;

size_t num_prompt_tokens_processed = sequence_group->get_num_processed_tokens();
OPENVINO_ASSERT(num_prompt_tokens_processed + actual_seq_len <= sequence_group->get_prompt_len());
OPENVINO_ASSERT(num_prompt_tokens_processed + output_seq_len <= sequence_group->get_prompt_len());

// if we processed the whole prompt we don't include last logprob as it will be processed by the sampler (it's already completion)
// otherwise we include it as it will be used in the next part of the prompt
int exclude_last_logprob = 1;
if (num_prompt_tokens_processed + actual_seq_len < sequence_group->get_prompt_len())
if (num_prompt_tokens_processed + output_seq_len < sequence_group->get_prompt_len())
exclude_last_logprob = 0;

// if we start processing the prompt we add "fake" log prob for the first position (begin of sequence)
if (num_prompt_tokens_processed == 0)
sequence_group->append_prompt_log_prob(1.0);

for (int token_logits_offset = 0, token_id_offset = num_prompt_tokens_processed + 1;
token_logits_offset < actual_seq_len - exclude_last_logprob;
token_logits_offset < output_seq_len - exclude_last_logprob;
token_logits_offset++, token_id_offset++) {

const float* token_logits = (sequence_group_logits_data + token_logits_offset * vocab_size);
Expand All @@ -498,7 +498,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(

sequence_group->append_prompt_log_prob(token_logit - max_value - log_sum);
}
currently_processed_tokens += padded_amount_of_processed_tokens * num_running_sequences;
currently_processed_tokens += output_seq_len * num_running_sequences;
// For max_new_tokens == 0, we don't reach sampling so need to notify handle separately
if(sequence_group->get_sampling_parameters().max_new_tokens == 0) {
sequence_group->notify_handle_echo_only();
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ StatefulLLMPipeline::StatefulLLMPipeline(
const ov::AnyMap& properties,
const ov::genai::GenerationConfig& generation_config)
: LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) {
utils::slice_matmul_stateful_model(model);
utils::apply_slice_before_matmul_transformation(model);
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);

ov::CompiledModel compiled_model;
Expand Down
51 changes: 42 additions & 9 deletions src/cpp/src/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,28 +114,54 @@ class ModelRunner {
subsequence_begins_data[0] = 0;
block_indices_begins_data[0] = 0;

bool matmul_gathering_is_available = false;
size_t gathering_current_index = 0;
std::vector<int64_t> gather_indices_values;
try {
std::ignore = m_request.get_tensor("sampled_tokens_indices");
matmul_gathering_is_available = true;
} catch (const ov::Exception&) {}


for (size_t i = 0; i < num_sequence_groups; ++i) {
size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i];
SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id];
std::vector<Sequence::CPtr> running_sequences = sequence_group->get_running_sequences();
SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id];
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
size_t num_running_sequences = running_sequences.size();
size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens();
size_t group_position_id = sequence_group->get_num_processed_tokens();
size_t prompt_len = sequence_group->get_prompt_len();

// spec: In case of multiple input tokens for current sequence (prompt_len > 1),
// context_len corresponds to first token within subgroup of scheduled tokens
size_t group_context_len = group_position_id;
// Next variables are only for sliced matmul case
size_t output_seq_len = 0;
const bool echo_output = sequence_group->get_sampling_parameters().echo;
const bool sampling_is_required = sequence_group->requires_sampling();
const size_t tokens_to_sample_per_sequence = 1 + sequence_group->get_num_tokens_to_validate();

for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id) {
output_seq_len = 0;
Sequence::CPtr sequence = running_sequences[seq_id];

for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) {
for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) {
// compute token for current sequence
input_ids_data[token_id] = position_id < sequence_group->get_prompt_len() ?
input_ids_data[token_id] = position_id < prompt_len ?
sequence_group->get_prompt_ids()[position_id] :
sequence->get_generated_ids()[position_id - sequence_group->get_prompt_len()];
sequence->get_generated_ids()[position_id - prompt_len];

position_ids_data[token_id] = position_id;

// Check if token gathering is required for the entire sequence group
if (matmul_gathering_is_available && (sampling_is_required || echo_output)) {
// Determine if the current token should be gathered
if (echo_output ||
olpipi marked this conversation as resolved.
Show resolved Hide resolved
// Skip gathering for prompt tokens
group_position_id + token_id >= prompt_len - 1 &&
olpipi marked this conversation as resolved.
Show resolved Hide resolved
// Gather only the last scheduled token or 1 + num_tokens_to_validate tokens for SD
// In SD, tokens_to_sample_per_sequence may exceed num_scheduled_tokens
token_id + tokens_to_sample_per_sequence >= num_scheduled_tokens) {
gather_indices_values.push_back(gathering_current_index);
output_seq_len++;
}
}
}

size_t expected_kv_cache_size = sequence_group->get_num_processed_tokens() - sequence_group->get_num_evicted_tokens();
Expand All @@ -153,6 +179,7 @@ class ModelRunner {
subsequence_begins_data += 1;
block_indices_begins_data += 1;
}
sequence_group->set_output_seq_len(matmul_gathering_is_available ? output_seq_len : num_scheduled_tokens);
}

// typical LLM parameters
Expand All @@ -168,6 +195,12 @@ class ModelRunner {
m_request.set_tensor("block_indices_begins", block_indices_begins);
m_request.set_tensor("max_context_len", max_context_len);

if (matmul_gathering_is_available) {
ov::Tensor gather_indices(ov::element::i64, {gather_indices_values.size()});
std::memcpy(gather_indices.data(), gather_indices_values.data(), gather_indices_values.size() * sizeof(int64_t));
m_request.set_tensor("sampled_tokens_indices", gather_indices);
}

// print_tensor("input_ids", input_ids);
// print_tensor("position_ids", position_ids);

Expand Down
13 changes: 6 additions & 7 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
const float * logits_data = logits.data<float>();
ov::Shape logits_shape = logits.get_shape();
OPENVINO_ASSERT(logits_shape.size() == 3);
size_t batch_seq_len = logits_shape[1], vocab_size = logits_shape[2];
size_t vocab_size = logits_shape[2];

SamplerOutput sampler_output;
for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
Expand All @@ -758,8 +758,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
continue;

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len);
size_t output_seq_len = sequence_group->get_output_seq_len();
const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters();

const auto request_id = sequence_group->get_request_id();
Expand All @@ -774,13 +773,13 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
auto& stop_strings = m_stop_strings.at(request_id);
auto& logit_processor = m_logit_processors.at(request_id);
const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, output_seq_len, vocab_size}, (void *)sequence_group_logits_data);
size_t max_removed_tokens_per_request = 0, min_generated_len = std::numeric_limits<size_t>::max(), updated_validation_len = 0;
if (sequence_group->requires_sampling()) {
// get number of token to be validated
auto num_tokens_to_process = sequence_group->get_num_tokens_to_validate();
if (num_tokens_to_process > actual_seq_len - 1) {
auto delta = num_tokens_to_process - (actual_seq_len - 1);
if (num_tokens_to_process > output_seq_len - 1) {
auto delta = num_tokens_to_process - (output_seq_len - 1);
updated_validation_len = std::max(updated_validation_len, delta);
num_tokens_to_process -= delta;
}
Expand Down Expand Up @@ -914,7 +913,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
}

// accumulate a number of processed tokens
currently_processed_tokens += padded_amount_of_processed_tokens * num_running_sequences;
currently_processed_tokens += output_seq_len * num_running_sequences;
}

return sampler_output;
Expand Down
13 changes: 13 additions & 0 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
size_t m_num_validation_tokens = 0;
// flag to enable/disable token generation, e.g. in speculative decoding scenario
bool m_is_gen_paused = false;
// output seq len at current iteration
size_t m_output_seq_len = 0;

size_t m_num_streamed_tokens = 0, m_stream_window_size = 0;

Expand Down Expand Up @@ -394,6 +396,14 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
return m_num_processed_tokens;
}

size_t get_output_seq_len() const {
return m_output_seq_len;
}

void set_output_seq_len(size_t len) {
m_output_seq_len = len;
}

/**
* Registers within the sequence group that a given amount of tokens
* has been evicted from the underlying KV cache.
Expand Down Expand Up @@ -436,11 +446,14 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {

void schedule_tokens(size_t num_tokens) {
m_num_scheduled_tokens = num_tokens;
// Unless otherwise specified, the sampler will process all scheduled tokens.
m_output_seq_len = num_tokens;
}

void clear_scheduled_tokens() {
m_num_scheduled_tokens = 0;
m_num_validation_tokens = 0;
m_output_seq_len = 0;
}

bool is_scheduled() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con

utils::apply_paged_attention_transformations(main_model, main_model_desc.scheduler_config.use_cache_eviction);
utils::apply_paged_attention_transformations(draft_model, main_model_desc.scheduler_config.use_cache_eviction);
utils::apply_gather_before_matmul_transformation(main_model);
utils::apply_gather_before_matmul_transformation(draft_model);
olpipi marked this conversation as resolved.
Show resolved Hide resolved

std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device;
bool is_draft_scheduler_undefined = draft_model_desc.scheduler_config == SchedulerConfig();
Expand Down
53 changes: 39 additions & 14 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
#include "utils.hpp"

#include <fstream>
#include <memory>

#include "openvino/op/add.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/slice.hpp"
Expand Down Expand Up @@ -230,23 +232,34 @@ ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::Token
return {new_input_ids, new_attention_mask};
}

void slice_matmul_stateful_model(std::shared_ptr<ov::Model> model) {
auto last_node = model->output(0).get_node()->input_value(0).get_node();
ov::Node* matmul = dynamic_cast<ov::op::v0::MatMul*>(last_node);
if (matmul) {
// we have found matmul, do nothing
} else if(auto add = dynamic_cast<ov::op::v1::Add*>(last_node)) {
matmul = dynamic_cast<ov::op::v0::MatMul*>(add->input_value(0).get_node());
} else if (auto transpose = dynamic_cast<ov::op::v1::Transpose*>(last_node)) {
matmul = dynamic_cast<ov::op::v0::MatMul*>(transpose->input_value(0).get_node());
} else if (auto multiply = dynamic_cast<ov::op::v1::Multiply*>(last_node)) {
if (auto tanh = dynamic_cast<ov::op::v0::Tanh*>(multiply->input_value(0).get_node())) {
if (auto divide = dynamic_cast<ov::op::v1::Divide*>(tanh->input_value(0).get_node())) {
matmul = dynamic_cast<ov::op::v0::MatMul*>(divide->input_value(0).get_node());
namespace {
std::shared_ptr<ov::Node> find_llm_matmul(const std::shared_ptr<ov::Model>& model) {
auto last_node = model->output(0).get_node()->input_value(0).get_node_shared_ptr();
std::shared_ptr<ov::Node> matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(last_node);
// There are several patterns for matmul we are looking for:
// Matmul -> Result
// Matmul -> Add -> Result
// Matmul -> Transpose -> Result
// MatMul -> Divide -> Tanh -> Multiply -> Result
if (!matmul) {
if(auto add = std::dynamic_pointer_cast<ov::op::v1::Add>(last_node)) {
olpipi marked this conversation as resolved.
Show resolved Hide resolved
matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(add->input_value(0).get_node_shared_ptr());
} else if (auto transpose = std::dynamic_pointer_cast<ov::op::v1::Transpose>(last_node)) {
matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(transpose->input_value(0).get_node_shared_ptr());
} else if (auto multiply = std::dynamic_pointer_cast<ov::op::v1::Multiply>(last_node)) {
if (auto tanh = std::dynamic_pointer_cast<ov::op::v0::Tanh>(multiply->input_value(0).get_node_shared_ptr())) {
if (auto divide = std::dynamic_pointer_cast<ov::op::v1::Divide>(tanh->input_value(0).get_node_shared_ptr())) {
matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(divide->input_value(0).get_node_shared_ptr());
}
}
}
}
return matmul;
}
} // namespace

void apply_slice_before_matmul_transformation(std::shared_ptr<ov::Model> model) {
auto matmul = find_llm_matmul(model);
if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
auto stop = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-2});
Expand All @@ -257,6 +270,19 @@ void slice_matmul_stateful_model(std::shared_ptr<ov::Model> model) {
}
}

void apply_gather_before_matmul_transformation(std::shared_ptr<ov::Model> model) {
auto matmul = ov::genai::utils::find_llm_matmul(model);
if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) {
auto indices = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{-1});
indices->set_friendly_name("sampled_tokens_indices");
indices->output(0).get_tensor().set_names({"sampled_tokens_indices"});
auto axis = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{0});
auto gather = std::make_shared<ov::op::v8::Gather>(matmul->input_value(0), indices, axis);
matmul->input(0).replace_source_output(gather);
model->add_parameters({indices});
}
}

template <typename T>
void read_rt_info(std::shared_ptr<ov::Model>& model, const char* name, T& value) {
if (!model)
Expand Down Expand Up @@ -396,7 +422,6 @@ void print_compiled_model_properties(ov::CompiledModel& compiled_Model, const ch
}
}
}

} // namespace utils
} // namespace genai
} // namespace ov
4 changes: 3 additions & 1 deletion src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ std::pair<ov::AnyMap, SchedulerConfig> split_scheduler_config(const ov::AnyMap&

ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& minuend, const ov::genai::TokenizedInputs& subtrahend);

void slice_matmul_stateful_model(std::shared_ptr<ov::Model> model);
void apply_slice_before_matmul_transformation(std::shared_ptr<ov::Model> model);

void apply_gather_before_matmul_transformation(std::shared_ptr<ov::Model> model);

ov::Core singleton_core();

Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/utils/paged_attention_transformations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ size_t get_hidden_size(const std::shared_ptr<ov::Model> model);

void set_kv_cache_type_and_shape(std::shared_ptr<ov::Model> model, DeviceConfig& device_config);

void apply_gather_before_matmul_transformation(std::shared_ptr<ov::Model> model);

} // namespace utils
} // namespace genai
} // namespace ov
Loading