diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index e163dce2df..de1038a716 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -1,8 +1,10 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #include "llm_pipeline_static.hpp" +#include "sampler.hpp" + #include #include @@ -235,12 +237,12 @@ enum class GenerateHint { std::string to_string(GenerateHint h) { switch(h) { - case GenerateHint::FAST_COMPILE : + case GenerateHint::FAST_COMPILE : return "FAST_COMPILE"; - case GenerateHint::BEST_PERF : + case GenerateHint::BEST_PERF : return "BEST_PERF"; default: - OPENVINO_THROW("Unsupported value for type GenerateHint provided"); + OPENVINO_THROW("Unsupported value for type GenerateHint provided"); } } @@ -632,6 +634,19 @@ void copy_columns_by_row_chunks(const ov::Tensor& src, ov::Tensor& dst) { } } +void stream_generated_tokens(std::shared_ptr streamer_ptr, + ov::genai::GenerationHandle& handle) { + if (streamer_ptr && handle->can_read()) { + std::unordered_map token = handle->back(); + for (const auto& gen_token : token.begin()->second.generated_ids) { + if (streamer_ptr->put(gen_token)) { + handle->drop(); + break; + } + } + } +} + } // anonymous namespace namespace ov { @@ -643,7 +658,8 @@ StaticLLMPipeline::StaticLLMPipeline( const std::string& device, const ov::AnyMap& config ) : LLMPipelineImplBase(tokenizer, - utils::from_config_json_if_exists(models_path)) { + utils::from_config_json_if_exists(models_path)), + m_sampler(m_tokenizer) { auto properties = config; /* NB: Static LLM pipeline consists of two models, first to process the input prompt (prefill), @@ -672,6 +688,8 @@ StaticLLMPipeline::StaticLLMPipeline( if (m_generation_config.eos_token_id == -1) { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); } + + m_sampler.set_seed(m_generation_config.rng_seed); }; StaticLLMPipeline::StaticLLMPipeline( @@ -688,8 +706,7 @@ StaticLLMPipeline::StaticLLMPipeline( const std::string& device, const ov::AnyMap& properties, const ov::genai::GenerationConfig& generation_config -) : LLMPipelineImplBase(tokenizer, generation_config) { - +) : LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) { bool use_blobs = false; auto anyopt = get_option(properties, "USE_BLOBS"); if (anyopt.has_value()) { @@ -708,6 +725,8 @@ StaticLLMPipeline::StaticLLMPipeline( if (m_generation_config.eos_token_id == -1) { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); } + + m_sampler.set_seed(m_generation_config.rng_seed); } void StaticLLMPipeline::setupAndCompileModels( @@ -955,7 +974,10 @@ EncodedResults StaticLLMPipeline::generate( attention_mask = data->attention_mask; } - if (input_ids.get_shape().at(0) > 1u) { + ov::Shape prompts_shape = input_ids.get_shape(); + const size_t batch_size = prompts_shape[0]; + + if (batch_size > 1u) { OPENVINO_THROW("Currently only batch size=1 is supported"); } @@ -974,12 +996,14 @@ EncodedResults StaticLLMPipeline::generate( streamer_ptr = std::make_shared(m_tokenizer, *callback); } - if (!config.is_greedy_decoding()) { - OPENVINO_THROW("Currently only greedy decoding is supported"); + if (!config.is_greedy_decoding() && !config.is_multinomial()) { + OPENVINO_THROW("Currently only greedy and multinomial decoding are supported"); + } + + if (config.num_return_sequences != 1u) { + OPENVINO_THROW("Currently only \"num_return_sequences\" equal to 1 is supported!"); } - ov::Shape prompts_shape = input_ids.get_shape(); - const size_t batch_size = prompts_shape[0]; ov::genai::EncodedResults results; auto& raw_perf_counters = results.perf_metrics.raw_metrics; // NB: Only batch=1 is supported now @@ -1016,11 +1040,21 @@ EncodedResults StaticLLMPipeline::generate( // NB: Now there are prompt_len tokens in KV-cache m_kvcache_desc.num_stored_tokens += static_cast(prompt_len); - int64_t last_token = utils::argmax(m_prefill_request.get_tensor("logits"), 0); - results.tokens[0].push_back(last_token); - if (streamer_ptr && streamer_ptr->put(last_token)) { - return results; - } + + auto logits = m_prefill_request.get_tensor("logits"); + int64_t output_sequence_len = logits.get_shape().at(1); + + auto sequence_group = std::make_shared( + 0 /* request_id */, padded_input_ids, config, 1 /* block_size */); + sequence_group->update_processed_tokens_num(m_kvcache_desc.max_prompt_size - output_sequence_len); + sequence_group->schedule_tokens(output_sequence_len); + + // NB: Controls what tokens are ready to be pushed into the streamer + GenerationHandle handle = std::make_shared( + sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters()); + + SamplerOutput sampler_output = m_sampler.sample({sequence_group}, logits); + stream_generated_tokens(streamer_ptr, handle); // Outputs: logits, ... const auto kStartOutputKVCacheLayers = 1u; @@ -1061,30 +1095,28 @@ EncodedResults StaticLLMPipeline::generate( std::fill(attention_mask_data, attention_mask_data + m_kvcache_desc.num_stored_tokens - 1u, 1u); attention_mask_data[m_kvcache_desc.total_size - 1] = 1u; - const size_t max_tokens = config.get_max_new_tokens(prompt_len); - for (int i = 0; i < max_tokens - 1; ++i) { - input_ids_data[0] = last_token; + while (sequence_group->is_running()) { + sequence_group->schedule_tokens(1); + const auto running_sequences = sequence_group->get_running_sequences(); + OPENVINO_ASSERT(running_sequences.size() == 1u); + + input_ids_data[0] = running_sequences.front()->get_generated_ids().back(); position_ids_data[0] = m_kvcache_desc.num_stored_tokens; attention_mask_data[m_kvcache_desc.num_stored_tokens - 1] = 1u; m_kvcache_request.infer(); m_kvcache_desc.num_stored_tokens += 1; - last_token = utils::argmax(m_kvcache_request.get_tensor("logits"), 0); - results.tokens[0].push_back(last_token); - raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); raw_perf_counters.m_batch_sizes.emplace_back(batch_size); - if (streamer_ptr && streamer_ptr->put(last_token)) { - break; - } - if (last_token == config.eos_token_id && !config.ignore_eos) { - break; - } + SamplerOutput sampler_output = m_sampler.sample( + {sequence_group}, m_kvcache_request.get_tensor("logits")); + stream_generated_tokens(streamer_ptr, handle); // NB: KV-cache is full, further generation is impossible if (m_kvcache_desc.num_stored_tokens == m_kvcache_desc.total_size) { + sequence_group->set_out_of_memory(); break; } @@ -1108,6 +1140,12 @@ EncodedResults StaticLLMPipeline::generate( streamer_ptr->end(); } + OPENVINO_ASSERT(sequence_group->get_finished_sequences().size() == 1u); + auto sequence = sequence_group->get_finished_sequences().front(); + results.tokens[0] = sequence->get_generated_ids(); + results.scores[0] = sequence->get_cumulative_log_prob(); + m_sampler.clear_request_info(sequence_group->get_request_id()); + auto stop_time = std::chrono::steady_clock::now(); // If is called without tokenization then that stat will not be reported. auto& metrics = results.perf_metrics; diff --git a/src/cpp/src/llm_pipeline_static.hpp b/src/cpp/src/llm_pipeline_static.hpp index 7acc28c684..8dc7ef49a1 100644 --- a/src/cpp/src/llm_pipeline_static.hpp +++ b/src/cpp/src/llm_pipeline_static.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once @@ -6,6 +6,7 @@ #include #include "llm_pipeline_base.hpp" +#include "sampler.hpp" namespace ov { namespace genai { @@ -77,6 +78,8 @@ class StaticLLMPipeline final : public LLMPipelineImplBase { bool v_tensors_transposed; }; + Sampler m_sampler; + KVCacheDesc m_kvcache_desc; ov::InferRequest m_kvcache_request; ov::InferRequest m_prefill_request; diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index b2e8add403..54850f657b 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #include "sampler.hpp" @@ -743,7 +743,7 @@ process_stop_strings(const std::set& stop_strings, Tokenizer& token return result; } -SamplerOutput Sampler::sample(std::vector & sequence_groups, +SamplerOutput Sampler::sample(const std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled) { const float * logits_data = logits.data(); diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 981e11560f..7796f93d1e 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -1,5 +1,5 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once @@ -67,7 +67,7 @@ class Sampler { Sampler() = default; Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {}; - SamplerOutput sample(std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false); + SamplerOutput sample(const std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false); void set_seed(size_t new_seed) { rng_engine.seed(new_seed); seed = new_seed; diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 14ce87c6f1..b6bcc83530 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2023-2024 Intel Corporation +// Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #pragma once @@ -292,8 +292,8 @@ class SequenceGroup : public std::enable_shared_from_this { } size_t num_finished_seqs() const { - return std::count_if(m_sequences.begin(), m_sequences.end(), [] (Sequence::CPtr seq) { - return seq->has_finished(); + return std::count_if(m_sequences.begin(), m_sequences.end(), [this] (Sequence::CPtr seq) { + return seq->has_finished() || seq->out_of_memory() || handle_dropped(); }); } diff --git a/tests/python_tests/test_llm_pipeline_static.py b/tests/python_tests/test_llm_pipeline_static.py index 6ef6162043..d2d3673356 100644 --- a/tests/python_tests/test_llm_pipeline_static.py +++ b/tests/python_tests/test_llm_pipeline_static.py @@ -1,7 +1,9 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import openvino_genai as ov_genai +from openvino_genai import GenerationConfig + import pytest import platform import sys @@ -12,9 +14,19 @@ ) from common import get_default_properties +from common import \ + get_greedy, \ + get_greedy_with_penalties, \ + get_multinomial_temperature, \ + get_multinomial_all_parameters, \ + get_multinomial_temperature_and_presence_penalty, \ + get_beam_search + + if sys.platform == 'darwin' or platform.machine() in ["aarch64", "arm64", "ARM64"]: pytest.skip("NPU plugin is available only on Linux and Windows x86_64", allow_module_level=True) + # This test suite is designed specifically to validate the functionality and robustness of the StaticLLMPipeline on NPUW:CPU. common_config = { 'NPU_USE_NPUW': 'YES', @@ -33,17 +45,22 @@ def generate_chat_history(model_path, device, pipeline_config, questions): return chat_history +generation_configs = [ + get_greedy(), + get_greedy_with_penalties() +] @pytest.mark.precommit @pytest.mark.nightly -def test_generation_compare_with_stateful(): - prompt = 'The Sun is yellow because' +@pytest.mark.parametrize("generation_config", generation_configs) +def test_generation_compare_with_stateful(generation_config): + prompt = 'What is OpenVINO?' model_path = read_model(get_models_list()[0])[1] - stateful_pipe = ov_genai.LLMPipeline(model_path, "CPU", **get_default_properties()) - ref_out = stateful_pipe.generate(prompt, max_new_tokens=100) + stateful_pipe = ov_genai.LLMPipeline(model_path, "CPU") + ref_out = stateful_pipe.generate(prompt, generation_config) static_pipe = ov_genai.LLMPipeline(model_path, "NPU", **common_config) - actual_out = static_pipe.generate(prompt, max_new_tokens=100) + actual_out = static_pipe.generate(prompt, generation_config) if ref_out != actual_out: print(f'ref_out: {ref_out}\n') @@ -51,6 +68,25 @@ def test_generation_compare_with_stateful(): assert ref_out == actual_out +generation_configs = [ + get_multinomial_temperature_and_presence_penalty() +] +@pytest.mark.precommit +@pytest.mark.nightly +@pytest.mark.parametrize("generation_config", generation_configs) +def test_multinomial_sampling(generation_config): + # Multinomial sampling is highly sensitive to raw logits values. For fair comparison, + # a reference implementation producing identical logits (e.g., from StaticLLMPipeline) + # would be necessary. However, the CPU in StatefulPipeline and StaticLLMPipeline may apply + # different optimizations due to differences in provided topologies, leading to slight + # variations in raw logits. Therefore, there is no reliable reference for validation, + # so only ensure that no exceptions are raised. + prompt = 'What is OpenVINO?' + model_path = read_model(get_models_list()[0])[1] + static_pipe = ov_genai.LLMPipeline(model_path, "NPU", **common_config) + actual_out = static_pipe.generate(prompt, generation_config) + + @pytest.mark.precommit @pytest.mark.nightly def test_length_properties_set_no_exception(): @@ -100,23 +136,25 @@ def test_batch_raise_error(): # TODO: For the further sampling support generation_configs = [ - dict(num_beams=3), - dict(do_sample=True) + get_beam_search(), + # NB: Only num_return_sequences=1 is supported! + get_multinomial_all_parameters() ] @pytest.mark.parametrize("generation_config", generation_configs) @pytest.mark.precommit @pytest.mark.nightly def test_unsupported_sampling_raise_error(generation_config): model_path = read_model(get_models_list()[0])[1] - prompt = 'The Sun is yellow because' + prompt = 'What is OpenVINO?' + pipe = ov_genai.LLMPipeline(model_path, "NPU", **common_config) with pytest.raises(RuntimeError): - pipe.generate(prompt, **generation_config) + pipe.generate(prompt, generation_config) @pytest.mark.precommit @pytest.mark.nightly -def test_max_number_of_tokens(): +def test_terminate_by_max_number_of_tokens(): model_path = read_model(get_models_list()[0])[1] prompt = 'The Sun is yellow because' num_tokens = 128 @@ -129,6 +167,47 @@ def test_max_number_of_tokens(): assert len(encoded_results.tokens[0]) == num_tokens +@pytest.mark.precommit +@pytest.mark.nightly +def test_terminate_by_out_of_memory(): + model_path = read_model(get_models_list()[0])[1] + prompt = 'The Sun is yellow because' + pipeline_config = { "MAX_PROMPT_LEN": 64, "MIN_RESPONSE_LEN": 64 } + pipeline_config |= common_config + kv_cache_size = pipeline_config['MAX_PROMPT_LEN'] + pipeline_config['MIN_RESPONSE_LEN'] + + tokenizer = ov_genai.Tokenizer(model_path) + tokenized_input = tokenizer.encode(prompt) + input_len = tokenized_input.input_ids.get_shape()[1] + + pipe = ov_genai.LLMPipeline(model_path, "NPU", **pipeline_config) + encoded_results = pipe.generate(tokenized_input, max_new_tokens=1000, ignore_eos=True) + + assert len(encoded_results.tokens[0]) == (kv_cache_size - input_len + 1) + + +@pytest.mark.precommit +@pytest.mark.nightly +def test_terminate_by_sampler(): + model_path = read_model(get_models_list()[0])[1] + prompt = 'The Sun is yellow because' + + current_iter = 0 + num_iters = 10 + def callback(subword): + nonlocal current_iter + current_iter += 1 + return current_iter == num_iters + + tokenizer = ov_genai.Tokenizer(model_path) + tokenized_input = tokenizer.encode(prompt) + + pipe = ov_genai.LLMPipeline(model_path, "NPU", **common_config) + encoded_results = pipe.generate(tokenized_input, max_new_tokens=1000, ignore_eos=True, streamer=callback) + + assert len(encoded_results.tokens[0]) == num_iters + + # FIXME: Known problem, output differs from stateful pipeline starting from 3rd prompt! @pytest.mark.skip(reason="JIRA-144780: Output differs from stateful pipeline") @pytest.mark.precommit diff --git a/tools/llm_bench/task/text_generation.py b/tools/llm_bench/task/text_generation.py index 4822b228ca..c768d427e7 100644 --- a/tools/llm_bench/task/text_generation.py +++ b/tools/llm_bench/task/text_generation.py @@ -198,7 +198,6 @@ def run_text_generation(input_text, num, model, tokenizer, args, iter_data_list, def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data_list, md5_list, prompt_index, streamer, tokens_len, streaming, model_precision, proc_id, mem_consumption): - set_seed(args['seed']) input_text_list = [input_text] * args['batch_size'] if args["output_dir"] is not None and num == 0: for bs_index, in_text in enumerate(input_text_list): @@ -226,6 +225,7 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data log.info(out_str) gen_config = model.get_generation_config() gen_config.max_new_tokens = max_gen_tokens + gen_config.rng_seed = args["seed"] gen_config.num_beams = args["num_beams"] gen_config.do_sample = False if args.get('draft_model', ''): @@ -353,7 +353,6 @@ def token_printer(): def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, args, iter_data_list, md5_list, prompt_index, streamer, tokens_len, streaming, model_precision, proc_id, mem_consumption): - set_seed(args['seed']) input_text_list = [input_text] * args['batch_size'] if args["output_dir"] is not None and num == 0: for bs_index, in_text in enumerate(input_text_list): @@ -379,6 +378,7 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg max_gen_tokens = DEFAULT_OUTPUT_TOKEN_SIZE if args['infer_count'] is None else args['infer_count'] streamer.reset() gen_config = model.get_generation_config() + gen_config.rng_seed = args["seed"] gen_config.max_new_tokens = max_gen_tokens gen_config.num_beams = args["num_beams"] gen_config.do_sample = False