diff --git a/src/cpp/src/model_runner.hpp b/src/cpp/src/model_runner.hpp index 699e5bfccc..3462f2566a 100644 --- a/src/cpp/src/model_runner.hpp +++ b/src/cpp/src/model_runner.hpp @@ -129,8 +129,8 @@ class ModelRunner { size_t num_running_sequences = running_sequences.size(); size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens(); size_t group_position_id = sequence_group->get_num_processed_tokens(); - auto prompt_len = sequence_group->get_prompt_len(); - size_t tokens_num_to_sample = 0; + size_t prompt_len = sequence_group->get_prompt_len(); + size_t seq_len_after_gather = 0; // spec: In case of multiple input tokens for current sequence (prompt_len > 1), // context_len corresponds to first token within subgroup of scheduled tokens @@ -148,7 +148,7 @@ class ModelRunner { if (matmul_gathering_is_required) { if (group_position_id + token_id >= prompt_len - 1) { gather_indice_values.push_back(gathering_current_index); - tokens_num_to_sample++; + seq_len_after_gather++; } } position_ids_data[token_id] = position_id; @@ -169,7 +169,7 @@ class ModelRunner { subsequence_begins_data += 1; block_indices_begins_data += 1; } - sequence_group->set_seq_len_to_sample(tokens_num_to_sample); + sequence_group->set_seq_len_to_sample(matmul_gathering_is_required ? std::min(seq_len_after_gather, num_scheduled_tokens) : num_scheduled_tokens); } // typical LLM parameters diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index a6cbab44f5..c8b68dbe88 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -757,7 +757,7 @@ SamplerOutput Sampler::sample(std::vector & sequence_groups, size_t num_running_sequences = sequence_group->num_running_seqs(); size_t actual_seq_len = sequence_group->is_matmul_sliced() ? sequence_group->get_seq_len_to_sample() : sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled - size_t padded_amount_of_processed_tokens = std::max(sequence_group->get_num_scheduled_tokens(), batch_seq_len); + size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len); const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters(); const auto request_id = sequence_group->get_request_id();