From dfba8711c2c6075f9a1d11307e2659232bd46e5b Mon Sep 17 00:00:00 2001 From: Ekaterina Shiryaeva Date: Tue, 31 Dec 2024 15:20:49 +0000 Subject: [PATCH 1/5] Fix decoder inputs for static pipeline --- src/cpp/src/whisper_pipeline_static.cpp | 27 ++++++++++++++++--------- src/cpp/src/whisper_pipeline_static.hpp | 1 + 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index 01fe882187..c47f1a8af5 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -136,9 +136,9 @@ void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder, // attention_mask [1, 1, 1, 0] auto input_ids_data = input_ids_tensor.data(); std::copy(init_ids.begin(), init_ids.end(), input_ids_data); - std::fill(input_ids_data + init_ids.size(), - input_ids_data + input_ids_tensor.get_size(), - static_cast(pad_token)); + // std::fill(input_ids_data + init_ids.size(), + // input_ids_data + input_ids_tensor.get_size(), + // static_cast(pad_token)); auto attention_mask_data = attention_mask_tensor.data(); std::fill_n(attention_mask_data, init_ids.size(), 1u); @@ -210,13 +210,13 @@ void zero_past_key_values(ov::InferRequest& request) { } } -void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder) { +void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder, const size_t init_ids_size) { // NB: Prepare attetion mask to be in a format [0, 0, 0, 1, 1, 1, 1, ..., 0, 1] // Mask should be inverted for decoder_with_past auto attention_mask = decoder_with_past.get_tensor("attention_mask"); auto* attention_mask_ptr = attention_mask.data(); - std::fill(attention_mask_ptr, attention_mask_ptr + 3u, 0); - std::fill(attention_mask_ptr + 3u, attention_mask_ptr + attention_mask.get_size() - 2, 1); + std::fill(attention_mask_ptr, attention_mask_ptr + init_ids_size, 0); + std::fill(attention_mask_ptr + init_ids_size, attention_mask_ptr + attention_mask.get_size() - 2, 1); attention_mask_ptr[attention_mask.get_size() - 2] = 0; attention_mask_ptr[attention_mask.get_size() - 1] = 1; // NB: Zero past_key_values.*.decoder.value tensors @@ -318,7 +318,7 @@ std::pair> full_decode(ov::Tensor& encoder_hidden_sta return {false, output_tokens}; } - prepare_decoder_with_past(models.decoder_with_past, models.decoder); + prepare_decoder_with_past(models.decoder_with_past, models.decoder, init_ids.size()); for (size_t i = 0; i < max_new_tokens - 1; i++) { auto output_token = decode_with_past(models.decoder_with_past, @@ -489,7 +489,7 @@ void preprocess_decoder(std::shared_ptr model) { preprocessor.input("attention_mask").preprocess().convert_element_type(); } else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) { preprocessor.input("encoder_hidden_states").tensor().set_element_type(ov::element::Type_t::f16); - preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32); // () + preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(); } else if (tensor.get_any_name().find("past_key_values") != std::string::npos) { preprocessor.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16); preprocessor.input(tensor.get_any_name()).preprocess().convert_element_type(); @@ -563,7 +563,7 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size); auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model); - reshape_to_static(decoder_model, 4, 4, last_hidden_state_shape); + reshape_to_static(decoder_model, 1, 1, last_hidden_state_shape); // for detect_language() reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape); // Replace KV-tensors for the entire cache to tensors only for new token @@ -577,9 +577,12 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys compiled_model = core.compile_model(encoder_model, "NPU"); ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model"); m_models.encoder = compiled_model.create_infer_request(); + + m_decoder_model = decoder_model; // for reshape in generate() when we get number of input tokens compiled_model = core.compile_model(decoder_model, "NPU"); ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model"); m_models.decoder = compiled_model.create_infer_request(); + compiled_model = core.compile_model(decoder_with_past_model, "NPU"); ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model"); m_models.decoder_with_past = compiled_model.create_infer_request(); @@ -654,7 +657,13 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( // prepare init_ids just once for whole input if (init_ids.empty()) { + OPENVINO_ASSERT(m_models.decoder.get_tensor("input_ids").get_shape().back() == 1); init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics); + + // Reshape decoder model for the number of input tokens + ov::Core core = utils::singleton_core(); + reshape_to_static(m_decoder_model, init_ids.size(), init_ids.size(), hidden_state_tensor.get_shape()); + m_models.decoder = core.compile_model(m_decoder_model, "NPU").create_infer_request(); } auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor, diff --git a/src/cpp/src/whisper_pipeline_static.hpp b/src/cpp/src/whisper_pipeline_static.hpp index cbd57beb18..35395f5693 100644 --- a/src/cpp/src/whisper_pipeline_static.hpp +++ b/src/cpp/src/whisper_pipeline_static.hpp @@ -25,6 +25,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi private: WhisperInitializedModels m_models; + std::shared_ptr m_decoder_model; }; } // namespace genai From 7e8bbfeeb1b84847b5c62fd106891ccd67f2d349 Mon Sep 17 00:00:00 2001 From: Ekaterina Shiryaeva Date: Tue, 7 Jan 2025 11:10:41 +0000 Subject: [PATCH 2/5] Whisper: address comments --- src/cpp/src/whisper_pipeline_static.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index c47f1a8af5..a0c7c3a68b 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -136,9 +136,6 @@ void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder, // attention_mask [1, 1, 1, 0] auto input_ids_data = input_ids_tensor.data(); std::copy(init_ids.begin(), init_ids.end(), input_ids_data); - // std::fill(input_ids_data + init_ids.size(), - // input_ids_data + input_ids_tensor.get_size(), - // static_cast(pad_token)); auto attention_mask_data = attention_mask_tensor.data(); std::fill_n(attention_mask_data, init_ids.size(), 1u); @@ -489,7 +486,7 @@ void preprocess_decoder(std::shared_ptr model) { preprocessor.input("attention_mask").preprocess().convert_element_type(); } else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) { preprocessor.input("encoder_hidden_states").tensor().set_element_type(ov::element::Type_t::f16); - preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(); + preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32); } else if (tensor.get_any_name().find("past_key_values") != std::string::npos) { preprocessor.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16); preprocessor.input(tensor.get_any_name()).preprocess().convert_element_type(); From ac1655834f7340bafa6bd05956fe29c3da1a07de Mon Sep 17 00:00:00 2001 From: Ekaterina Shiryaeva Date: Thu, 9 Jan 2025 15:54:22 +0000 Subject: [PATCH 3/5] Added decoder cache, removed decoder's attn mask --- src/cpp/src/whisper_pipeline_static.cpp | 93 ++++++++++--------------- src/cpp/src/whisper_pipeline_static.hpp | 13 +++- 2 files changed, 48 insertions(+), 58 deletions(-) diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index a0c7c3a68b..3e09099c28 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -121,25 +121,15 @@ void update_past_key_value(ov::InferRequest& source, ov::InferRequest& dest, con } } -void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder, - const std::vector& init_ids, - const int64_t pad_token) { +void set_decoder_input_ids(ov::InferRequest& decoder, + const std::vector& init_ids) { auto input_ids_tensor = decoder.get_tensor("input_ids"); - auto attention_mask_tensor = decoder.get_tensor("attention_mask"); - const size_t seq_length = input_ids_tensor.get_shape()[1]; OPENVINO_ASSERT(seq_length >= init_ids.size()); - // pad right - // input_ids [token, token, token, pad_token] - // attention_mask [1, 1, 1, 0] auto input_ids_data = input_ids_tensor.data(); std::copy(init_ids.begin(), init_ids.end(), input_ids_data); - - auto attention_mask_data = attention_mask_tensor.data(); - std::fill_n(attention_mask_data, init_ids.size(), 1u); - std::fill(attention_mask_data + init_ids.size(), attention_mask_data + attention_mask_tensor.get_size(), 0u); } int64_t decode(ov::Tensor& encoder_hidden_state, @@ -151,7 +141,7 @@ int64_t decode(ov::Tensor& encoder_hidden_state, const bool return_timestamps = false) { // NB: Fill decoder inputs encoder_hidden_state.copy_to(decoder.get_tensor("encoder_hidden_states")); - set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id); + set_decoder_input_ids(decoder, init_ids); ov::genai::utils::infer_with_perf_metrics(decoder, raw_metrics); @@ -230,7 +220,7 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state, decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state}); std::vector init_ids{static_cast(config.decoder_start_token_id)}; - set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id); + set_decoder_input_ids(decoder, init_ids); const auto infer_start = std::chrono::steady_clock::now(); decoder.infer(); @@ -350,36 +340,6 @@ bool check_decoder_model_compatibility(const std::shared_ptr& decoder return false; } -void add_attention_mask_input_for_decoder(std::shared_ptr model) { - using namespace ov::pass::pattern; - using namespace ov::op; - class AttentionMaskInput : public ov::pass::MatcherPass { - public: - OPENVINO_RTTI("AttentionMaskInput"); - - AttentionMaskInput(std::shared_ptr model) { - auto range = wrap_type(); - auto convert = wrap_type({range}); - auto convert1 = wrap_type({convert}); - auto greater = wrap_type({convert1, any_input()}); - auto convert2 = wrap_type({greater}); - - register_matcher(std::make_shared(convert2, this->get_type_info().name), [model](Matcher& m) { - auto node = m.get_match_root(); - auto attention_mask = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1}); - attention_mask->get_output_tensor(0).set_names({"attention_mask"}); - model->add_parameters({attention_mask}); - ov::replace_node(node, attention_mask); - return false; - }); - } - }; - - ov::pass::Manager pm; - pm.register_pass(model); - pm.run_passes(model); -} - void add_attention_mask_input(std::shared_ptr model) { using namespace ov::pass::pattern; using namespace ov::op; @@ -464,6 +424,22 @@ void reshape_to_static_encoder(std::shared_ptr model, const size_t fe model->reshape(new_shapes); } +void reshape_input_ids(std::shared_ptr model, const uint32_t input_size) { + std::map new_shapes; + for (auto input : model->inputs()) { + const auto& input_name = input.get_any_name(); + ov::PartialShape new_shape; + if (input_name.find("input_ids") != std::string::npos) { + new_shape = ov::PartialShape({1, input_size}); + } else { + new_shape = input.get_partial_shape(); + } + new_shapes.emplace(input_name, new_shape); + } + + model->reshape(new_shapes); +} + void preprocess_encoder(std::shared_ptr model) { ov::preprocess::PrePostProcessor preprocessor(model); @@ -538,6 +514,17 @@ std::shared_ptr redirect_new_kv_to_output(const std::shared_ptr model) : m_decoder_model(model) {} + + ov::CompiledModel get_model(uint8_t input_ids_size); +private: + std::unordered_map m_cache; + std::shared_ptr m_decoder_model; +}; + class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase { public: StaticWhisperPipeline(const std::filesystem::path& model_path, const ov::AnyMap& properties); @@ -25,7 +36,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi private: WhisperInitializedModels m_models; - std::shared_ptr m_decoder_model; + DecoderCache m_decoder_cache; }; } // namespace genai From 614e6d9b7802d89b00c480e20b27c0f5b7913b04 Mon Sep 17 00:00:00 2001 From: Ekaterina Shiryaeva Date: Thu, 9 Jan 2025 19:08:11 +0000 Subject: [PATCH 4/5] Address comments --- src/cpp/src/whisper_pipeline_static.cpp | 33 +++++++++---------------- src/cpp/src/whisper_pipeline_static.hpp | 7 ++++-- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index 3e09099c28..5e0f2ef20e 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -425,19 +425,7 @@ void reshape_to_static_encoder(std::shared_ptr model, const size_t fe } void reshape_input_ids(std::shared_ptr model, const uint32_t input_size) { - std::map new_shapes; - for (auto input : model->inputs()) { - const auto& input_name = input.get_any_name(); - ov::PartialShape new_shape; - if (input_name.find("input_ids") != std::string::npos) { - new_shape = ov::PartialShape({1, input_size}); - } else { - new_shape = input.get_partial_shape(); - } - new_shapes.emplace(input_name, new_shape); - } - - model->reshape(new_shapes); + model->reshape({{"input_ids", ov::PartialShape({1, input_size})}}); } void preprocess_encoder(std::shared_ptr model) { @@ -516,10 +504,16 @@ namespace genai { ov::CompiledModel DecoderCache::get_model(uint8_t input_ids_size) { if (m_cache.find(input_ids_size) == m_cache.cend()) { - reshape_input_ids(m_decoder_model, input_ids_size); + if (m_decoder_model->is_dynamic()) { // model is dynamic, reshaping it to static + reshape_to_static(m_decoder_model, input_ids_size, input_ids_size, m_lhs_shape); + } else { + reshape_input_ids(m_decoder_model, input_ids_size); + } ov::Core core = utils::singleton_core(); - m_cache.insert({input_ids_size, core.compile_model(m_decoder_model, "NPU")}); + ov::CompiledModel compiled_model = core.compile_model(m_decoder_model, "NPU"); + ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model"); + m_cache.emplace(input_ids_size, compiled_model); } return m_cache.at(input_ids_size); @@ -541,7 +535,6 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size); auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model); - reshape_to_static(decoder_model, 1, 1, last_hidden_state_shape); // for detect_language() reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape); // Replace KV-tensors for the entire cache to tensors only for new token @@ -556,10 +549,8 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model"); m_models.encoder = compiled_model.create_infer_request(); - m_decoder_cache = DecoderCache(decoder_model); - compiled_model = m_decoder_cache.get_model(1); - ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model"); - m_models.decoder = compiled_model.create_infer_request(); + // Will compile decoder model when it's needed + m_decoder_cache = DecoderCache(decoder_model, last_hidden_state_shape); compiled_model = core.compile_model(decoder_with_past_model, "NPU"); ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model"); @@ -635,7 +626,7 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( // prepare init_ids just once for whole input if (init_ids.empty()) { - OPENVINO_ASSERT(m_models.decoder.get_tensor("input_ids").get_size() == 1); + m_models.decoder = m_decoder_cache.get_model(1).create_infer_request(); // for detect_language() init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics); // Get decoder with size of input_ids diff --git a/src/cpp/src/whisper_pipeline_static.hpp b/src/cpp/src/whisper_pipeline_static.hpp index 9913760792..c0e4aa8220 100644 --- a/src/cpp/src/whisper_pipeline_static.hpp +++ b/src/cpp/src/whisper_pipeline_static.hpp @@ -17,13 +17,16 @@ namespace genai { class DecoderCache { public: - DecoderCache() {} - DecoderCache(std::shared_ptr model) : m_decoder_model(model) {} + DecoderCache() = default; + DecoderCache(std::shared_ptr model, ov::PartialShape shape) + : m_decoder_model(model) + , m_lhs_shape(shape) {} ov::CompiledModel get_model(uint8_t input_ids_size); private: std::unordered_map m_cache; std::shared_ptr m_decoder_model; + ov::PartialShape m_lhs_shape; }; class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase { From 9e91e32e85891c1bf14d214f978bb1eb25a1dc2a Mon Sep 17 00:00:00 2001 From: Ekaterina Shiryaeva Date: Fri, 10 Jan 2025 10:40:56 +0000 Subject: [PATCH 5/5] Address comments --- src/cpp/src/whisper_pipeline_static.cpp | 26 ++++++++++++------------- src/cpp/src/whisper_pipeline_static.hpp | 9 +++------ 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index 5e0f2ef20e..91de478b1c 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -214,9 +214,11 @@ void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferReq }; int64_t detect_language(ov::Tensor& encoder_hidden_state, - ov::InferRequest decoder, + ov::genai::DecoderCache& decoder_cache, const ov::genai::WhisperGenerationConfig& config, ov::genai::RawPerfMetrics& raw_metrics) { + auto decoder = decoder_cache.get_model(1); + decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state}); std::vector init_ids{static_cast(config.decoder_start_token_id)}; @@ -246,7 +248,7 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state, } std::vector prepare_init_ids(ov::Tensor& encoder_hidden_state, - ov::InferRequest& decoder, + ov::genai::DecoderCache& decoder_cache, const ov::genai::WhisperGenerationConfig& config, const bool return_timestamps, ov::genai::RawPerfMetrics& raw_metrics) { @@ -266,7 +268,7 @@ std::vector prepare_init_ids(ov::Tensor& encoder_hidden_state, language_token_id = static_cast(config.lang_to_id.at(language)); } } else { - language_token_id = detect_language(encoder_hidden_state, decoder, config, raw_metrics); + language_token_id = detect_language(encoder_hidden_state, decoder_cache, config, raw_metrics); } int32_t task_token_id = static_cast(config.transcribe_token_id); @@ -502,18 +504,14 @@ std::shared_ptr redirect_new_kv_to_output(const std::shared_ptris_dynamic()) { // model is dynamic, reshaping it to static - reshape_to_static(m_decoder_model, input_ids_size, input_ids_size, m_lhs_shape); - } else { - reshape_input_ids(m_decoder_model, input_ids_size); - } + reshape_input_ids(m_decoder_model, input_ids_size); ov::Core core = utils::singleton_core(); ov::CompiledModel compiled_model = core.compile_model(m_decoder_model, "NPU"); ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model"); - m_cache.emplace(input_ids_size, compiled_model); + m_cache.emplace(input_ids_size, compiled_model.create_infer_request()); } return m_cache.at(input_ids_size); @@ -535,6 +533,7 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size); auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model); + reshape_to_static(decoder_model, 1, 1, last_hidden_state_shape); reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape); // Replace KV-tensors for the entire cache to tensors only for new token @@ -550,7 +549,7 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys m_models.encoder = compiled_model.create_infer_request(); // Will compile decoder model when it's needed - m_decoder_cache = DecoderCache(decoder_model, last_hidden_state_shape); + m_decoder_cache = DecoderCache(decoder_model); compiled_model = core.compile_model(decoder_with_past_model, "NPU"); ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model"); @@ -626,11 +625,10 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( // prepare init_ids just once for whole input if (init_ids.empty()) { - m_models.decoder = m_decoder_cache.get_model(1).create_infer_request(); // for detect_language() - init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics); + init_ids = prepare_init_ids(hidden_state_tensor, m_decoder_cache, config, return_timestamps, raw_metrics); // Get decoder with size of input_ids - m_models.decoder = m_decoder_cache.get_model(init_ids.size()).create_infer_request(); + m_models.decoder = m_decoder_cache.get_model(init_ids.size()); } auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor, diff --git a/src/cpp/src/whisper_pipeline_static.hpp b/src/cpp/src/whisper_pipeline_static.hpp index c0e4aa8220..b0618452d4 100644 --- a/src/cpp/src/whisper_pipeline_static.hpp +++ b/src/cpp/src/whisper_pipeline_static.hpp @@ -18,15 +18,12 @@ namespace genai { class DecoderCache { public: DecoderCache() = default; - DecoderCache(std::shared_ptr model, ov::PartialShape shape) - : m_decoder_model(model) - , m_lhs_shape(shape) {} + DecoderCache(std::shared_ptr model) : m_decoder_model(model) {} - ov::CompiledModel get_model(uint8_t input_ids_size); + ov::InferRequest get_model(uint8_t input_ids_size); private: - std::unordered_map m_cache; + std::unordered_map m_cache; std::shared_ptr m_decoder_model; - ov::PartialShape m_lhs_shape; }; class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase {