diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index fc321f822e..ada19ae74c 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -397,7 +397,17 @@ void add_attention_mask_input(std::shared_ptr model) { pm.run_passes(model); } -void reshape_to_static(std::shared_ptr model, const uint32_t input_size, const uint32_t kvcache_size) { + +uint32_t get_encoder_hidden_state_size(const std::shared_ptr& encoder) { + for (auto output : encoder->outputs()) { + if (output.get_any_name() == "last_hidden_state") { + return output.get_partial_shape()[1].get_length(); + } + } + return 0; +} + +void reshape_to_static(std::shared_ptr model, const uint32_t input_size, const uint32_t kvcache_size, const uint32_t lhstate_size) { std::map new_shapes; for (auto input : model->inputs()) { const auto& input_name = input.get_any_name(); @@ -414,14 +424,14 @@ void reshape_to_static(std::shared_ptr model, const uint32_t input_si const auto& partial_shape = input.get_partial_shape(); new_shape = partial_shape; new_shape[0] = 1; // batch_dim - new_shape[1] = 1500; // FIXME: is it got from encoder output{'last_hidden_state'} + new_shape[1] = lhstate_size; // from encoder output{'last_hidden_state'} } else if (input_name.find("past_key_values") != std::string::npos) { const auto& partial_shape = input.get_partial_shape(); new_shape = partial_shape; new_shape[0] = 1; // Use batch dim here new_shape[2] = input_name.find(".decoder") != std::string::npos ? kvcache_size - input_size // kv_size for decoder - : 1500; // for encoder + : lhstate_size; // hidden state size for encoder } new_shapes.emplace(input_name, new_shape); } @@ -538,8 +548,10 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys size_t max_sequence_length = 448; reshape_to_static_encoder(encoder_model); - reshape_to_static(decoder_model, 4, 4); - reshape_to_static(decoder_with_past_model, 1, max_sequence_length); + + size_t last_hidden_state_size = get_encoder_hidden_state_size(encoder_model); + reshape_to_static(decoder_model, 4, 4, last_hidden_state_size); + reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_size); // Replace KV-tensors for the entire cache to tensors only for new token decoder_with_past_model = redirect_new_kv_to_output(decoder_with_past_model);