diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index 01fe882187..91de478b1c 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -121,28 +121,15 @@ void update_past_key_value(ov::InferRequest& source, ov::InferRequest& dest, con } } -void set_decoder_input_ids_attention_mask(ov::InferRequest& decoder, - const std::vector& init_ids, - const int64_t pad_token) { +void set_decoder_input_ids(ov::InferRequest& decoder, + const std::vector& init_ids) { auto input_ids_tensor = decoder.get_tensor("input_ids"); - auto attention_mask_tensor = decoder.get_tensor("attention_mask"); - const size_t seq_length = input_ids_tensor.get_shape()[1]; OPENVINO_ASSERT(seq_length >= init_ids.size()); - // pad right - // input_ids [token, token, token, pad_token] - // attention_mask [1, 1, 1, 0] auto input_ids_data = input_ids_tensor.data(); std::copy(init_ids.begin(), init_ids.end(), input_ids_data); - std::fill(input_ids_data + init_ids.size(), - input_ids_data + input_ids_tensor.get_size(), - static_cast(pad_token)); - - auto attention_mask_data = attention_mask_tensor.data(); - std::fill_n(attention_mask_data, init_ids.size(), 1u); - std::fill(attention_mask_data + init_ids.size(), attention_mask_data + attention_mask_tensor.get_size(), 0u); } int64_t decode(ov::Tensor& encoder_hidden_state, @@ -154,7 +141,7 @@ int64_t decode(ov::Tensor& encoder_hidden_state, const bool return_timestamps = false) { // NB: Fill decoder inputs encoder_hidden_state.copy_to(decoder.get_tensor("encoder_hidden_states")); - set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id); + set_decoder_input_ids(decoder, init_ids); ov::genai::utils::infer_with_perf_metrics(decoder, raw_metrics); @@ -210,13 +197,13 @@ void zero_past_key_values(ov::InferRequest& request) { } } -void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder) { +void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferRequest& decoder, const size_t init_ids_size) { // NB: Prepare attetion mask to be in a format [0, 0, 0, 1, 1, 1, 1, ..., 0, 1] // Mask should be inverted for decoder_with_past auto attention_mask = decoder_with_past.get_tensor("attention_mask"); auto* attention_mask_ptr = attention_mask.data(); - std::fill(attention_mask_ptr, attention_mask_ptr + 3u, 0); - std::fill(attention_mask_ptr + 3u, attention_mask_ptr + attention_mask.get_size() - 2, 1); + std::fill(attention_mask_ptr, attention_mask_ptr + init_ids_size, 0); + std::fill(attention_mask_ptr + init_ids_size, attention_mask_ptr + attention_mask.get_size() - 2, 1); attention_mask_ptr[attention_mask.get_size() - 2] = 0; attention_mask_ptr[attention_mask.get_size() - 1] = 1; // NB: Zero past_key_values.*.decoder.value tensors @@ -227,13 +214,15 @@ void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferReq }; int64_t detect_language(ov::Tensor& encoder_hidden_state, - ov::InferRequest decoder, + ov::genai::DecoderCache& decoder_cache, const ov::genai::WhisperGenerationConfig& config, ov::genai::RawPerfMetrics& raw_metrics) { + auto decoder = decoder_cache.get_model(1); + decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state}); std::vector init_ids{static_cast(config.decoder_start_token_id)}; - set_decoder_input_ids_attention_mask(decoder, init_ids, config.pad_token_id); + set_decoder_input_ids(decoder, init_ids); const auto infer_start = std::chrono::steady_clock::now(); decoder.infer(); @@ -259,7 +248,7 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state, } std::vector prepare_init_ids(ov::Tensor& encoder_hidden_state, - ov::InferRequest& decoder, + ov::genai::DecoderCache& decoder_cache, const ov::genai::WhisperGenerationConfig& config, const bool return_timestamps, ov::genai::RawPerfMetrics& raw_metrics) { @@ -279,7 +268,7 @@ std::vector prepare_init_ids(ov::Tensor& encoder_hidden_state, language_token_id = static_cast(config.lang_to_id.at(language)); } } else { - language_token_id = detect_language(encoder_hidden_state, decoder, config, raw_metrics); + language_token_id = detect_language(encoder_hidden_state, decoder_cache, config, raw_metrics); } int32_t task_token_id = static_cast(config.transcribe_token_id); @@ -318,7 +307,7 @@ std::pair> full_decode(ov::Tensor& encoder_hidden_sta return {false, output_tokens}; } - prepare_decoder_with_past(models.decoder_with_past, models.decoder); + prepare_decoder_with_past(models.decoder_with_past, models.decoder, init_ids.size()); for (size_t i = 0; i < max_new_tokens - 1; i++) { auto output_token = decode_with_past(models.decoder_with_past, @@ -353,36 +342,6 @@ bool check_decoder_model_compatibility(const std::shared_ptr& decoder return false; } -void add_attention_mask_input_for_decoder(std::shared_ptr model) { - using namespace ov::pass::pattern; - using namespace ov::op; - class AttentionMaskInput : public ov::pass::MatcherPass { - public: - OPENVINO_RTTI("AttentionMaskInput"); - - AttentionMaskInput(std::shared_ptr model) { - auto range = wrap_type(); - auto convert = wrap_type({range}); - auto convert1 = wrap_type({convert}); - auto greater = wrap_type({convert1, any_input()}); - auto convert2 = wrap_type({greater}); - - register_matcher(std::make_shared(convert2, this->get_type_info().name), [model](Matcher& m) { - auto node = m.get_match_root(); - auto attention_mask = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1}); - attention_mask->get_output_tensor(0).set_names({"attention_mask"}); - model->add_parameters({attention_mask}); - ov::replace_node(node, attention_mask); - return false; - }); - } - }; - - ov::pass::Manager pm; - pm.register_pass(model); - pm.run_passes(model); -} - void add_attention_mask_input(std::shared_ptr model) { using namespace ov::pass::pattern; using namespace ov::op; @@ -467,6 +426,10 @@ void reshape_to_static_encoder(std::shared_ptr model, const size_t fe model->reshape(new_shapes); } +void reshape_input_ids(std::shared_ptr model, const uint32_t input_size) { + model->reshape({{"input_ids", ov::PartialShape({1, input_size})}}); +} + void preprocess_encoder(std::shared_ptr model) { ov::preprocess::PrePostProcessor preprocessor(model); @@ -489,7 +452,7 @@ void preprocess_decoder(std::shared_ptr model) { preprocessor.input("attention_mask").preprocess().convert_element_type(); } else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) { preprocessor.input("encoder_hidden_states").tensor().set_element_type(ov::element::Type_t::f16); - preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32); // () + preprocessor.input("encoder_hidden_states").preprocess().convert_element_type(ov::element::Type_t::f32); } else if (tensor.get_any_name().find("past_key_values") != std::string::npos) { preprocessor.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16); preprocessor.input(tensor.get_any_name()).preprocess().convert_element_type(); @@ -541,6 +504,19 @@ std::shared_ptr redirect_new_kv_to_output(const std::shared_ptr model) : m_decoder_model(model) {} + + ov::InferRequest get_model(uint8_t input_ids_size); +private: + std::unordered_map m_cache; + std::shared_ptr m_decoder_model; +}; + class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase { public: StaticWhisperPipeline(const std::filesystem::path& model_path, const ov::AnyMap& properties); @@ -25,6 +36,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi private: WhisperInitializedModels m_models; + DecoderCache m_decoder_cache; }; } // namespace genai