From 7d78ca7c439bd653aed3b0ff2b868fc20c59d8f8 Mon Sep 17 00:00:00 2001 From: Ekaterina Shiryaeva Date: Mon, 11 Nov 2024 14:36:47 +0000 Subject: [PATCH] Address review comments --- src/cpp/src/whisper_pipeline_static.cpp | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index ada19ae74c..792c90f865 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -398,16 +398,11 @@ void add_attention_mask_input(std::shared_ptr model) { } -uint32_t get_encoder_hidden_state_size(const std::shared_ptr& encoder) { - for (auto output : encoder->outputs()) { - if (output.get_any_name() == "last_hidden_state") { - return output.get_partial_shape()[1].get_length(); - } - } - return 0; +ov::PartialShape get_encoder_hidden_state_shape(const std::shared_ptr& encoder) { + return encoder->output("last_hidden_state").get_partial_shape(); } -void reshape_to_static(std::shared_ptr model, const uint32_t input_size, const uint32_t kvcache_size, const uint32_t lhstate_size) { +void reshape_to_static(std::shared_ptr model, const uint32_t input_size, const uint32_t kvcache_size, const ov::PartialShape& lhstate_shape) { std::map new_shapes; for (auto input : model->inputs()) { const auto& input_name = input.get_any_name(); @@ -424,14 +419,15 @@ void reshape_to_static(std::shared_ptr model, const uint32_t input_si const auto& partial_shape = input.get_partial_shape(); new_shape = partial_shape; new_shape[0] = 1; // batch_dim - new_shape[1] = lhstate_size; // from encoder output{'last_hidden_state'} + new_shape[1] = lhstate_shape[1]; // from encoder output{'last_hidden_state'} + new_shape[2] = lhstate_shape[2]; } else if (input_name.find("past_key_values") != std::string::npos) { const auto& partial_shape = input.get_partial_shape(); new_shape = partial_shape; new_shape[0] = 1; // Use batch dim here new_shape[2] = input_name.find(".decoder") != std::string::npos ? kvcache_size - input_size // kv_size for decoder - : lhstate_size; // hidden state size for encoder + : lhstate_shape[1]; // hidden state size for encoder } new_shapes.emplace(input_name, new_shape); } @@ -549,9 +545,9 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys reshape_to_static_encoder(encoder_model); - size_t last_hidden_state_size = get_encoder_hidden_state_size(encoder_model); - reshape_to_static(decoder_model, 4, 4, last_hidden_state_size); - reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_size); + auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model); + reshape_to_static(decoder_model, 4, 4, last_hidden_state_shape); + reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape); // Replace KV-tensors for the entire cache to tensors only for new token decoder_with_past_model = redirect_new_kv_to_output(decoder_with_past_model);