Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eshiryae authored and dmatveev committed Nov 11, 2024
1 parent 875f106 commit 7d78ca7
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,16 +398,11 @@ void add_attention_mask_input(std::shared_ptr<ov::Model> model) {
}


uint32_t get_encoder_hidden_state_size(const std::shared_ptr<ov::Model>& encoder) {
for (auto output : encoder->outputs()) {
if (output.get_any_name() == "last_hidden_state") {
return output.get_partial_shape()[1].get_length();
}
}
return 0;
ov::PartialShape get_encoder_hidden_state_shape(const std::shared_ptr<ov::Model>& encoder) {
return encoder->output("last_hidden_state").get_partial_shape();
}

void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_size, const uint32_t kvcache_size, const uint32_t lhstate_size) {
void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_size, const uint32_t kvcache_size, const ov::PartialShape& lhstate_shape) {
std::map<std::string, ov::PartialShape> new_shapes;
for (auto input : model->inputs()) {
const auto& input_name = input.get_any_name();
Expand All @@ -424,14 +419,15 @@ void reshape_to_static(std::shared_ptr<ov::Model> model, const uint32_t input_si
const auto& partial_shape = input.get_partial_shape();
new_shape = partial_shape;
new_shape[0] = 1; // batch_dim
new_shape[1] = lhstate_size; // from encoder output{'last_hidden_state'}
new_shape[1] = lhstate_shape[1]; // from encoder output{'last_hidden_state'}
new_shape[2] = lhstate_shape[2];
} else if (input_name.find("past_key_values") != std::string::npos) {
const auto& partial_shape = input.get_partial_shape();
new_shape = partial_shape;
new_shape[0] = 1; // Use batch dim here
new_shape[2] = input_name.find(".decoder") != std::string::npos
? kvcache_size - input_size // kv_size for decoder
: lhstate_size; // hidden state size for encoder
: lhstate_shape[1]; // hidden state size for encoder
}
new_shapes.emplace(input_name, new_shape);
}
Expand Down Expand Up @@ -549,9 +545,9 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys

reshape_to_static_encoder(encoder_model);

size_t last_hidden_state_size = get_encoder_hidden_state_size(encoder_model);
reshape_to_static(decoder_model, 4, 4, last_hidden_state_size);
reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_size);
auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model);
reshape_to_static(decoder_model, 4, 4, last_hidden_state_shape);
reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape);

// Replace KV-tensors for the entire cache to tensors only for new token
decoder_with_past_model = redirect_new_kv_to_output(decoder_with_past_model);
Expand Down

0 comments on commit 7d78ca7

Please sign in to comment.