Skip to content

Commit 9639188

Browse files
Merge branch 'releases/2024/3' into nm/revert_update_defaults_release
2 parents 9651e87 + 33667bf commit 9639188

34 files changed

+1721
-203
lines changed

.github/workflows/causal_lm_cpp.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ concurrency:
1313
cancel-in-progress: true
1414

1515
env:
16-
l_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc1/linux/l_openvino_toolkit_ubuntu20_2024.3.0.dev20240711_x86_64.tgz
17-
m_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc1/macos/m_openvino_toolkit_macos_12_6_2024.3.0.dev20240711_x86_64.tgz
18-
w_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc1/windows/w_openvino_toolkit_windows_2024.3.0.dev20240711_x86_64.zip
16+
l_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc2/linux/l_openvino_toolkit_ubuntu20_2024.3.0.dev20240719_x86_64.tgz
17+
m_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc2/macos/m_openvino_toolkit_macos_12_6_2024.3.0.dev20240719_x86_64.tgz
18+
w_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc2/windows/w_openvino_toolkit_windows_2024.3.0.dev20240719_x86_64.zip
1919
jobs:
2020
cpp-multinomial-greedy_causal_lm-ubuntu:
2121
runs-on: ubuntu-20.04-8-cores

.github/workflows/genai_package.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ concurrency:
55
group: ${{ github.workflow }}-${{ github.head_ref || github.ref_name }}
66
cancel-in-progress: true
77
env:
8-
l_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc1/linux/l_openvino_toolkit_ubuntu20_2024.3.0.dev20240711_x86_64.tgz
9-
m_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc1/macos/m_openvino_toolkit_macos_12_6_2024.3.0.dev20240711_x86_64.tgz
10-
w_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc1/windows/w_openvino_toolkit_windows_2024.3.0.dev20240711_x86_64.zip
8+
l_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc2/linux/l_openvino_toolkit_ubuntu20_2024.3.0.dev20240719_x86_64.tgz
9+
m_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc2/macos/m_openvino_toolkit_macos_12_6_2024.3.0.dev20240719_x86_64.tgz
10+
w_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc2/windows/w_openvino_toolkit_windows_2024.3.0.dev20240719_x86_64.zip
1111
jobs:
1212
ubuntu_genai_package:
1313
strategy:
@@ -113,5 +113,6 @@ jobs:
113113
&& cmake --install "samples build" --config ${{ matrix.build-type }} --component samples_bin --prefix samples_install
114114
if: ${{ 'Release' != matrix.build-type }}
115115
- run: call ov\setupvars.bat && "${{ github.workspace }}/samples_install/samples_bin/greedy_causal_lm" .\TinyLlama-1.1B-Chat-v1.0\ ""
116+
if: ${{ 'Release' == matrix.build-type }} # Tokenizers don't work in debug
116117
- run: call ov\setupvars.bat && python .\ov\samples\python\multinomial_causal_lm\multinomial_causal_lm.py .\TinyLlama-1.1B-Chat-v1.0\ 0
117118
if: ${{ 'Release' == matrix.build-type }} # Python bindings can be built in Release only

.github/workflows/genai_python_lib.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ concurrency:
55
group: ${{ github.workflow }}-${{ github.head_ref || github.ref_name }}
66
cancel-in-progress: true
77
env:
8-
l_ov_centos_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc1/linux/l_openvino_toolkit_centos7_2024.3.0.dev20240711_x86_64.tgz
9-
m_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc1/macos/m_openvino_toolkit_macos_12_6_2024.3.0.dev20240711_x86_64.tgz
10-
w_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc1/windows/w_openvino_toolkit_windows_2024.3.0.dev20240711_x86_64.zip
8+
l_ov_centos_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc2/linux/l_openvino_toolkit_centos7_2024.3.0.dev20240719_x86_64.tgz
9+
m_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc2/macos/m_openvino_toolkit_macos_12_6_2024.3.0.dev20240719_x86_64.tgz
10+
w_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/pre-release/2024.3.0rc2/windows/w_openvino_toolkit_windows_2024.3.0.dev20240719_x86_64.zip
1111
jobs:
1212
ubuntu_genai_python_lib:
1313
# A tokenizers' dependency fails to compile on ubuntu-20 n CenOS7 env.

Dockerfile

Lines changed: 0 additions & 38 deletions
This file was deleted.

