Skip to content

Commit dfba871

Browse files
committed
Fix decoder inputs for static pipeline
1 parent 34dc469 commit dfba871

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

src/cpp/src/whisper_pipeline_static.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder,
136136
// attention_mask [1, 1, 1, 0]
137137
auto input_ids_data = input_ids_tensor.data<int32_t>();
138138
std::copy(init_ids.begin(), init_ids.end(), input_ids_data);
139-
std::fill(input_ids_data + init_ids.size(),
140-
input_ids_data + input_ids_tensor.get_size(),
141-
static_cast<int32_t>(pad_token));
139+
// std::fill(input_ids_data + init_ids.size(),
140+
// input_ids_data + input_ids_tensor.get_size(),
141+
// static_cast<int32_t>(pad_token));
142142

143143
auto attention_mask_data = attention_mask_tensor.data<ov::float16>();
144144
std::fill_n(attention_mask_data, init_ids.size(), 1u);
@@ -210,13 +210,13 @@ void zero_past_key_values(ov::InferRequest& request) {
210210
}
211211
}
212212

213-
void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder) {
213+
void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder, const size_t init_ids_size) {
214214
// NB: Prepare attetion mask to be in a format [0, 0, 0, 1, 1, 1, 1, ..., 0, 1]
215215
// Mask should be inverted for decoder_with_past
216216
auto attention_mask = decoder_with_past.get_tensor("attention_mask");
217217
auto* attention_mask_ptr = attention_mask.data<ov::float16>();
218-
std::fill(attention_mask_ptr, attention_mask_ptr + 3u, 0);
219-
std::fill(attention_mask_ptr + 3u, attention_mask_ptr + attention_mask.get_size() - 2, 1);
218+
std::fill(attention_mask_ptr, attention_mask_ptr + init_ids_size, 0);
219+
std::fill(attention_mask_ptr + init_ids_size, attention_mask_ptr + attention_mask.get_size() - 2, 1);
220220
attention_mask_ptr[attention_mask.get_size() - 2] = 0;
221221
attention_mask_ptr[attention_mask.get_size() - 1] = 1;
222222
// NB: Zero past_key_values.*.decoder.value tensors
@@ -318,7 +318,7 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
318318
return {false, output_tokens};
319319
}
320320

321-
prepare_decoder_with_past(models.decoder_with_past, models.decoder);
321+
prepare_decoder_with_past(models.decoder_with_past, models.decoder, init_ids.size());
322322

323323
for (size_t i = 0; i < max_new_tokens - 1; i++) {
324324
auto output_token = decode_with_past(models.decoder_with_past,
@@ -489,7 +489,7 @@ void preprocess_decoder(std::shared_ptr<ov::Model> model) {
489489
preprocessor.input("attention_mask").preprocess().convert_element_type();
490490
} else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) {
491491
preprocessor.input("encoder_hidden_states").tensor().set_element_type(ov::element::Type_t::f16);
492-
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32); // ()
492+
preprocessor.input("encoder_hidden_states").preprocess().convert_element_type();
493493
} else if (tensor.get_any_name().find("past_key_values") != std::string::npos) {
494494
preprocessor.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
495495
preprocessor.input(tensor.get_any_name()).preprocess().convert_element_type();
@@ -563,7 +563,7 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
563563
reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size);
564564

565565
auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model);
566-
reshape_to_static(decoder_model, 4, 4, last_hidden_state_shape);
566+
reshape_to_static(decoder_model, 1, 1, last_hidden_state_shape); // for detect_language()
567567
reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape);
568568

569569
// Replace KV-tensors for the entire cache to tensors only for new token
@@ -577,9 +577,12 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
577577
compiled_model = core.compile_model(encoder_model, "NPU");
578578
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model");
579579
m_models.encoder = compiled_model.create_infer_request();
580+
581+
m_decoder_model = decoder_model; // for reshape in generate() when we get number of input tokens
580582
compiled_model = core.compile_model(decoder_model, "NPU");
581583
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
582584
m_models.decoder = compiled_model.create_infer_request();
585+
583586
compiled_model = core.compile_model(decoder_with_past_model, "NPU");
584587
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model");
585588
m_models.decoder_with_past = compiled_model.create_infer_request();
@@ -654,7 +657,13 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(
654657

655658
// prepare init_ids just once for whole input
656659
if (init_ids.empty()) {
660+
OPENVINO_ASSERT(m_models.decoder.get_tensor("input_ids").get_shape().back() == 1);
657661
init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics);
662+
663+
// Reshape decoder model for the number of input tokens
664+
ov::Core core = utils::singleton_core();
665+
reshape_to_static(m_decoder_model, init_ids.size(), init_ids.size(), hidden_state_tensor.get_shape());
666+
m_models.decoder = core.compile_model(m_decoder_model, "NPU").create_infer_request();
658667
}
659668

660669
auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor,

src/cpp/src/whisper_pipeline_static.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi
2525

2626
private:
2727
WhisperInitializedModels m_models;
28+
std::shared_ptr<ov::Model> m_decoder_model;
2829
};
2930

3031
} // namespace genai

0 commit comments

Comments
 (0)