@@ -136,9 +136,9 @@ void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder,
136
136
// attention_mask [1, 1, 1, 0]
137
137
auto input_ids_data = input_ids_tensor.data <int32_t >();
138
138
std::copy (init_ids.begin (), init_ids.end (), input_ids_data);
139
- std::fill (input_ids_data + init_ids.size (),
140
- input_ids_data + input_ids_tensor.get_size (),
141
- static_cast <int32_t >(pad_token));
139
+ // std::fill(input_ids_data + init_ids.size(),
140
+ // input_ids_data + input_ids_tensor.get_size(),
141
+ // static_cast<int32_t>(pad_token));
142
142
143
143
auto attention_mask_data = attention_mask_tensor.data <ov::float16>();
144
144
std::fill_n (attention_mask_data, init_ids.size (), 1u );
@@ -210,13 +210,13 @@ void zero_past_key_values(ov::InferRequest& request) {
210
210
}
211
211
}
212
212
213
- void prepare_decoder_with_past (ov::InferRequest& decoder_with_past, ov::InferRequest& decoder) {
213
+ void prepare_decoder_with_past (ov::InferRequest& decoder_with_past, ov::InferRequest& decoder, const size_t init_ids_size ) {
214
214
// NB: Prepare attetion mask to be in a format [0, 0, 0, 1, 1, 1, 1, ..., 0, 1]
215
215
// Mask should be inverted for decoder_with_past
216
216
auto attention_mask = decoder_with_past.get_tensor (" attention_mask" );
217
217
auto * attention_mask_ptr = attention_mask.data <ov::float16>();
218
- std::fill (attention_mask_ptr, attention_mask_ptr + 3u , 0 );
219
- std::fill (attention_mask_ptr + 3u , attention_mask_ptr + attention_mask.get_size () - 2 , 1 );
218
+ std::fill (attention_mask_ptr, attention_mask_ptr + init_ids_size , 0 );
219
+ std::fill (attention_mask_ptr + init_ids_size , attention_mask_ptr + attention_mask.get_size () - 2 , 1 );
220
220
attention_mask_ptr[attention_mask.get_size () - 2 ] = 0 ;
221
221
attention_mask_ptr[attention_mask.get_size () - 1 ] = 1 ;
222
222
// NB: Zero past_key_values.*.decoder.value tensors
@@ -318,7 +318,7 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
318
318
return {false , output_tokens};
319
319
}
320
320
321
- prepare_decoder_with_past (models.decoder_with_past , models.decoder );
321
+ prepare_decoder_with_past (models.decoder_with_past , models.decoder , init_ids. size () );
322
322
323
323
for (size_t i = 0 ; i < max_new_tokens - 1 ; i++) {
324
324
auto output_token = decode_with_past (models.decoder_with_past ,
@@ -489,7 +489,7 @@ void preprocess_decoder(std::shared_ptr<ov::Model> model) {
489
489
preprocessor.input (" attention_mask" ).preprocess ().convert_element_type ();
490
490
} else if (tensor.get_any_name ().find (" encoder_hidden_states" ) != std::string::npos) {
491
491
preprocessor.input (" encoder_hidden_states" ).tensor ().set_element_type (ov::element::Type_t::f16 );
492
- preprocessor.input (" encoder_hidden_states" ).preprocess ().convert_element_type (ov::element::Type_t:: f32 ); // ()
492
+ preprocessor.input (" encoder_hidden_states" ).preprocess ().convert_element_type ();
493
493
} else if (tensor.get_any_name ().find (" past_key_values" ) != std::string::npos) {
494
494
preprocessor.input (tensor.get_any_name ()).tensor ().set_element_type (ov::element::Type_t::f16 );
495
495
preprocessor.input (tensor.get_any_name ()).preprocess ().convert_element_type ();
@@ -563,7 +563,7 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
563
563
reshape_to_static_encoder (encoder_model, m_feature_extractor.feature_size );
564
564
565
565
auto last_hidden_state_shape = get_encoder_hidden_state_shape (encoder_model);
566
- reshape_to_static (decoder_model, 4 , 4 , last_hidden_state_shape);
566
+ reshape_to_static (decoder_model, 1 , 1 , last_hidden_state_shape); // for detect_language()
567
567
reshape_to_static (decoder_with_past_model, 1 , max_sequence_length, last_hidden_state_shape);
568
568
569
569
// Replace KV-tensors for the entire cache to tensors only for new token
@@ -577,9 +577,12 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
577
577
compiled_model = core.compile_model (encoder_model, " NPU" );
578
578
ov::genai::utils::print_compiled_model_properties (compiled_model, " Static Whisper encoder model" );
579
579
m_models.encoder = compiled_model.create_infer_request ();
580
+
581
+ m_decoder_model = decoder_model; // for reshape in generate() when we get number of input tokens
580
582
compiled_model = core.compile_model (decoder_model, " NPU" );
581
583
ov::genai::utils::print_compiled_model_properties (compiled_model, " Static Whisper decoder model" );
582
584
m_models.decoder = compiled_model.create_infer_request ();
585
+
583
586
compiled_model = core.compile_model (decoder_with_past_model, " NPU" );
584
587
ov::genai::utils::print_compiled_model_properties (compiled_model, " Static Whisper decoder with past model" );
585
588
m_models.decoder_with_past = compiled_model.create_infer_request ();
@@ -654,7 +657,13 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(
654
657
655
658
// prepare init_ids just once for whole input
656
659
if (init_ids.empty ()) {
660
+ OPENVINO_ASSERT (m_models.decoder .get_tensor (" input_ids" ).get_shape ().back () == 1 );
657
661
init_ids = prepare_init_ids (hidden_state_tensor, m_models.decoder , config, return_timestamps, raw_metrics);
662
+
663
+ // Reshape decoder model for the number of input tokens
664
+ ov::Core core = utils::singleton_core ();
665
+ reshape_to_static (m_decoder_model, init_ids.size (), init_ids.size (), hidden_state_tensor.get_shape ());
666
+ m_models.decoder = core.compile_model (m_decoder_model, " NPU" ).create_infer_request ();
658
667
}
659
668
660
669
auto [cancelled, chunk_output_tokens] = full_decode (hidden_state_tensor,
0 commit comments