Skip to content

Commit

Permalink
CB: removed handle_dropped() misuse (#1594)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov authored Jan 20, 2025
1 parent 57f32c7 commit 8aeb714
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 40 deletions.
12 changes: 6 additions & 6 deletions src/cpp/src/block_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,9 @@ class BlockManager {
*/
const size_t free_group_partially(SequenceGroup::Ptr sequence_group, size_t num_required_blocks) {
size_t blocks_num = std::ceil(num_required_blocks / sequence_group->get_not_finished_sequences().size());
auto running_sequences = sequence_group->get_not_finished_sequences();
for (size_t idx = 0; idx < running_sequences.size(); ++idx) {
auto seq_id = running_sequences[idx]->get_id();
auto not_finished_sequences = sequence_group->get_not_finished_sequences();
for (size_t idx = 0; idx < not_finished_sequences.size(); ++idx) {
auto seq_id = not_finished_sequences[idx]->get_id();
OPENVINO_ASSERT(m_block_table.count(seq_id) > 0, "Invalid sequence group.");
free_sequence_partially(seq_id, blocks_num);
}
Expand All @@ -579,9 +579,9 @@ class BlockManager {

const size_t free_last_block_from_each_sequence(SequenceGroup::Ptr sequence_group) {
size_t blocks_released = 0;
auto running_sequences = sequence_group->get_not_finished_sequences();
for (size_t idx = 0; idx < running_sequences.size(); ++idx) {
auto seq_id = running_sequences[idx]->get_id();
auto not_finished_sequences = sequence_group->get_not_finished_sequences();
for (size_t idx = 0; idx < not_finished_sequences.size(); ++idx) {
auto seq_id = not_finished_sequences[idx]->get_id();
OPENVINO_ASSERT(m_block_table.count(seq_id) > 0, "Invalid sequence group.");
if (free_last_block(seq_id)) {
blocks_released++;
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_free_non_running_reque
std::vector<SequenceGroup::Ptr>::iterator requests_iterator = m_requests.begin();
while (requests_iterator != m_requests.end()) {
const auto& request = *requests_iterator;
if(request->has_finished() || request->out_of_memory() || request->handle_dropped()) {
if (request->has_finished() || request->handle_dropped()) {
for (const auto& sequence: request->get_sequences()) {
if (m_scheduler->has_block_table(sequence->get_id())) {
m_scheduler->free_sequence(sequence->get_id());
Expand Down
15 changes: 9 additions & 6 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ enum StaticPipelineKind {
STATEFUL,
STATELESS
};

StaticPipelineKind str_to_pipeline(const std::string& str) {
if (str == "STATEFUL") {
return StaticPipelineKind::STATEFUL;
Expand Down Expand Up @@ -935,7 +936,7 @@ EncodedResults StatefulLLMPipeline::generate(
m_request.set_tensor("input_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast<void*>(&input_ids_data)));
m_request.set_tensor("position_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast<void*>(&position_ids_data)));

while (sequence_group->is_running()) {
while (sequence_group->is_running() && !sequence_group->handle_dropped()) {
// KV Cache is full, no further generation is possible
if (position_ids_data + 1 == m_kvcache_total) {
sequence_group->set_out_of_memory();
Expand All @@ -959,12 +960,11 @@ EncodedResults StatefulLLMPipeline::generate(
raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now());
raw_perf_counters.m_batch_sizes.emplace_back(batch_size);

SamplerOutput sampler_output = m_sampler.sample(
{sequence_group}, m_request.get_tensor("logits"));
SamplerOutput sampler_output = m_sampler.sample({sequence_group}, m_request.get_tensor("logits"));
stream_generated_tokens(streamer_ptr, handle);
}

if (streamer_ptr) {
if (streamer_ptr) { // push streamer's cache
streamer_ptr->end();
}

Expand Down Expand Up @@ -1441,7 +1441,7 @@ EncodedResults StatelessLLMPipeline::generate(
std::fill(attention_mask_data, attention_mask_data + m_kvcache_desc.num_stored_tokens - 1u, 1u);
attention_mask_data[m_kvcache_desc.total_size - 1] = 1u;

while (sequence_group->is_running()) {
while (sequence_group->is_running() && !sequence_group->handle_dropped()) {
sequence_group->schedule_tokens(1);
const auto running_sequences = sequence_group->get_running_sequences();
OPENVINO_ASSERT(running_sequences.size() == 1u);
Expand All @@ -1460,6 +1460,9 @@ EncodedResults StatelessLLMPipeline::generate(
{sequence_group}, m_kvcache_request.get_tensor("logits"));
stream_generated_tokens(streamer_ptr, handle);

if (sequence_group->handle_dropped())
break;

// NB: KV-cache is full, further generation is impossible
if (m_kvcache_desc.num_stored_tokens == m_kvcache_desc.total_size) {
sequence_group->set_out_of_memory();
Expand All @@ -1482,7 +1485,7 @@ EncodedResults StatelessLLMPipeline::generate(
}
}

if (streamer_ptr) {
if (streamer_ptr) { // push streamer's cache
streamer_ptr->end();
}

Expand Down
6 changes: 3 additions & 3 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
auto free_non_running_requests = [&streamer_ptr, &generations, &active_sequence_groups]() {
auto removed_it = std::remove_if(active_sequence_groups.begin(), active_sequence_groups.end(),
[](SequenceGroup::Ptr sg) -> bool {
return sg->has_finished() || sg->out_of_memory() || sg->handle_dropped();
return sg->has_finished() || sg->handle_dropped();
});
active_sequence_groups.erase(removed_it, active_sequence_groups.end());
};
Expand Down Expand Up @@ -152,7 +152,7 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
beam_offets.insert({sequence_groups.at(i)->get_request_id(), i});

SamplerOutput sampler_output = sampler.sample(sequence_groups, logits);
free_non_running_requests();
free_non_running_requests(); // handle sampler output

// "Generation" phase

Expand Down Expand Up @@ -239,7 +239,7 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
raw_perf_counters.m_batch_sizes.emplace_back(current_batch_size);

sampler_output = sampler.sample(active_sequence_groups, m_llm.get_tensor("logits"));
free_non_running_requests();
free_non_running_requests(); // handle sampler output
}

stream_generated_tokens();
Expand Down
52 changes: 28 additions & 24 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,15 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
m_block_size(block_size),
m_generation_stream(GenerationStream::create()) { }

bool out_of_memory() const {
for (size_t seq_id = 0; seq_id < m_sequences.size(); ++seq_id) {
if (m_sequences[seq_id]->out_of_memory()) {
return true;
}
}
return false;
}

public:
using Ptr = std::shared_ptr<SequenceGroup>;
using CPtr = std::shared_ptr<const SequenceGroup>;
Expand Down Expand Up @@ -294,22 +303,18 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
return m_sequences.size();
}

size_t num_finished_seqs() const {
return std::count_if(m_sequences.begin(), m_sequences.end(), [this] (Sequence::CPtr seq) {
return seq->has_finished() || seq->out_of_memory() || handle_dropped();
});
}

size_t num_running_seqs() const {
return num_total_seqs() - num_finished_seqs();
return std::count_if(m_sequences.begin(), m_sequences.end(), [] (Sequence::CPtr seq) {
return seq->is_running();
});
}

bool has_finished() const {
return num_running_seqs() == 0;
return !is_running();
}

bool is_running() const {
return !has_finished();
return num_running_seqs() > 0;
}

const std::vector<Sequence::Ptr>& get_sequences() const {
Expand All @@ -336,14 +341,21 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
return *it;
}

// must be used only after sequence group generation loop has finished (either by lenght or OOM)
// or stopped / cancelled via streamer / generation_stream->drop()
std::vector<Sequence::CPtr> get_finished_sequences() const {
std::vector<Sequence::CPtr> finished_seqs;
finished_seqs.reserve(num_total_seqs());

for (size_t seq_id = 0; seq_id < m_sequences.size(); ++seq_id) {
if (m_sequences[seq_id]->has_finished() || m_sequences[seq_id]->out_of_memory() || handle_dropped()) {
finished_seqs.push_back(m_sequences[seq_id]);
}
}

OPENVINO_ASSERT(finished_seqs.size() == num_total_seqs(), "Internal error: get_finished_sequences() must be called when all sequences are "
"either finisehed / ignored by OOM or dropped via GenerationStream::drop()");

std::sort(finished_seqs.begin(), finished_seqs.end(), [=] (Sequence::CPtr s1, Sequence::CPtr s2) -> bool {
bool is_beam_search = m_sampling_params.is_beam_search();
const float score_1 = is_beam_search ? s1->get_beam_search_score(m_sampling_params) : s1->get_cumulative_log_prob();
Expand All @@ -354,21 +366,22 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
return finished_seqs;
}

std::vector<Sequence::Ptr> get_running_sequences() {
// returns running or waiting sequences
std::vector<Sequence::Ptr> get_not_finished_sequences() {
std::vector<Sequence::Ptr> running_seqs;
for (size_t seq_id = 0; seq_id < m_sequences.size(); ++seq_id) {
if (m_sequences[seq_id]->is_running()) {
if (!m_sequences[seq_id]->has_finished()) {
running_seqs.emplace_back(m_sequences[seq_id]);
}
}

return running_seqs;
}

std::vector<Sequence::Ptr> get_not_finished_sequences() {
std::vector<Sequence::Ptr> get_running_sequences() {
std::vector<Sequence::Ptr> running_seqs;
for (size_t seq_id = 0; seq_id < m_sequences.size(); ++seq_id) {
if (!m_sequences[seq_id]->has_finished()) {
if (m_sequences[seq_id]->is_running()) {
running_seqs.emplace_back(m_sequences[seq_id]);
}
}
Expand Down Expand Up @@ -559,15 +572,6 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
}
}

bool out_of_memory() const {
for (size_t seq_id = 0; seq_id < m_sequences.size(); ++seq_id) {
if (m_sequences[seq_id]->out_of_memory()) {
return true;
}
}
return false;
}

bool is_waiting() const {
for (size_t seq_id = 0; seq_id < m_sequences.size(); ++seq_id) {
if (m_sequences[seq_id]->is_waiting()) {
Expand Down Expand Up @@ -634,7 +638,7 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
}
// For beam search streaming is not available, so we notify only upon finishing
if (m_sampling_params.is_beam_search()) {
if (has_finished() || out_of_memory()) {
if (has_finished()) {
push_outputs();
}
} else if (m_sampling_params.is_greedy_decoding() || m_sampling_params.is_multinomial()) {
Expand All @@ -656,7 +660,7 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
size_t num_output_token_to_push = generated_len - m_num_streamed_tokens - m_stream_window_size;
push_partial_outputs(num_output_token_to_push);
m_num_streamed_tokens += (num_output_token_to_push);
} else if (has_finished() || out_of_memory()) {
} else if (has_finished()) {
push_outputs();
}
}
Expand Down

0 comments on commit 8aeb714

Please sign in to comment.