Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Change kvcache default type of PagedAttention to u8 for CPU plugin #1206

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
28 changes: 9 additions & 19 deletions src/cpp/src/device_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,20 @@ class DeviceConfig {
m_block_size = get_block_size_by_device(device);

if (m_device == "CPU") {
auto inference_precision = core.get_property(device, ov::hint::inference_precision);
m_kv_cache_type = inference_precision == ov::element::bf16 ? ov::element::bf16 : ov::element::f16;

// if user sets precision hint, kv cache type should be changed
const auto inference_precision_it = plugin_config.find(ov::hint::inference_precision.name());
if (inference_precision_it != plugin_config.end()) {
const auto inference_precision = inference_precision_it->second.as<ov::element::Type>();
if (inference_precision == ov::element::f32) {
m_kv_cache_type = ov::element::f32;
} else if (inference_precision == ov::element::f16) {
m_kv_cache_type = ov::element::f16;
} else if (inference_precision == ov::element::bf16) {
m_kv_cache_type = ov::element::bf16;
} else {
// use default f32
m_kv_cache_type = ov::element::f32;
}
}

// if user sets ov::kv_cache_precision hint
const auto kv_cache_precision_it = plugin_config.find(ov::hint::kv_cache_precision.name());
if (kv_cache_precision_it != plugin_config.end()) {
const auto kv_cache_precision = kv_cache_precision_it->second.as<ov::element::Type>();
m_kv_cache_type = kv_cache_precision;
} else {
// ACCURACY mode will use f32 kvcache
const auto execution_mode_it = plugin_config.find(ov::hint::execution_mode.name());
if (execution_mode_it != plugin_config.end() && execution_mode_it->second.as<ov::hint::ExecutionMode>() == ov::hint::ExecutionMode::ACCURACY) {
m_kv_cache_type = ov::element::f32;
} else {
// x86 and arm have different default kv cache type
m_kv_cache_type = core.get_property(device, ov::hint::kv_cache_precision);
}
}
} else if (m_device.find("GPU") != std::string::npos) {
auto inference_precision = core.get_property(device, ov::hint::inference_precision);
Expand Down
41 changes: 13 additions & 28 deletions tests/cpp/cache_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,10 @@
#include "scheduler.hpp"
#include "device_config.hpp"
#include "cache_manager.hpp"
#include "openvino/op/concat.hpp"
#include "helper.hpp"

using namespace ov::genai;

std::shared_ptr<ov::Model> get_dummy_model(ov::Core core, size_t num_layers) {
ov::NodeVector keys;
ov::NodeVector values;
ov::ParameterVector params;
ov::element::Type inference_precision = core.get_property("CPU", ov::hint::inference_precision);
ov::element::Type kv_cache_type = inference_precision == ov::element::bf16 ? ov::element::bf16 : ov::element::f16;

auto shape = ov::PartialShape({ov::Dimension::dynamic(), ov::Dimension::dynamic(), ov::Dimension::dynamic(), ov::Dimension::dynamic()});
for (size_t i = 0; i < num_layers; i++) {
auto key = std::make_shared<ov::op::v0::Parameter>(kv_cache_type, shape);
auto value = std::make_shared<ov::op::v0::Parameter>(kv_cache_type, shape);
key->get_output_tensor(0).set_names({"key_cache." + std::to_string(i)});
value->get_output_tensor(0).set_names({"value_cache." + std::to_string(i)});
keys.push_back(key);
values.push_back(value);
params.push_back(key);
params.push_back(value);
}
const auto& concat1 = std::make_shared<ov::op::v0::Concat>(keys, 1);
const auto& concat2 = std::make_shared<ov::op::v0::Concat>(values, 1);
auto model = std::make_shared<ov::Model>(ov::NodeVector{concat1, concat2}, params);
return std::make_shared<ov::Model>(ov::NodeVector{concat1, concat2}, params);
}

size_t get_total_allocated_bytes(std::shared_ptr<ov::genai::CacheManager> cache_manager, size_t num_decoder_layers) {
size_t allocated_bytes = 0;
for (size_t i = 0; i < num_decoder_layers; i++) {
Expand All @@ -58,14 +34,23 @@ TEST(TestCacheManager, test_cache_size_param) {
ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU");
size_t num_decoder_layers = 12;
std::vector<size_t> num_kv_heads(12, 12);
device_config.set_model_params(num_kv_heads, 64, num_decoder_layers);
size_t head_size = 64;
device_config.set_model_params(num_kv_heads, head_size, num_decoder_layers);

ov::InferRequest request = core.compile_model(get_dummy_model(core, num_decoder_layers)).create_infer_request();
auto cache_manager = std::make_shared<ov::genai::CacheManager>(device_config, request, core);
auto block_manager = BlockManager(device_config.get_num_kv_blocks(), false, device_config.get_block_size(), device_config.get_num_layers());
cache_manager->allocate_cache_if_needed(block_manager.get_total_number_of_kv_blocks());

ASSERT_EQ(get_total_allocated_bytes(cache_manager, num_decoder_layers), 2146959360);

const size_t kv_cache_total_size = scheduler_config.cache_size * 1024 * 1024 * 1024;
const size_t cpu_block_size = 32;
// For u8 kvcahce, its scale, zero point and quantized data will be stored together.
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)|
// so, we have to extend head_size by 2 * sizeof(float)
const size_t cpu_block_size_total = num_decoder_layers * (num_kv_heads[0] + num_kv_heads[1]) * cpu_block_size * (head_size + 2 * sizeof(float)) * sizeof(uint8_t);
size_t expected_size = kv_cache_total_size / cpu_block_size_total * cpu_block_size_total;
ASSERT_EQ(get_total_allocated_bytes(cache_manager, num_decoder_layers), expected_size);
}


Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/device_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ TEST(TestDeviceConfig, kv_cache_precision_u8) {
size_t head_size = 64, head_size_u8 = head_size + 8;
std::vector<size_t> num_kv_heads(12, 12);

ov::genai::DeviceConfig device_config_default(core, scheduler_config, "CPU");
device_config_default.set_model_params(num_kv_heads, head_size_u8, num_decoder_layers);
ov::genai::DeviceConfig device_config_f16(core, scheduler_config, "CPU", { ov::hint::kv_cache_precision(ov::element::f16) });
device_config_f16.set_model_params(num_kv_heads, head_size_u8, num_decoder_layers);

ov::genai::DeviceConfig device_config_u8(core, scheduler_config, "CPU", { ov::hint::kv_cache_precision(ov::element::u8) });
ov::genai::DeviceConfig device_config_u8(core, scheduler_config, "CPU");
device_config_u8.set_model_params(num_kv_heads, head_size, num_decoder_layers);

const auto ratio = ov::element::f16.size() / ov::element::u8.size();
ASSERT_EQ(device_config_default.get_num_kv_blocks() * ratio, device_config_u8.get_num_kv_blocks());
ASSERT_EQ(device_config_f16.get_num_kv_blocks() * ratio, device_config_u8.get_num_kv_blocks());
}
29 changes: 29 additions & 0 deletions tests/cpp/helper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "helper.hpp"
#include "openvino/op/concat.hpp"

std::shared_ptr<ov::Model> get_dummy_model(ov::Core core, size_t num_layers) {
ov::NodeVector keys;
ov::NodeVector values;
ov::ParameterVector params;
ov::element::Type inference_precision = core.get_property("CPU", ov::hint::inference_precision);
ov::element::Type kv_cache_type = core.get_property("CPU", ov::hint::kv_cache_precision);

auto shape = ov::PartialShape({ov::Dimension::dynamic(), ov::Dimension::dynamic(), ov::Dimension::dynamic(), ov::Dimension::dynamic()});
for (size_t i = 0; i < num_layers; i++) {
auto key = std::make_shared<ov::op::v0::Parameter>(kv_cache_type, shape);
auto value = std::make_shared<ov::op::v0::Parameter>(kv_cache_type, shape);
key->get_output_tensor(0).set_names({"key_cache." + std::to_string(i)});
value->get_output_tensor(0).set_names({"value_cache." + std::to_string(i)});
keys.push_back(key);
values.push_back(value);
params.push_back(key);
params.push_back(value);
}
const auto& concat1 = std::make_shared<ov::op::v0::Concat>(keys, 1);
const auto& concat2 = std::make_shared<ov::op::v0::Concat>(values, 1);
auto model = std::make_shared<ov::Model>(ov::NodeVector{concat1, concat2}, params);
return std::make_shared<ov::Model>(ov::NodeVector{concat1, concat2}, params);
}
8 changes: 8 additions & 0 deletions tests/cpp/helper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "openvino/runtime/core.hpp"

std::shared_ptr<ov::Model> get_dummy_model(ov::Core core, size_t num_layers);
26 changes: 2 additions & 24 deletions tests/cpp/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "openvino/genai/generation_config.hpp"
#include "sequence_group.hpp"
#include "scheduler.hpp"
#include "helper.hpp"

using namespace ov::genai;

Expand All @@ -18,34 +19,11 @@ void clear_finished_sequences(std::vector<SequenceGroup::Ptr>& requests) {
});
requests.erase(new_end, requests.end());
}
std::shared_ptr<ov::Model> get_model(ov::Core core, size_t num_layers) {
ov::NodeVector keys;
ov::NodeVector values;
ov::ParameterVector params;
ov::element::Type inference_precision = core.get_property("CPU", ov::hint::inference_precision);
ov::element::Type kv_cache_type = inference_precision == ov::element::bf16 ? ov::element::bf16 : ov::element::f16;

auto shape = ov::PartialShape({ov::Dimension::dynamic(), ov::Dimension::dynamic(), ov::Dimension::dynamic(), ov::Dimension::dynamic()});
for (size_t i = 0; i < num_layers; i++) {
auto key = std::make_shared<ov::op::v0::Parameter>(kv_cache_type, shape);
auto value = std::make_shared<ov::op::v0::Parameter>(kv_cache_type, shape);
key->get_output_tensor(0).set_names({"key_cache." + std::to_string(i)});
value->get_output_tensor(0).set_names({"value_cache." + std::to_string(i)});
keys.push_back(key);
values.push_back(value);
params.push_back(key);
params.push_back(value);
}
const auto& concat1 = std::make_shared<ov::op::v0::Concat>(keys, 1);
const auto& concat2 = std::make_shared<ov::op::v0::Concat>(values, 1);
auto model = std::make_shared<ov::Model>(ov::NodeVector{concat1, concat2}, params);
return std::make_shared<ov::Model>(ov::NodeVector{concat1, concat2}, params);
}

std::shared_ptr<CacheManager> init_cache_manager(SchedulerConfig scheduler_config) {
ov::Core core = ov::Core();
size_t num_decoder_layers = 12;
ov::InferRequest request = core.compile_model(get_model(core, num_decoder_layers)).create_infer_request();
ov::InferRequest request = core.compile_model(get_dummy_model(core, num_decoder_layers)).create_infer_request();
size_t head_size = 64, head_size_u8 = head_size + 8;
std::vector<size_t> num_kv_heads(12, 12);
ov::genai::DeviceConfig device_config(core, scheduler_config, "CPU");
Expand Down
Loading