Skip to content

Commit

Permalink
Fix hidden state size for StaticWhisperPipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
eshiryae authored and dmatveev committed Nov 11, 2024
1 parent a566870 commit 875f106
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,17 @@ void add_attention_mask_input(std::shared_ptr<ov::Model> model) {
pm.run_passes(model);
}

void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_size, const uint32_t kvcache_size) {

uint32_t get_encoder_hidden_state_size(const std::shared_ptr<ov::Model>& encoder) {
for (auto output : encoder->outputs()) {
if (output.get_any_name() == "last_hidden_state") {
return output.get_partial_shape()[1].get_length();
}
}
return 0;
}

void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_size, const uint32_t kvcache_size, const uint32_t lhstate_size) {
std::map<std::string, ov::PartialShape> new_shapes;
for (auto input : model->inputs()) {
const auto& input_name = input.get_any_name();
Expand All @@ -414,14 +424,14 @@ void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_si
const auto& partial_shape = input.get_partial_shape();
new_shape = partial_shape;
new_shape[0] = 1; // batch_dim
new_shape[1] = 1500; // FIXME: is it got from encoder output{'last_hidden_state'}
new_shape[1] = lhstate_size; // from encoder output{'last_hidden_state'}
} else if (input_name.find("past_key_values") != std::string::npos) {
const auto& partial_shape = input.get_partial_shape();
new_shape = partial_shape;
new_shape[0] = 1; // Use batch dim here
new_shape[2] = input_name.find(".decoder") != std::string::npos
? kvcache_size - input_size // kv_size for decoder
: 1500; // for encoder
: lhstate_size; // hidden state size for encoder
}
new_shapes.emplace(input_name, new_shape);
}
Expand Down Expand Up @@ -538,8 +548,10 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
size_t max_sequence_length = 448;

reshape_to_static_encoder(encoder_model);
reshape_to_static(decoder_model, 4, 4);
reshape_to_static(decoder_with_past_model, 1, max_sequence_length);

size_t last_hidden_state_size = get_encoder_hidden_state_size(encoder_model);
reshape_to_static(decoder_model, 4, 4, last_hidden_state_size);
reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_size);

// Replace KV-tensors for the entire cache to tensors only for new token
decoder_with_past_model = redirect_new_kv_to_output(decoder_with_past_model);
Expand Down

0 comments on commit 875f106

Please sign in to comment.