Skip to content

Commit

Permalink
add perf printing for prompt lookup decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
xufang-lisa committed Jan 3, 2025
1 parent d7ac127 commit ffdad03
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ int main(int argc, char* argv[]) try {

// Since the streamer is set, the results will
// be printed each time a new token is generated.
pipe.generate(prompt, config, streamer);
std::cout << std::endl;
int iter = 0;
while (iter < 10) {
pipe.generate(prompt, config, streamer);
iter++;
std::cout << "\npipeline finish iter:" << iter << std::endl;
}
} catch (const std::exception& error) {
try {
std::cerr << error.what() << '\n';
Expand Down
21 changes: 21 additions & 0 deletions src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void ContinuousBatchingPipeline::PromptLookupImpl::step() {
}

if (generated_len_after.empty() && 0) {
m_pipeline->get_infer_duration(m_sd_metrics.main_infer_duration, m_sd_metrics.main_infer_num);
m_sd_metrics.print(true);
m_sd_metrics.clean_up();
}
Expand Down Expand Up @@ -103,14 +104,25 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vector<ov::Ten
results.reserve(input_ids.size());

bool continue_generation = true;
bool get_first_token = false;
float first_token_time = 0;
int first_tokens_num = 0;
m_pipeline->reset_infer_duration();
while (has_non_finished_requests() && continue_generation) {
ManualTimer step_timer("speculative_decoding: step()");
step_timer.start();
step();
step_timer.end();
first_token_time += step_timer.get_duration();
if (streamer_ptr) {
// not generated tokens like several prompt phase
if (!main_generations.at(0).get()->can_read()) {
continue;
}
std::unordered_map<uint64_t, GenerationOutput> token = main_generations.at(0).get()->back();
if (!get_first_token && !token.begin()->second.generated_ids.empty()) {
first_tokens_num = token.begin()->second.generated_ids.size();
}
OPENVINO_ASSERT(1 <= token.size());
OPENVINO_ASSERT(1 <= token.begin()->second.generated_ids.size());
for (const auto& gen_token : token.begin()->second.generated_ids) {
Expand All @@ -120,6 +132,12 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vector<ov::Ten
}
}
}
if (!get_first_token && first_tokens_num > 0) {
get_first_token = true;
m_sd_metrics.first_token_duration = first_token_time;
int number = 0;
m_pipeline->get_infer_duration(m_sd_metrics.main_infer_for_first_token, number);
}
}
if (streamer_ptr) {
streamer_ptr->end();
Expand Down Expand Up @@ -148,6 +166,9 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vector<ov::Ten
OPENVINO_ASSERT(results.size() == input_ids.size());
generate_timer.end();
m_sd_metrics.total_duration = generate_timer.get_duration();
m_pipeline->get_infer_duration(m_sd_metrics.main_infer_duration, m_sd_metrics.main_infer_num);
m_sd_metrics.print(true);
m_sd_metrics.clean_up();

return results;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<

OPENVINO_ASSERT(results.size() == input_ids.size());
generate_timer.end();
m_draft_pipeline->get_infer_duration(m_sd_metrics.draft_infer_duration, m_sd_metrics.draft_infer_num);
m_main_pipeline->get_infer_duration(m_sd_metrics.main_infer_duration, m_sd_metrics.main_infer_num);
m_sd_metrics.print(true);
m_sd_metrics.clean_up();
return results;
}

Expand Down

0 comments on commit ffdad03

Please sign in to comment.