llm_bench/python/requirements.txt

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
22
numpy
33
--extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly
4-
openvino
5-
openvino-tokenizers
6-
openvino_genai
4+
openvino~=2024.3.0
5+
openvino-tokenizers~=2024.3.0
6+
openvino_genai~=2024.3.0
77
auto-gptq>=0.5.1 # for gptq
88
pillow
9-
torch
10-
transformers>=4.40.0
9+
torch<2.5.0
10+
torchvision<0.20.0
11+
transformers>=4.40.0,<4.43.0
1112
diffusers>=0.22.0
12-
#optimum is in dependency list of optimum-intel
13-
git+https://github.com/huggingface/optimum-intel.git@439d61f79cf55d5d0b28334f577b6ac3c5ced28f#egg=optimum-intel
14-
git+https://github.com/openvinotoolkit/nncf.git@develop#egg=nncf
13+
#optimum is in dependency list of optimum-intel
14+
git+https://github.com/huggingface/optimum-intel.git@6388aeb8738b63e28fc594af84df94590e77cb9a#egg=optimum-intel
15+
nncf~=2.12.0
1516
packaging
1617
psutil
1718
timm

samples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_subdirectory(cpp/greedy_causal_lm)
1010
add_subdirectory(cpp/multinomial_causal_lm)
1111
add_subdirectory(cpp/prompt_lookup_decoding_lm)
1212
add_subdirectory(cpp/speculative_decoding_lm)
13+
add_subdirectory(cpp/benchmark_genai)
1314

1415
install(FILES requirements.txt DESTINATION samples
1516
COMPONENT cpp_samples_genai)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (C) 2023-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
find_package(OpenVINOGenAI REQUIRED PATHS
6+
"${CMAKE_BINARY_DIR}" # Reuse the package from the build.
7+
${OpenVINO_DIR} # GenAI may be installed alogside OpenVINO.
8+
)
9+
10+
FetchContent_Declare(cxxopts
11+
URL https://github.com/jarro2783/cxxopts/archive/refs/tags/v3.1.1.tar.gz
12+
URL_HASH SHA256=523175f792eb0ff04f9e653c90746c12655f10cb70f1d5e6d6d9491420298a08)
13+
FetchContent_MakeAvailable(cxxopts)
14+
15+
add_executable(benchmark_genai benchmark_genai.cpp)
16+
target_link_libraries(benchmark_genai PRIVATE openvino::genai cxxopts::cxxopts)
17+
set_target_properties(benchmark_genai PROPERTIES
18+
COMPILE_PDB_NAME benchmark_genai
19+
# Ensure out of box LC_RPATH on macOS with SIP
20+
INSTALL_RPATH_USE_LINK_PATH ON)
21+
install(TARGETS benchmark_genai
22+
RUNTIME DESTINATION samples_bin/
23+
COMPONENT samples_bin
24+
EXCLUDE_FROM_ALL)

