diff --git a/.github/workflows/llm_bench-python.yml b/.github/workflows/genai-tools.yml similarity index 78% rename from .github/workflows/llm_bench-python.yml rename to .github/workflows/genai-tools.yml index 56145c080c..333bee3e11 100644 --- a/.github/workflows/llm_bench-python.yml +++ b/.github/workflows/genai-tools.yml @@ -1,7 +1,7 @@ # This workflow will install Python dependencies, run tests and lint with a single version of Python # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python -name: llm_bench Python Test +name: GenAI tools on: workflow_dispatch: @@ -46,7 +46,8 @@ jobs: commit_packages_to_provide: wheels revision: latest_available_commit - build: + llm_bench: + name: 'LLM bench tests' defaults: run: shell: bash @@ -60,7 +61,6 @@ jobs: OV_INSTALL_DIR: ${{ github.workspace }}/ov SRC_DIR: ${{ github.workspace }} LLM_BENCH_PYPATH: ${{ github.workspace }}/tools/llm_bench - WWB_PATH: ${{ github.workspace }}/tools/who_what_benchmark steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -70,6 +70,12 @@ jobs: uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} + - name: Lint with flake8 + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest black + # stop the build if there are Python syntax errors or undefined names + python -m flake8 ${{ env.LLM_BENCH_PYPATH }} --config=${{ env.LLM_BENCH_PYPATH }}/setup.cfg - name: Download OpenVINO package uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: @@ -78,59 +84,42 @@ jobs: merge-multiple: true - name: Install dependencies run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest black python -m pip install ${{ env.SRC_DIR }}/thirdparty/openvino_tokenizers -v ${{ needs.openvino_download.outputs.ov_wheel_source }} python -m pip install ${{ env.SRC_DIR }} -v ${{ needs.openvino_download.outputs.ov_wheel_source }} - GIT_CLONE_PROTECTION_ACTIVE=false pip install -r ${{ env.LLM_BENCH_PYPATH }}/requirements.txt ${{ needs.openvino_download.outputs.ov_wheel_source }} + python -m pip install -r ${{ env.LLM_BENCH_PYPATH }}/requirements.txt ${{ needs.openvino_download.outputs.ov_wheel_source }} working-directory: ${{ env.OV_INSTALL_DIR }} - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - python -m flake8 ${{ env.LLM_BENCH_PYPATH }} --config=${{ env.LLM_BENCH_PYPATH }}/setup.cfg - python -m flake8 ${{ env.WWB_PATH }} --config=${{ env.WWB_PATH }}/setup.cfg - - name: Create code style diff for samples - if: failure() - run: | - python -m black -l 160 -S ${{ env.LLM_BENCH_PYPATH }}/ - git diff > llm.bench_diff.diff - - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 - if: failure() - with: - name: llm.bench_diff - path: llm.bench_diff.diff - - name: Test native pytorch model on Linux + - name: Test native pytorch model run: | git clone --depth 1 https://huggingface.co/katuni4ka/tiny-random-qwen python ./tools/llm_bench/benchmark.py -m tiny-random-qwen -d cpu -n 1 -f pt -ic 20 rm -rf tiny-random-qwen env: GIT_LFS_SKIP_SMUDGE: 0 - - name: Test tiny-random-baichuan2 on Linux Optimum Intel + - name: Test tiny-random-baichuan2 Optimum Intel run: | optimum-cli export openvino --model katuni4ka/tiny-random-baichuan2 --trust-remote-code --weight-format fp16 ./ov_models/tiny-random-baichuan2/pytorch/dldt/FP16 python ./tools/llm_bench/benchmark.py -m ./ov_models/tiny-random-baichuan2/pytorch/dldt/FP16/ -d cpu -n 1 --optimum -ic 10 rm -rf ./ov_models/tiny-random-baichuan2 - - name: Test OpenVINO/LCM_Dreamshaper_v7-int8-ov on Linux Optimum Intel + - name: Test OpenVINO/LCM_Dreamshaper_v7-int8-ov Optimum Intel run: | huggingface-cli download OpenVINO/LCM_Dreamshaper_v7-int8-ov --local-dir ov_models/lcm_dreamshaper_v7 python ./tools/llm_bench/benchmark.py -m ./ov_models/lcm_dreamshaper_v7/ -pf ./tools/llm_bench/prompts/stable-diffusion.jsonl -d cpu -n 1 --optimum --num_steps 4 - - name: Test OpenVINO/LCM_Dreamshaper_v7-int8-ov on Linux with GenAI + - name: Test OpenVINO/LCM_Dreamshaper_v7-int8-ov with GenAI run: | python ./tools/llm_bench/benchmark.py -m ./ov_models/lcm_dreamshaper_v7/ -pf ./tools/llm_bench/prompts/stable-diffusion.jsonl -d cpu -n 1 --num_steps 4 - - name: Test OpenVINO/LCM_Dreamshaper_v7-int8-ov on Linux with GenAI and LoRA + - name: Test OpenVINO/LCM_Dreamshaper_v7-int8-ov with GenAI and LoRA run: | wget -O ./ov_models/soulcard.safetensors https://civitai.com/api/download/models/72591 python ./tools/llm_bench/benchmark.py -m ./ov_models/lcm_dreamshaper_v7/ -pf ./tools/llm_bench/prompts/stable-diffusion.jsonl -d cpu -n 1 --lora ./ov_models/soulcard.safetensors --lora_alphas 0.7 --num_steps 4 rm -rf ./ov_models/lcm_dreamshaper_v7/ - - name: Test TinyLlama-1.1B-Chat-v1.0 in Speculative Deconding mode on Linux + - name: Test TinyLlama-1.1B-Chat-v1.0 in Speculative Decoding via GenAI run: | optimum-cli export openvino --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --trust-remote-code --weight-format fp16 ov_models/TinyLlama-1.1B-Chat-v1.0/FP16 optimum-cli export openvino --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --trust-remote-code --weight-format int8 ov_models/TinyLlama-1.1B-Chat-v1.0/INT8 python ./tools/llm_bench/benchmark.py -m ./ov_models/TinyLlama-1.1B-Chat-v1.0/FP16/ --draft_model ./ov_models/TinyLlama-1.1B-Chat-v1.0/INT8/ -p "Why is the Sun yellow?" -d cpu --draft_device cpu -n 1 --assistant_confidence_threshold 0.4 -ic 20 python ./tools/llm_bench/benchmark.py -m ./ov_models/TinyLlama-1.1B-Chat-v1.0/FP16/ --draft_model ./ov_models/TinyLlama-1.1B-Chat-v1.0/INT8/ -p "Why is the Sun yellow?" -d cpu --draft_device cpu -n 1 --num_assistant_tokens 5 -ic 20 rm -rf ov_models/TinyLlama-1.1B-Chat-v1.0 - - name: Test whisper-tiny on Linux + - name: Test whisper-tiny via GenAI run: | GIT_LFS_SKIP_SMUDGE=1 git clone --depth 1 --branch main --single-branch https://huggingface.co/datasets/facebook/multilingual_librispeech cd multilingual_librispeech @@ -143,60 +132,64 @@ jobs: python ./tools/llm_bench/benchmark.py -m ./ov_models/whisper-tiny --media multilingual_librispeech/data/mls_polish/train/audio/3283_1447_000/3283_1447_000000.flac -d cpu -n 1 rm -rf ./ov_models/whisper-tiny rm -rf multilingual_librispeech - - name: Text InternVL2-1B on Linux + - name: Text InternVL2-1B via GenAI run: | optimum-cli export openvino --model OpenGVLab/InternVL2-1B ./ov_models/internvl2-1B --task image-text-to-text --trust-remote-code python ./tools/llm_bench/benchmark.py -m ./ov_models/internvl2-1B --media https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11 --prompt "What is unusual on this image?" -ic 20 python ./tools/llm_bench/benchmark.py -m ./ov_models/internvl2-1B --media https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11 --prompt "What is unusual on this image?" -ic 20 --optimum rm -rf ./ov_models/internvl2-1B - - name: WWB Tests - run: | - pip install git+https://github.com/huggingface/optimum-intel.git - GIT_CLONE_PROTECTION_ACTIVE=false PIP_PRE=1 PIP_EXTRA_INDEX_URL=https://storage.openvinotoolkit.org/simple/wheels/nightly pip install ${{ env.WWB_PATH }} - python -m pytest -v ${{ env.WWB_PATH }}/tests - stateful: + + wwb: + name: 'WWB tests' defaults: run: shell: bash runs-on: ubuntu-22.04 + strategy: + fail-fast: false + matrix: + python-version: ["3.11"] needs: [ openvino_download ] env: OV_INSTALL_DIR: ${{ github.workspace }}/ov SRC_DIR: ${{ github.workspace }} - LLM_BENCH_PYPATH: ${{ github.workspace }}/tools/llm_bench WWB_PATH: ${{ github.workspace }}/tools/who_what_benchmark steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: recursive - - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: - python-version: "3.11" + python-version: ${{ matrix.python-version }} + - name: Lint with flake8 + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest black + # stop the build if there are Python syntax errors or undefined names + python -m flake8 ${{ env.WWB_PATH }} --config=${{ env.WWB_PATH }}/setup.cfg - name: Download OpenVINO package uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: ${{ needs.openvino_download.outputs.ov_artifact_name }} path: ${{ env.OV_INSTALL_DIR }} merge-multiple: true - - name: Test stateful + - name: Install dependencies run: | python -m pip install ${{ env.SRC_DIR }}/thirdparty/openvino_tokenizers -v ${{ needs.openvino_download.outputs.ov_wheel_source }} python -m pip install ${{ env.SRC_DIR }} -v ${{ needs.openvino_download.outputs.ov_wheel_source }} - GIT_CLONE_PROTECTION_ACTIVE=false python -m pip install -r ${{ env.LLM_BENCH_PYPATH }}/requirements.txt ${{ needs.openvino_download.outputs.ov_wheel_source }} - python ${{ env.LLM_BENCH_PYPATH }}/convert.py --model_id TinyLlama/TinyLlama-1.1B-Chat-v1.0 --output_dir ${{ env.SRC_DIR }} --stateful - grep beam_idx ${{ env.SRC_DIR }}/pytorch/dldt/FP32/openvino_model.xml + python -m pip install -r ${{ env.WWB_PATH }}/requirements.txt ${{ needs.openvino_download.outputs.ov_wheel_source }} + python -m pip install git+https://github.com/huggingface/optimum-intel.git@main#egg=optimum-intel working-directory: ${{ env.OV_INSTALL_DIR }} - name: WWB Tests run: | - pip install pytest - pip install git+https://github.com/huggingface/optimum-intel.git - GIT_CLONE_PROTECTION_ACTIVE=false PIP_PRE=1 PIP_EXTRA_INDEX_URL=https://storage.openvinotoolkit.org/simple/wheels/nightly pip install ${{ env.WWB_PATH }} + python -m pip install -v ${{ env.WWB_PATH }} python -m pytest -v ${{ env.WWB_PATH }}/tests Overall_Status: name: ci/gha_overall_status_llm_bench - needs: [openvino_download, build, stateful] + needs: [openvino_download, llm_bench, wwb] if: ${{ always() }} runs-on: ubuntu-latest steps: diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 0125479f92..11efed8b32 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -20,13 +20,13 @@ namespace { /* * NPU reads some properties from the config file, but when LLMPipeline is initialized -* from the model_str and weights_tensor, there are not files. +* from the model_str and weights_tensor, there are no files. * In the later case ModelDesc is stored in properties. * This function pops ModelDescr from the the properties and returns a pair of updated properties and ModelDescr. */ -std::pair split_model_descr(const ov::AnyMap& properties) { +std::pair split_model_descr(const ov::AnyMap& properties) { ov::AnyMap main_properties = properties; - ov::genai::ModelConfigDesc model_descr; + ov::genai::static_llm::ModelConfigDesc model_descr; auto pop_property = [](ov::AnyMap& orig_propertis, const std::string& key, auto& value) { if (orig_propertis.find(key) != orig_propertis.end()) { @@ -105,7 +105,7 @@ ov::genai::LLMPipeline::LLMPipeline( auto [plugin_config, scheduler_config] = utils::split_scheduler_config(properties); m_pimpl = std::make_unique(models_path, tokenizer, scheduler_config, device, plugin_config); } else if (device == "NPU") { - m_pimpl = std::make_unique(models_path, tokenizer, device, properties); + m_pimpl = static_llm::LLMPipelineFactory::create(models_path, tokenizer, device, properties); } else { m_pimpl = std::make_unique(models_path, tokenizer, device, properties); } @@ -124,7 +124,7 @@ ov::genai::LLMPipeline::LLMPipeline( auto [device_properties, scheduler_config] = utils::split_scheduler_config(properties); m_pimpl = std::make_unique(models_path, scheduler_config, device, device_properties); } else if (device == "NPU") { - m_pimpl = std::make_unique(models_path, device, properties); + m_pimpl = static_llm::LLMPipelineFactory::create(models_path, device, properties); } else { m_pimpl = std::make_unique(models_path, device, properties); } @@ -162,7 +162,7 @@ ov::genai::LLMPipeline::LLMPipeline( // This will convert from AnyMap to ModelDesc. auto [filtered_properties, model_descr] = split_model_descr(properties); - m_pimpl = std::make_unique( + m_pimpl = static_llm::LLMPipelineFactory::create( utils::singleton_core().read_model(model_str, weights_tensor), model_descr, tokenizer, diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index de1038a716..94aa6e19fe 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -398,12 +398,12 @@ KVAxesPosition get_kv_axes(const std::string& model_type) { return axes; } -ov::genai::ModelConfigDesc get_modeldesc_from_json(const std::filesystem::path& filepath) { +ov::genai::static_llm::ModelConfigDesc get_modeldesc_from_json(const std::filesystem::path& filepath) { std::ifstream file(filepath); OPENVINO_ASSERT(file.is_open(), "Could not open file: ", filepath); nlohmann::json config_data = nlohmann::json::parse(file); - ov::genai::ModelConfigDesc desc; + ov::genai::static_llm::ModelConfigDesc desc; desc.type = config_data["model_type"].get(); // NB: In case _name_or_path field isn't presented in config.json if (config_data.contains("_name_or_path")) { @@ -588,6 +588,19 @@ std::optional pop_int_and_cast(ov::AnyMap& config, const std::string& return std::nullopt; } +void update_config(ov::AnyMap& config, const std::pair& pair) { + if (config.count(pair.first) == 0) { + config.insert(pair); + } +} + +void rename_key(ov::AnyMap& config, const std::string& old_key, const std::string& new_key) { + if (config.count(old_key) != 0) { + auto opt_value = pop_option(config, old_key); + config[new_key] = opt_value.value(); + } +} + ov::Tensor make_tensor_slice(ov::Tensor tensor, size_t dim, size_t start_pos, size_t end_pos) { ov::Shape start_shape(std::vector(tensor.get_shape().size(), 0u)); start_shape[dim] = start_pos; @@ -647,12 +660,269 @@ void stream_generated_tokens(std::shared_ptr streamer_p } } +enum StaticPipelineKind { + STATEFUL, + STATELESS +}; +StaticPipelineKind str_to_pipeline(const std::string& str) { + if (str == "STATEFUL") { + return StaticPipelineKind::STATEFUL; + } + if (str == "STATELESS") { + return StaticPipelineKind::STATELESS; + } + OPENVINO_THROW("Unsupported \"PIPELINE\" provided: ", + str, ". Please select either \"STATEFUL\" or \"STATELESS\"."); +} } // anonymous namespace namespace ov { namespace genai { +namespace static_llm { + +StatefulLLMPipeline::StatefulLLMPipeline( + const std::filesystem::path& models_path, + const ov::genai::Tokenizer& tokenizer, + const std::string&, + const ov::AnyMap& config +) : LLMPipelineImplBase(tokenizer, + utils::from_config_json_if_exists(models_path)) { + + auto model = genai::utils::singleton_core().read_model(models_path / "openvino_model.xml", {}, config); + ModelConfigDesc model_desc = get_modeldesc_from_json(models_path / "config.json"); + ov::AnyMap properties = config; + + auto compiled = setupAndCompileModel(model, model_desc, properties); + m_request = compiled->create_infer_request(); +} + + +StatefulLLMPipeline::StatefulLLMPipeline( + const std::shared_ptr& model, + const ModelConfigDesc& model_desc, + const ov::genai::Tokenizer& tokenizer, + const std::string&, + const ov::AnyMap& properties, + const ov::genai::GenerationConfig& generation_config +) : LLMPipelineImplBase(tokenizer, generation_config) { + ov::AnyMap properties_copy = properties; + auto compiled = setupAndCompileModel(model, model_desc, properties_copy); + m_request = compiled->create_infer_request(); +} + +std::shared_ptr StatefulLLMPipeline::setupAndCompileModel( + const std::shared_ptr& model, + const ModelConfigDesc& model_desc, + ov::AnyMap& pipeline_config) { + + const uint32_t kMaxPromptLen = pop_int_and_cast(pipeline_config, "MAX_PROMPT_LEN").value_or(1024u); + const uint32_t kMinResponseLen = pop_int_and_cast(pipeline_config, "MIN_RESPONSE_LEN").value_or(128u); + m_kvcache_total = kMaxPromptLen + kMinResponseLen; + std::string generate_hint = pop_or_default(pipeline_config, "GENERATE_HINT", "FAST_COMPILE"); + + update_config(pipeline_config, {"NPU_USE_NPUW", "YES"}); + update_config(pipeline_config, {"NPUW_LLM", "YES"}); + + KVAxesPosition axes = get_kv_axes(model_desc.type); + update_config(pipeline_config, {"NPUW_LLM_BATCH_DIM", axes.batch}); + update_config(pipeline_config, {"NPUW_LLM_SEQ_LEN_DIM", axes.seq_len}); + + update_config(pipeline_config, {"NPUW_LLM_MAX_PROMPT_LEN", kMaxPromptLen}); + update_config(pipeline_config, {"NPUW_LLM_MIN_RESPONSE_LEN", kMinResponseLen}); + update_config(pipeline_config, {"NPUW_LLM_GENERATE_HINT", generate_hint}); + + // NB: Try to apply opt transpose only for Llama-2-7b-chat-hf model + if ( model_desc.name_or_path == "meta-llama/Llama-2-7b-chat-hf" || + (model_desc.type == "llama" && model_desc.num_key_value_heads == 32)) { + update_config(pipeline_config, {"NPUW_LLM_OPTIMIZE_V_TENSORS", true}); + } + + rename_key(pipeline_config, "PREFILL_CONFIG", "NPUW_LLM_PREFILL_CONFIG"); + rename_key(pipeline_config, "GENERATE_CONFIG", "NPUW_LLM_GENERATE_CONFIG"); + + return std::make_shared(genai::utils::singleton_core().compile_model(model, "NPU", pipeline_config)); +} + +DecodedResults StatefulLLMPipeline::generate( + StringInputs inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer +) { + auto start_time = std::chrono::steady_clock::now(); + + GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; + std::string prompt; + if (auto input_vector = std::get_if>(&inputs)) { + OPENVINO_ASSERT(input_vector->size() == 1u, "Currently only batch size=1 is supported"); + prompt = std::move(input_vector->front()); + } else { + OPENVINO_ASSERT(std::holds_alternative(inputs)); + prompt = std::get(inputs); + } + + ov::genai::TokenizedInputs tokenized_input; + if (m_is_chat_conversation) { + m_history.push_back({{"role", "user"}, {"content", prompt}}); + constexpr bool add_generation_prompt = true; + prompt = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); + // for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF + tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false)); + } else { + tokenized_input = m_tokenizer.encode(prompt); + } + + auto encode_stop_time = std::chrono::steady_clock::now(); + auto encoded_results = generate(tokenized_input, config, streamer); + + auto decode_start_time = std::chrono::steady_clock::now(); + DecodedResults decoded_results = {m_tokenizer.decode(encoded_results.tokens), encoded_results.scores}; + auto decode_stop_time = std::chrono::steady_clock::now(); + + if (m_is_chat_conversation) { + auto answer = decoded_results.texts[0]; + m_history.push_back({{"role", "assistant"}, {"content", answer}}); + } + + // generate_durations + decoded_results.perf_metrics = encoded_results.perf_metrics; + auto& raw_counters = decoded_results.perf_metrics.raw_metrics; + auto stop_time = std::chrono::steady_clock::now(); + raw_counters.generate_durations = std::vector(); + raw_counters.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); + raw_counters.tokenization_durations.emplace_back(PerfMetrics::get_microsec(encode_stop_time - start_time)); + raw_counters.detokenization_durations.emplace_back(PerfMetrics::get_microsec(decode_stop_time - decode_start_time)); + decoded_results.perf_metrics.m_evaluated = false; + decoded_results.perf_metrics.evaluate_statistics(start_time); + return decoded_results; +} + +EncodedResults StatefulLLMPipeline::generate( + const EncodedInputs& inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer +) { + auto start_time = std::chrono::steady_clock::now(); + ov::Tensor input_ids; + ov::Tensor attention_mask; + + if (auto data = std::get_if(&inputs)) { + input_ids = *data; + attention_mask = ov::genai::utils::init_attention_mask(input_ids); + } else if (auto data = std::get_if(&inputs)) { + input_ids = data->input_ids; + attention_mask = data->attention_mask; + } + + OPENVINO_ASSERT(input_ids.get_shape().at(0) == 1u, "Currently only batch size=1 is supported"); + + GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; + // If eos_token_id was not provided, take value from default m_generation_config + if (config.eos_token_id == -1) + config.set_eos_token_id(m_generation_config.eos_token_id); + config.validate(); + + std::shared_ptr streamer_ptr; + if (auto streamer_obj = std::get_if(&streamer)) { + streamer_ptr = nullptr; + } else if (auto streamer_obj = std::get_if>(&streamer)) { + streamer_ptr = *streamer_obj; + } else if (auto callback = std::get_if>(&streamer)) { + streamer_ptr = std::make_shared(m_tokenizer, *callback); + } + + OPENVINO_ASSERT(config.is_greedy_decoding(), "Currently only greedy decoding is supported"); + + ov::Shape prompts_shape = input_ids.get_shape(); + const size_t batch_size = prompts_shape[0]; + ov::genai::EncodedResults results; + auto& raw_perf_counters = results.perf_metrics.raw_metrics; + // NB: Only batch=1 is supported now + results.scores.resize(1u); + results.scores[0] = 0u; + results.tokens.resize(1u); + + // TODO: Check if there is enough space in KV-cache to process input prompt + auto prompt_len = input_ids.get_size(); + + ov::Tensor position_ids{ov::element::i64, input_ids.get_shape()}; + utils::initialize_position_ids(position_ids, attention_mask); + + m_request.set_tensor("input_ids", input_ids); + m_request.set_tensor("attention_mask", attention_mask); + m_request.set_tensor("position_ids", position_ids); + + m_request.infer(); + + int64_t last_token = utils::argmax(m_request.get_tensor("logits"), 0); + + results.tokens[0].push_back(last_token); + if (streamer_ptr && streamer_ptr->put(last_token)) { + return results; + } + + int64_t input_ids_data = -1; + int64_t position_ids_data = prompt_len - 1; + std::vector attention_mask_data(prompt_len - 1, 1); + m_request.set_tensor("input_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast(&input_ids_data))); + m_request.set_tensor("position_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast(&position_ids_data))); -StaticLLMPipeline::StaticLLMPipeline( + const size_t max_tokens = config.get_max_new_tokens(prompt_len); + for (int i = 0; i < max_tokens - 1; ++i) { + // KV Cache is full, no further generation is possible + if (position_ids_data + 1 == m_kvcache_total) { + break; + } + + // Just change the variables here, as pointers to them are already set to corresponding tensors + input_ids_data = last_token; + ++position_ids_data; + // However, attention_mask changes its shape on each iteration, it should be re-set explicitly + attention_mask_data.push_back(1); + m_request.set_tensor("attention_mask", ov::Tensor(ov::element::i64, ov::Shape{1,attention_mask_data.size()}, (void*)&attention_mask_data[0])); + + m_request.infer(); + + last_token = utils::argmax(m_request.get_tensor("logits"), 0); + results.tokens[0].push_back(last_token); + + raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); + raw_perf_counters.m_batch_sizes.emplace_back(batch_size); + if (streamer_ptr && streamer_ptr->put(last_token)) { + break; + } + + if (last_token == config.eos_token_id && !config.ignore_eos) { + break; + } + } + + if (streamer_ptr) { + streamer_ptr->end(); + } + + auto stop_time = std::chrono::steady_clock::now(); + // If is called without tokenization then that stat will not be reported. + auto& metrics = results.perf_metrics; + metrics.num_input_tokens = batch_size * input_ids.get_shape().at(1); + metrics.load_time = this->m_load_time_ms; + metrics.raw_metrics.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); + metrics.evaluate_statistics(start_time); + return results; +} + +void StatefulLLMPipeline::start_chat(const std::string& system_message) { + if (!system_message.empty()) { + m_history.push_back({{"role", "system"}, {"content", system_message}}); + } + m_is_chat_conversation = true; +}; + +void StatefulLLMPipeline::finish_chat() { + m_is_chat_conversation = false; + m_history.clear(); +}; + +StatelessLLMPipeline::StatelessLLMPipeline( const std::filesystem::path& models_path, const ov::genai::Tokenizer& tokenizer, const std::string& device, @@ -692,14 +962,14 @@ StaticLLMPipeline::StaticLLMPipeline( m_sampler.set_seed(m_generation_config.rng_seed); }; -StaticLLMPipeline::StaticLLMPipeline( +StatelessLLMPipeline::StatelessLLMPipeline( const std::filesystem::path& models_path, const std::string& device, const ov::AnyMap& properties -) : StaticLLMPipeline(models_path, Tokenizer(models_path), device, properties) { +) : StatelessLLMPipeline(models_path, Tokenizer(models_path), device, properties) { } -StaticLLMPipeline::StaticLLMPipeline( +StatelessLLMPipeline::StatelessLLMPipeline( const std::shared_ptr& model, const ModelConfigDesc& model_desc, const ov::genai::Tokenizer& tokenizer, @@ -729,7 +999,7 @@ StaticLLMPipeline::StaticLLMPipeline( m_sampler.set_seed(m_generation_config.rng_seed); } -void StaticLLMPipeline::setupAndCompileModels( +void StatelessLLMPipeline::setupAndCompileModels( const std::shared_ptr& model, const std::string& device, const ModelConfigDesc& model_desc, @@ -808,7 +1078,7 @@ void StaticLLMPipeline::setupAndCompileModels( ov::genai::utils::print_compiled_model_properties(prefill_compiled_model, "Static LLM prefill compiled model"); } -void StaticLLMPipeline::setupAndImportModels( +void StatelessLLMPipeline::setupAndImportModels( const std::filesystem::path& models_path, const std::string& device, ov::AnyMap& properties) { @@ -882,19 +1152,19 @@ void StaticLLMPipeline::setupAndImportModels( m_kvcache_desc = KVCacheDesc { kMaxPromptLen, kMaxPromptLen + kMinResponseLen, 0u, 2u }; } -void StaticLLMPipeline::start_chat(const std::string& system_message) { +void StatelessLLMPipeline::start_chat(const std::string& system_message) { if (!system_message.empty()) { m_history.push_back({{"role", "system"}, {"content", system_message}}); } m_is_chat_conversation = true; }; -void StaticLLMPipeline::finish_chat() { +void StatelessLLMPipeline::finish_chat() { m_is_chat_conversation = false; m_history.clear(); }; -void StaticLLMPipeline::prepare_for_new_conversation() { +void StatelessLLMPipeline::prepare_for_new_conversation() { fill_tensor(m_prefill_request.get_tensor("input_ids"), m_tokenizer.get_pad_token_id()); fill_tensor(m_prefill_request.get_tensor("position_ids"), 0u); fill_tensor(m_prefill_request.get_tensor("attention_mask"), 0u); @@ -902,7 +1172,7 @@ void StaticLLMPipeline::prepare_for_new_conversation() { m_kvcache_desc.num_stored_tokens = 0u; } -DecodedResults StaticLLMPipeline::generate( +DecodedResults StatelessLLMPipeline::generate( StringInputs inputs, OptionalGenerationConfig generation_config, StreamerVariant streamer @@ -957,7 +1227,7 @@ DecodedResults StaticLLMPipeline::generate( return decoded_results; } -EncodedResults StaticLLMPipeline::generate( +EncodedResults StatelessLLMPipeline::generate( const EncodedInputs& inputs, OptionalGenerationConfig generation_config, StreamerVariant streamer @@ -1156,5 +1426,49 @@ EncodedResults StaticLLMPipeline::generate( return results; } +std::unique_ptr +LLMPipelineFactory::create(const std::filesystem::path& models_path, + const ov::genai::Tokenizer& tokenizer, + const std::string& device, + const ov::AnyMap& config) { + auto properties = config; + const auto pipeline_mode = str_to_pipeline(pop_or_default(properties, "STATIC_PIPELINE", std::string("STATELESS"))); + if (pipeline_mode == StaticPipelineKind::STATEFUL) { + return std::make_unique(models_path, tokenizer, device, properties); + } + return std::make_unique(models_path, tokenizer, device, properties); +} + +std::unique_ptr +LLMPipelineFactory::create(const std::filesystem::path& models_path, + const std::string& device, + const ov::AnyMap& config) { + return create(models_path, Tokenizer(models_path), device, config); +} + +std::unique_ptr LLMPipelineFactory::create(const std::shared_ptr& model, + const ModelConfigDesc& model_desc, + const ov::genai::Tokenizer& tokenizer, + const std::string& device, + const ov::AnyMap& properties, + const ov::genai::GenerationConfig& generation_config) { + auto properties_copy = properties; + const auto pipeline_mode = str_to_pipeline(pop_or_default(properties_copy, "STATIC_PIPELINE", std::string("STATELESS"))); + if (pipeline_mode == StaticPipelineKind::STATEFUL) { + return std::make_unique(model, + model_desc, + tokenizer, + device, + properties_copy, + generation_config); + } + return std::make_unique(model, + model_desc, + tokenizer, + device, + properties_copy, + generation_config); +} +} // namespace static_llm } // namespace genai } // namespace ov diff --git a/src/cpp/src/llm_pipeline_static.hpp b/src/cpp/src/llm_pipeline_static.hpp index 8dc7ef49a1..dd51c31b29 100644 --- a/src/cpp/src/llm_pipeline_static.hpp +++ b/src/cpp/src/llm_pipeline_static.hpp @@ -10,6 +10,7 @@ namespace ov { namespace genai { +namespace static_llm { struct ModelConfigDesc { std::string type; @@ -17,16 +18,34 @@ struct ModelConfigDesc { int num_key_value_heads; }; -class StaticLLMPipeline final : public LLMPipelineImplBase { +struct LLMPipelineFactory { + static std::unique_ptr create(const std::filesystem::path& path, + const ov::genai::Tokenizer& tokenizer, + const std::string& device, + const ov::AnyMap& config); + + static std::unique_ptr create(const std::filesystem::path& path, + const std::string& device, + const ov::AnyMap& config); + + static std::unique_ptr create(const std::shared_ptr& model, + const ModelConfigDesc& model_desc, + const ov::genai::Tokenizer& tokenizer, + const std::string& device, + const ov::AnyMap& properties, + const ov::genai::GenerationConfig& generation_config = {}); +}; + +class StatefulLLMPipeline : public LLMPipelineImplBase { public: - StaticLLMPipeline( + StatefulLLMPipeline( const std::filesystem::path& path, const ov::genai::Tokenizer& tokenizer, const std::string& device, const ov::AnyMap& config ); - StaticLLMPipeline( + StatefulLLMPipeline( const std::shared_ptr& model, const ModelConfigDesc& model_desc, const ov::genai::Tokenizer& tokenizer, @@ -35,12 +54,57 @@ class StaticLLMPipeline final : public LLMPipelineImplBase { const ov::genai::GenerationConfig& generation_config = {} ); - StaticLLMPipeline( + std::shared_ptr setupAndCompileModel( + const std::shared_ptr& model, + const ModelConfigDesc& model_desc, + ov::AnyMap& pipeline_config); + + DecodedResults generate( + StringInputs inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer + ) override; + + EncodedResults generate( + const EncodedInputs& inputs, + OptionalGenerationConfig generation_config, + StreamerVariant streamer + ) override; + + void start_chat(const std::string& system_message) override; + void finish_chat() override; + +private: + uint32_t m_kvcache_total = 0u; + ov::InferRequest m_request; + bool m_is_chat_conversation = false; + ChatHistory m_history; +}; + +class StatelessLLMPipeline final : public LLMPipelineImplBase { +public: + StatelessLLMPipeline( + const std::filesystem::path& path, + const ov::genai::Tokenizer& tokenizer, + const std::string& device, + const ov::AnyMap& config + ); + + StatelessLLMPipeline( const std::filesystem::path& path, const std::string& device, const ov::AnyMap& config ); + StatelessLLMPipeline( + const std::shared_ptr& model, + const ModelConfigDesc& model_desc, + const ov::genai::Tokenizer& tokenizer, + const std::string& device, + const ov::AnyMap& properties, + const ov::genai::GenerationConfig& generation_config = {} + ); + void setupAndCompileModels( const std::shared_ptr& model, const std::string& device, @@ -88,5 +152,6 @@ class StaticLLMPipeline final : public LLMPipelineImplBase { ChatHistory m_history; }; +} // namespace static_llm } // namespace genai } // namespace ov diff --git a/tests/python_tests/common.py b/tests/python_tests/common.py index bb34c1dcd4..aa4c537dd6 100644 --- a/tests/python_tests/common.py +++ b/tests/python_tests/common.py @@ -390,7 +390,7 @@ def compare_generation_results(prompts: List[str], hf_results: List[GenerationRe def get_hugging_face_models(model_id: str): hf_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - opt_model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True, ov_config=get_default_properties()) + opt_model = OVModelForCausalLM.from_pretrained(model_id, export=True, compile=False, load_in_8bit=False, trust_remote_code=True, ov_config=get_default_properties()) return opt_model, hf_tokenizer diff --git a/tests/python_tests/test_kv_cache_eviction.py b/tests/python_tests/test_kv_cache_eviction.py index 41281e9cab..3dbf9297ee 100644 --- a/tests/python_tests/test_kv_cache_eviction.py +++ b/tests/python_tests/test_kv_cache_eviction.py @@ -42,7 +42,7 @@ class ConvertedModel: @pytest.fixture(scope='module') def converted_model(tmp_path_factory): model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" - model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True) + model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True, load_in_8bit=False, compile=False) tokenizer = AutoTokenizer.from_pretrained(model_id) models_path = tmp_path_factory.mktemp("cacheopt_test_models") / model_id model.save_pretrained(models_path) diff --git a/tests/python_tests/test_vlm_pipeline.py b/tests/python_tests/test_vlm_pipeline.py index 81c181bc54..62c1c27e3b 100644 --- a/tests/python_tests/test_vlm_pipeline.py +++ b/tests/python_tests/test_vlm_pipeline.py @@ -19,7 +19,7 @@ def get_ov_model(cache): ov_tokenizer, ov_detokenizer = openvino_tokenizers.convert_tokenizer(processor.tokenizer, with_detokenizer=True) openvino.save_model(ov_tokenizer, model_dir / "openvino_tokenizer.xml") openvino.save_model(ov_detokenizer, model_dir / "openvino_detokenizer.xml") - model = OVModelForVisualCausalLM.from_pretrained(model_id, compile=False, device="CPU", export=True, trust_remote_code=True) + model = OVModelForVisualCausalLM.from_pretrained(model_id, compile=False, device="CPU", export=True, load_in_8bit=False, trust_remote_code=True) processor.save_pretrained(model_dir) model.save_pretrained(model_dir) return model_dir diff --git a/tools/llm_bench/llm_bench_utils/pt_utils.py b/tools/llm_bench/llm_bench_utils/pt_utils.py index dc2c6d05f5..877c135a3c 100644 --- a/tools/llm_bench/llm_bench_utils/pt_utils.py +++ b/tools/llm_bench/llm_bench_utils/pt_utils.py @@ -62,11 +62,14 @@ def create_text_gen_model(model_path, device, **kwargs): model_class = PT_MODEL_CLASSES_MAPPING.get(model_type, PT_MODEL_CLASSES_MAPPING[default_model_type]) token_class = TOKENIZE_CLASSES_MAPPING.get(model_type, TOKENIZE_CLASSES_MAPPING[default_model_type]) start = time.perf_counter() - if model_type == 'chatglm': - model = model_class.from_pretrained(model_path, trust_remote_code=True).to('cpu', dtype=float) - else: - model = model_class.from_pretrained(model_path, trust_remote_code=True) - tokenizer = token_class.from_pretrained(model_path, trust_remote_code=True) + trust_remote_code = False + try: + model = model_class.from_pretrained(model_path, trust_remote_code=trust_remote_code) + except Exception: + start = time.perf_counter() + trust_remote_code = True + model = model_class.from_pretrained(model_path, trust_remote_code=trust_remote_code) + tokenizer = token_class.from_pretrained(model_path, trust_remote_code=trust_remote_code) end = time.perf_counter() from_pretrain_time = end - start else: diff --git a/tools/llm_bench/requirements.txt b/tools/llm_bench/requirements.txt index f5f4a3fdeb..6bf8d8cddf 100644 --- a/tools/llm_bench/requirements.txt +++ b/tools/llm_bench/requirements.txt @@ -1,5 +1,6 @@ --extra-index-url https://download.pytorch.org/whl/cpu numpy +--extra-index-url https://storage.openvinotoolkit.org/simple/wheels/pre-release --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly openvino openvino-tokenizers diff --git a/tools/who_what_benchmark/requirements.txt b/tools/who_what_benchmark/requirements.txt index 9d151abbf3..d4b702de78 100644 --- a/tools/who_what_benchmark/requirements.txt +++ b/tools/who_what_benchmark/requirements.txt @@ -1,11 +1,10 @@ +accelerate>=0.26.0 transformers>=4.35.2 sentence-transformers>=2.2.2 -openvino -openvino-tokenizers openvino-genai -openvino-telemetry -optimum-intel>=1.19.0 +optimum-intel[nncf]>=1.19.0 pandas>=2.0.3 numpy>=1.23.5 tqdm>=4.66.1 diffusers +datasets<3.2.0 diff --git a/tools/who_what_benchmark/tests/test_cli_image.py b/tools/who_what_benchmark/tests/test_cli_image.py index fec9e96f4c..ccd6ee1cec 100644 --- a/tools/who_what_benchmark/tests/test_cli_image.py +++ b/tools/who_what_benchmark/tests/test_cli_image.py @@ -42,8 +42,8 @@ def teardown_module(): ("hf-internal-testing/tiny-stable-diffusion-torch", "text-to-image", "hf"), ("hf-internal-testing/tiny-stable-diffusion-torch", "text-to-image", "openvino"), ("hf-internal-testing/tiny-stable-diffusion-xl-pipe", "text-to-image", "hf"), - # ("hf-internal-testing/tiny-stable-diffusion-torch", "image-inpainting", "hf"), - # ("hf-internal-testing/tiny-stable-diffusion-xl-pipe", "image-inpainting", "hf"), + ("hf-internal-testing/tiny-stable-diffusion-torch", "image-inpainting", "hf"), + ("hf-internal-testing/tiny-stable-diffusion-xl-pipe", "image-inpainting", "hf"), ], ) def test_image_model_types(model_id, model_type, backend): @@ -90,7 +90,7 @@ def test_image_model_types(model_id, model_type, backend): list(itertools.product(OV_IMAGE_MODELS, ["image-to-image", "text-to-image", - # "image-inpainting" + "image-inpainting" ])), ) def test_image_model_genai(model_id, model_type):