diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index fc321f822e..792c90f865 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -397,7 +397,12 @@ void add_attention_mask_input(std::shared_ptr model) { pm.run_passes(model); } -void reshape_to_static(std::shared_ptr model, const uint32_t input_size, const uint32_t kvcache_size) { + +ov::PartialShape get_encoder_hidden_state_shape(const std::shared_ptr& encoder) { + return encoder->output("last_hidden_state").get_partial_shape(); +} + +void reshape_to_static(std::shared_ptr model, const uint32_t input_size, const uint32_t kvcache_size, const ov::PartialShape& lhstate_shape) { std::map new_shapes; for (auto input : model->inputs()) { const auto& input_name = input.get_any_name(); @@ -414,14 +419,15 @@ void reshape_to_static(std::shared_ptr model, const uint32_t input_si const auto& partial_shape = input.get_partial_shape(); new_shape = partial_shape; new_shape[0] = 1; // batch_dim - new_shape[1] = 1500; // FIXME: is it got from encoder output{'last_hidden_state'} + new_shape[1] = lhstate_shape[1]; // from encoder output{'last_hidden_state'} + new_shape[2] = lhstate_shape[2]; } else if (input_name.find("past_key_values") != std::string::npos) { const auto& partial_shape = input.get_partial_shape(); new_shape = partial_shape; new_shape[0] = 1; // Use batch dim here new_shape[2] = input_name.find(".decoder") != std::string::npos ? kvcache_size - input_size // kv_size for decoder - : 1500; // for encoder + : lhstate_shape[1]; // hidden state size for encoder } new_shapes.emplace(input_name, new_shape); } @@ -538,8 +544,10 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys size_t max_sequence_length = 448; reshape_to_static_encoder(encoder_model); - reshape_to_static(decoder_model, 4, 4); - reshape_to_static(decoder_with_past_model, 1, max_sequence_length); + + auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model); + reshape_to_static(decoder_model, 4, 4, last_hidden_state_shape); + reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape); // Replace KV-tensors for the entire cache to tensors only for new token decoder_with_past_model = redirect_new_kv_to_output(decoder_with_past_model);