Skip to content

Commit

Permalink
Fix hidden state size for StaticWhisperPipeline - propagate releases/…
Browse files Browse the repository at this point in the history
…2024/5 (#1186)

Propagage #1179  to the 2024.5 release branch
  • Loading branch information
ilya-lavrenov authored Nov 12, 2024
2 parents 62546c2 + 7d78ca7 commit 8341634
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,12 @@ void add_attention_mask_input(std::shared_ptr<ov::Model> model) {
pm.run_passes(model);
}

void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_size, const uint32_t kvcache_size) {

ov::PartialShape get_encoder_hidden_state_shape(const std::shared_ptr<ov::Model>& encoder) {
return encoder->output("last_hidden_state").get_partial_shape();
}

void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_size, const uint32_t kvcache_size, const ov::PartialShape& lhstate_shape) {
std::map<std::string, ov::PartialShape> new_shapes;
for (auto input : model->inputs()) {
const auto& input_name = input.get_any_name();
Expand All @@ -414,14 +419,15 @@ void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_si
const auto& partial_shape = input.get_partial_shape();
new_shape = partial_shape;
new_shape[0] = 1; // batch_dim
new_shape[1] = 1500; // FIXME: is it got from encoder output{'last_hidden_state'}
new_shape[1] = lhstate_shape[1]; // from encoder output{'last_hidden_state'}
new_shape[2] = lhstate_shape[2];
} else if (input_name.find("past_key_values") != std::string::npos) {
const auto& partial_shape = input.get_partial_shape();
new_shape = partial_shape;
new_shape[0] = 1; // Use batch dim here
new_shape[2] = input_name.find(".decoder") != std::string::npos
? kvcache_size - input_size // kv_size for decoder
: 1500; // for encoder
: lhstate_shape[1]; // hidden state size for encoder
}
new_shapes.emplace(input_name, new_shape);
}
Expand Down Expand Up @@ -538,8 +544,10 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
size_t max_sequence_length = 448;

reshape_to_static_encoder(encoder_model);
reshape_to_static(decoder_model, 4, 4);
reshape_to_static(decoder_with_past_model, 1, max_sequence_length);

auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model);
reshape_to_static(decoder_model, 4, 4, last_hidden_state_shape);
reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape);

// Replace KV-tensors for the entire cache to tensors only for new token
decoder_with_past_model = redirect_new_kv_to_output(decoder_with_past_model);
Expand Down

0 comments on commit 8341634

Please sign in to comment.