samples/cpp/benchmark_genai/README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# LLMs benchmarking sample
2+
3+
This sample script demonstrates how to benchmark an LLMs in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text, and calculating various performance metrics.
4+
5+
## Download and convert the model and tokenizers
6+
7+
The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version.
8+
9+
It's not required to install [../../requirements.txt](../../requirements.txt) for deployment if the model has already been exported.
10+
11+
```sh
12+
pip install --upgrade-strategy eager -r ../../requirements.txt
13+
optimum-cli export openvino --trust-remote-code --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0
14+
```
15+
16+
## Usage
17+
18+
```sh
19+
benchmark_vanilla_genai [OPTIONS]
20+
```
21+
22+
### Options
23+
24+
- `-m, --model`: Path to the model and tokenizers base directory.
25+
- `-p, --prompt` (default: `"The Sky is blue because"`): The prompt to generate text.
26+
- `-nw, --num_warmup` (default: `1`): Number of warmup iterations.
27+
- `-mt, --max_new_tokens` (default: `20`): Number of warmup iterations.
28+
- `-n, --num_iter` (default: `3`): Number of iterations.
29+
- `-d, --device` (default: `"CPU"`): Device to run the model on.
30+
31+
### Output:
32+
33+
```
34+
benchmark_vanilla_genai -m TinyLlama-1.1B-Chat-v1.0 -n 10
35+
```
36+
37+
```
38+
Load time: 3405.69 ms
39+
Generate time: 1430.77 ± 3.04 ms
40+
Tokenization time: 0.51 ± 0.02 ms
41+
Detokenization time: 0.37 ± 0.01 ms
42+
TTFT: 81.60 ± 0.54 ms
43+
TPOT: 71.52 ± 2.72 ms
44+
Throughput tokens/s: 13.98 ± 0.53
45+
```
46+
47+
For more information how performance metrics are calculated please follow [performance-metrics tutorial](../../../src/README.md#performance-metrics).
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include "openvino/genai/llm_pipeline.hpp"
5+
#include <cxxopts.hpp>
6+
7+
int main(int argc, char* argv[]) try {
8+
cxxopts::Options options("benchmark_vanilla_genai", "Help command");
9+
10+
options.add_options()
11+
("m,model", "Path to model and tokenizers base directory", cxxopts::value<std::string>()->default_value("."))
12+
("p,prompt", "Prompt", cxxopts::value<std::string>()->default_value("The Sky is blue because"))
13+
("nw,num_warmup", "Number of warmup iterations", cxxopts::value<size_t>()->default_value(std::to_string(1)))
14+
("n,num_iter", "Number of iterations", cxxopts::value<size_t>()->default_value(std::to_string(3)))
15+
("mt,max_new_tokens", "Maximal number of new tokens", cxxopts::value<size_t>()->default_value(std::to_string(20)))
16+
("d,device", "device", cxxopts::value<std::string>()->default_value("CPU"))
17+
("h,help", "Print usage");
18+
19+
cxxopts::ParseResult result;
20+
try {
21+
result = options.parse(argc, argv);
22+
} catch (const cxxopts::exceptions::exception& e) {
23+
std::cout << e.what() << "\n\n";
24+
std::cout << options.help() << std::endl;
25+
return EXIT_FAILURE;
26+
}
27+
28+
if (result.count("help")) {
29+
std::cout << options.help() << std::endl;
30+
return EXIT_SUCCESS;
31+
}
32+
33+
std::string prompt = result["prompt"].as<std::string>();
34+
const std::string model_path = result["model"].as<std::string>();
35+
std::string device = result["device"].as<std::string>();
36+
size_t num_warmup = result["num_warmup"].as<size_t>();
37+
size_t num_iter = result["num_iter"].as<size_t>();
38+
39+
ov::genai::GenerationConfig config;
40+
config.max_new_tokens = result["max_new_tokens"].as<size_t>();
41+
42+
ov::genai::LLMPipeline pipe(model_path, device);
43+
44+
for (size_t i = 0; i < num_warmup; i++)
45+
pipe.generate(prompt, config);
46+
47+
ov::genai::DecodedResults res = pipe.generate(prompt, config);
48+
ov::genai::PerfMetrics metrics = res.perf_metrics;
49+
for (size_t i = 0; i < num_iter - 1; i++) {
50+
res = pipe.generate(prompt, config);
51+
metrics = metrics + res.perf_metrics;
52+
}
53+
54+
std::cout << std::fixed << std::setprecision(2);
55+
std::cout << "Load time: " << metrics.get_load_time() << " ms" << std::endl;
56+
std::cout << "Generate time: " << metrics.get_generate_duration().mean << " ± " << metrics.get_generate_duration().std << " ms" << std::endl;
57+
std::cout << "Tokenization time: " << metrics.get_tokenization_duration().mean << " ± " << metrics.get_tokenization_duration().std << " ms" << std::endl;
58+
std::cout << "Detokenization time: " << metrics.get_detokenization_duration().mean << " ± " << metrics.get_detokenization_duration().std << " ms" << std::endl;
59+
std::cout << "TTFT: " << metrics.get_ttft().mean << " ± " << metrics.get_ttft().std << " ms" << std::endl;
60+
std::cout << "TPOT: " << metrics.get_tpot().mean << " ± " << metrics.get_tpot().std << " ms/token " << std::endl;
61+
std::cout << "Throughput: " << metrics.get_throughput().mean << " ± " << metrics.get_throughput().std << " tokens/s" << std::endl;
62+
63+
return 0;
64+
} catch (const std::exception& error) {
65+
std::cerr << error.what() << '\n';
66+
return EXIT_FAILURE;
67+
} catch (...) {
68+
std::cerr << "Non-exception object thrown\n";
69+
return EXIT_FAILURE;
70+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# LLMs benchmarking sample
2+
3+
This sample script demonstrates how to benchmark an LLMs in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text, and calculating various performance metrics.
4+
5+
## Download and convert the model and tokenizers
6+
7+
The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version.
8+
9+
It's not required to install [../../requirements.txt](../../requirements.txt) for deployment if the model has already been exported.
10+
11+
```sh
12+
pip install --upgrade-strategy eager -r ../../requirements.txt
13+
optimum-cli export openvino --trust-remote-code --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0
14+
```
15+
16+
## Usage
17+
18+
```sh
19+
python benchmark_vanilla_genai.py [OPTIONS]
20+
```
21+
22+
### Options
23+
24+
- `-m, --model`: Path to the model and tokenizers base directory.
25+
- `-p, --prompt` (default: `"The Sky is blue because"`): The prompt to generate text.
26+
- `-nw, --num_warmup` (default: `1`): Number of warmup iterations.
27+
- `-n, --num_iter` (default: `3`): Number of iterations.
28+
- `-mt, --max_new_tokens` (default: `20`): Number of warmup iterations.
29+
- `-d, --device` (default: `"CPU"`): Device to run the model on.
30+
31+
### Output:
32+
33+
```
34+
python benchmark_vanilla_genai.py -m TinyLlama-1.1B-Chat-v1.0 -n 10
35+
```
36+
37+
```
38+
Load time: 3405.69 ms
39+
Generate time: 1430.77 ± 3.04 ms
40+
Tokenization time: 0.51 ± 0.02 ms
41+
Detokenization time: 0.37 ± 0.01 ms
42+
TTFT: 81.60 ± 0.54 ms
43+
TPOT: 71.52 ± 2.72 ms
44+
Throughput tokens/s: 13.98 ± 0.53
45+
```
46+
47+
For more information on how performance metrics are calculated, see [performance metrics readme](../../../src/README.md#performance-metrics).
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (C) 2023-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import argparse
5+
import openvino_genai as ov_genai
6+
7+
def main():
8+
parser = argparse.ArgumentParser(description="Help command")
9+
parser.add_argument("-m", "--model", type=str, help="Path to model and tokenizers base directory")
10+
parser.add_argument("-p", "--prompt", type=str, default="The Sky is blue because", help="Prompt")
11+
parser.add_argument("-nw", "--num_warmup", type=int, default=1, help="Number of warmup iterations")
12+
parser.add_argument("-n", "--num_iter", type=int, default=2, help="Number of iterations")
13+
parser.add_argument("-mt", "--max_new_tokens", type=int, default=20, help="Maximal number of new tokens")
14+
parser.add_argument("-d", "--device", type=str, default="CPU", help="Device")
15+
16+
args = parser.parse_args()
17+
18+
# Perf metrics is stored in DecodedResults.
19+
# In order to get DecodedResults instead of a string input should be a list.
20+
prompt = [args.prompt]
21+
model_path = args.model
22+
device = args.device
23+
num_warmup = args.num_warmup
24+
num_iter = args.num_iter
25+
26+
config = ov_genai.GenerationConfig()
27+
config.max_new_tokens = args.max_new_tokens
28+
29+
pipe = ov_genai.LLMPipeline(model_path, device)
30+
31+
for _ in range(num_warmup):
32+
pipe.generate(prompt, config)
33+
34+
res = pipe.generate(prompt, config)
35+
perf_metrics = res.perf_metrics
36+
for _ in range(num_iter - 1):
37+
res = pipe.generate(prompt, config)
38+
perf_metrics += res.perf_metrics
39+
40+
print(f"Load time: {perf_metrics.get_load_time():.2f} ms")
41+
print(f"Generate time: {perf_metrics.get_generate_duration().mean:.2f} ± {perf_metrics.get_generate_duration().std:.2f} ms")
42+
print(f"Tokenization time: {perf_metrics.get_tokenization_duration().mean:.2f} ± {perf_metrics.get_tokenization_duration().std:.2f} ms")
43+
print(f"Detokenization time: {perf_metrics.get_detokenization_duration().mean:.2f} ± {perf_metrics.get_detokenization_duration().std:.2f} ms")
44+
print(f"TTFT: {perf_metrics.get_ttft().mean:.2f} ± {perf_metrics.get_ttft().std:.2f} ms")
45+
print(f"TPOT: {perf_metrics.get_tpot().mean:.2f} ± {perf_metrics.get_tpot().std:.2f} ms")
46+
print(f"Throughput : {perf_metrics.get_throughput().mean:.2f} ± {perf_metrics.get_throughput().std:.2f} tokens/s")
47+
48+
if __name__ == "__main__":
49+
main()

samples/python/multinomial_causal_lm/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
This example showcases inference of text-generation Large Language Models (LLMs): `chatglm`, `LLaMA`, `Qwen` and other models with the same signature. The application doesn't have many configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample fearures `ov::genai::LLMPipeline` and configures it to run random sampling algorithm. There is also a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/llm-chatbot) which provides an example of LLM-powered Chatbot in Python.
44

5+
This sample also contains example implementation of an iterable streamer with bufferisation.
6+
57
## Download and convert the model and tokenizers
68

79
The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version.
@@ -22,6 +24,12 @@ Discrete GPUs (dGPUs) usually provide better performance compared to CPUs. It is
2224

2325
See https://github.com/openvinotoolkit/openvino.genai/blob/master/src/README.md#supported-models for the list of supported models.
2426

27+
## Streaming
28+
29+
This Python example demonstrates custom detokenization with bufferization. The streamer receives integer tokens corresponding to each word or subword, one by one. If tokens are decoded individually, the resulting text misses necessary spaces because of detokenize(tokenize(" a")) == "a".
30+
31+
To address this, the detokenizer needs a larger context. We accumulate tokens in a tokens_cache buffer and decode multiple tokens together, adding the text to the streaming queue only when a complete decoded chunk is ready. We run a separate thread to print all new elements arriving in this queue from the generation pipeline. Each generated chunk of text is put into a synchronized queue, ensuring that all put and get operations are thread-safe and blocked until they can proceed.
32+
2533
### Troubleshooting
2634

2735
#### Unicode characters encoding error on Windows

0 commit comments

Comments
 (0)