diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index a0c7c3a68b..3e09099c28 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -121,25 +121,15 @@ void update_past_key_value(ov::InferRequest& source, ov::InferRequest& dest, con } } -void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder, - const std::vector& init_ids, - const int64_t pad_token) { +void set_decoder_input_ids(ov::InferRequest& decoder, + const std::vector& init_ids) { auto input_ids_tensor = decoder.get_tensor("input_ids"); - auto attention_mask_tensor = decoder.get_tensor("attention_mask"); - const size_t seq_length = input_ids_tensor.get_shape()[1]; OPENVINO_ASSERT(seq_length >= init_ids.size()); - // pad right - // input_ids [token, token, token, pad_token] - // attention_mask [1, 1, 1, 0] auto input_ids_data = input_ids_tensor.data(); std::copy(init_ids.begin(), init_ids.end(), input_ids_data); - - auto attention_mask_data = attention_mask_tensor.data(); - std::fill_n(attention_mask_data, init_ids.size(), 1u); - std::fill(attention_mask_data + init_ids.size(), attention_mask_data + attention_mask_tensor.get_size(), 0u); } int64_t decode(ov::Tensor& encoder_hidden_state, @@ -151,7 +141,7 @@ int64_t decode(ov::Tensor& encoder_hidden_state, const bool return_timestamps = false) { // NB: Fill decoder inputs encoder_hidden_state.copy_to(decoder.get_tensor("encoder_hidden_states")); - set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id); + set_decoder_input_ids(decoder, init_ids); ov::genai::utils::infer_with_perf_metrics(decoder, raw_metrics); @@ -230,7 +220,7 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state, decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state}); std::vector init_ids{static_cast(config.decoder_start_token_id)}; - set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id); + set_decoder_input_ids(decoder, init_ids); const auto infer_start = std::chrono::steady_clock::now(); decoder.infer(); @@ -350,36 +340,6 @@ bool check_decoder_model_compatibility(const std::shared_ptr& decoder return false; } -void add_attention_mask_input_for_decoder(std::shared_ptr model) { - using namespace ov::pass::pattern; - using namespace ov::op; - class AttentionMaskInput : public ov::pass::MatcherPass { - public: - OPENVINO_RTTI("AttentionMaskInput"); - - AttentionMaskInput(std::shared_ptr model) { - auto range = wrap_type(); - auto convert = wrap_type({range}); - auto convert1 = wrap_type({convert}); - auto greater = wrap_type({convert1, any_input()}); - auto convert2 = wrap_type({greater}); - - register_matcher(std::make_shared(convert2, this->get_type_info().name), [model](Matcher& m) { - auto node = m.get_match_root(); - auto attention_mask = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1}); - attention_mask->get_output_tensor(0).set_names({"attention_mask"}); - model->add_parameters({attention_mask}); - ov::replace_node(node, attention_mask); - return false; - }); - } - }; - - ov::pass::Manager pm; - pm.register_pass(model); - pm.run_passes(model); -} - void add_attention_mask_input(std::shared_ptr model) { using namespace ov::pass::pattern; using namespace ov::op; @@ -464,6 +424,22 @@ void reshape_to_static_encoder(std::shared_ptr model, const size_t fe model->reshape(new_shapes); } +void reshape_input_ids(std::shared_ptr model, const uint32_t input_size) { + std::map new_shapes; + for (auto input : model->inputs()) { + const auto& input_name = input.get_any_name(); + ov::PartialShape new_shape; + if (input_name.find("input_ids") != std::string::npos) { + new_shape = ov::PartialShape({1, input_size}); + } else { + new_shape = input.get_partial_shape(); + } + new_shapes.emplace(input_name, new_shape); + } + + model->reshape(new_shapes); +} + void preprocess_encoder(std::shared_ptr model) { ov::preprocess::PrePostProcessor preprocessor(model); @@ -538,6 +514,17 @@ std::shared_ptr redirect_new_kv_to_output(const std::shared_ptr model) : m_decoder_model(model) {} + + ov::CompiledModel get_model(uint8_t input_ids_size); +private: + std::unordered_map m_cache; + std::shared_ptr m_decoder_model; +}; + class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase { public: StaticWhisperPipeline(const std::filesystem::path& model_path, const ov::AnyMap& properties); @@ -25,7 +36,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi private: WhisperInitializedModels m_models; - std::shared_ptr m_decoder_model; + DecoderCache m_decoder_cache; }; } // namespace genai