diff --git a/src/cpp/src/cache_eviction.cpp b/src/cpp/src/cache_eviction.cpp index 94f3110cce..624273e92f 100644 --- a/src/cpp/src/cache_eviction.cpp +++ b/src/cpp/src/cache_eviction.cpp @@ -267,4 +267,74 @@ namespace ov::genai { m_scores[decoder_layer_idx] = new_scores; m_cache_counter[decoder_layer_idx] = new_counter; } + + CacheRotationCalculator::CacheRotationCalculator(size_t block_size, size_t max_context_length, size_t kv_head_size, double rope_theta) : m_block_size(block_size), m_head_size(kv_head_size) { + // Frequencies follow the original recipe from RoFormer: + // https://arxiv.org/pdf/2104.09864v5 + // + // However, the way the rotation coefficients are ultimately applied in Llama and related models from huggingface is very different + // from the RoFormer - the embedding-dimension coefficients are not treated as consecutive x-y coordinate pairs, but are rather + // divided into contiguous x-like and y-like halves - see `rotate_half` function in HF transformers. It can be shown that this form + // still preserves the relative positioning property from the RoFormer article. + OPENVINO_ASSERT(rope_theta > 0, "rope_theta must be positive"); + size_t max_position_angle_multiplier = max_context_length; + size_t num_freqs = kv_head_size / 2; + m_rope_sin_lut.resize(max_position_angle_multiplier); + m_rope_cos_lut.resize(max_position_angle_multiplier); + + for (size_t i = 0; i < max_position_angle_multiplier; i++) { + m_rope_sin_lut[i].reserve(num_freqs); + m_rope_cos_lut[i].reserve(num_freqs); + for (size_t j = 0; j < num_freqs; j++) { + double exponent = - static_cast(2 * j) / kv_head_size; + double base_angle = std::pow(rope_theta, exponent); + m_rope_sin_lut[i].push_back(-std::sin(i * base_angle)); // minus since we will be rotating by an inverse angle + m_rope_cos_lut[i].push_back(std::cos(i * base_angle)); + } + } + } + + std::vector CacheRotationCalculator::get_rotation_coefficients(const std::set& evicted_block_logical_indices, size_t num_logical_blocks_before_eviction) { + OPENVINO_ASSERT(num_logical_blocks_before_eviction * m_block_size < m_rope_sin_lut.size(), + "num_logical_blocks_before_eviction may not correspond to less tokens than max_context_length"); + + std::vector retval; + if (evicted_block_logical_indices.empty()) { + return retval; + } + + for (auto idx : evicted_block_logical_indices) { + OPENVINO_ASSERT(idx < num_logical_blocks_before_eviction); + } + + // num_logical_blocks_before_eviction > evicted_block_logical_indices.size() is automatically guaranteed by the + // set property and the previous assertion + retval.reserve(num_logical_blocks_before_eviction - evicted_block_logical_indices.size()); + + ptrdiff_t current_rotation_delta_in_blocks = 0; + std::vector logical_block_space(num_logical_blocks_before_eviction); + std::iota(logical_block_space.begin(), logical_block_space.end(), 0); + + for (size_t logical_block_idx : logical_block_space) { + if (evicted_block_logical_indices.find(logical_block_idx) != evicted_block_logical_indices.end()) { + current_rotation_delta_in_blocks += 1; + } + else { + if (current_rotation_delta_in_blocks != 0) { + BlockRotationData block_rotation_data; + block_rotation_data.logical_block_idx = logical_block_idx - current_rotation_delta_in_blocks; + block_rotation_data.cosines.reserve(m_block_size); + block_rotation_data.sines.reserve(m_block_size); + for (size_t i = 0; i < m_block_size; i++) { + block_rotation_data.cosines.push_back(m_rope_cos_lut[current_rotation_delta_in_blocks * m_block_size]); + block_rotation_data.sines.push_back(m_rope_sin_lut[current_rotation_delta_in_blocks * m_block_size]); + } + + retval.push_back(block_rotation_data); + } + } + } + + return retval; + } } diff --git a/src/cpp/src/cache_eviction.hpp b/src/cpp/src/cache_eviction.hpp index c049156443..3829b764ea 100644 --- a/src/cpp/src/cache_eviction.hpp +++ b/src/cpp/src/cache_eviction.hpp @@ -117,88 +117,59 @@ class CacheEvictionAlgorithm { std::vector> m_cache_counter; }; +/** + * @brief Computes, based on the logical indices of the blocks to be evicted, the rotation coefficients for the + * remaining cache blocks. + * + * The rotation assumes that the executed model applies rotary positional embedding (RoPE) during the execution of + * the attention operation. Each cache block therefore has the RoPE values already "baked in", with positions equivalent + * to the point in time when the cache block values were originally computed in one of the previous attention operations. + * When blocks are evicted, the logical index space of the remaining blocks is in general no longer contiguous with respect to + * the effective positions of tokens in the blocks. Cache rotation allows to remedy this by effectively adjusting the RoPE positions + * of certain blocks in the cache after eviction, by additionally "rotating" them (in the same sense as in RoPE) by such angles that + * the cache blocks in the logical index space are again contiguous in terms of the RoPE positions. This is supposed to make the + * eviction process less impactful on the accuracy of the generation. + * + * Currently only the basic RoPE method is supported (as applied in the Llama original models). Each model in general may have + * its own RoPE method (e.g. non-linear/NTK frequency scaling), and ideally the cache rotation calculator should be adjusted based on + * the specifics of the RoPE defined by the LLM. + */ class CacheRotationCalculator { public: - CacheRotationCalculator(size_t block_size, size_t max_context_length, size_t kv_head_size, double rope_theta = 10000.0f) : m_block_size(block_size), m_head_size(kv_head_size) { - // Frequencies follow the original recipe from RoFormer: - // https://arxiv.org/pdf/2104.09864v5 - // - // However, the way the rotation coefficients are ultimately applied in Llama and related models from huggingface is very different - // from the RoFormer - the embedding-dimension coefficients are not treated as consecutive x-y coordinate pairs, but are rather - // divided into contiguous x-like and y-like halves - see `rotate_half` function in HF transformers. It can be shown that this form - // still preserves the relative positioning property from the RoFormer article. - OPENVINO_ASSERT(rope_theta > 0, "rope_theta must be positive"); - size_t max_position_angle_multiplier = max_context_length; - size_t num_freqs = kv_head_size / 2; - m_rope_sin_lut.resize(max_position_angle_multiplier); - m_rope_cos_lut.resize(max_position_angle_multiplier); - - for (size_t i = 0; i < max_position_angle_multiplier; i++) { - m_rope_sin_lut[i].reserve(num_freqs); - m_rope_cos_lut[i].reserve(num_freqs); - for (size_t j = 0; j < num_freqs; j++) { - double exponent = - static_cast(2 * j) / kv_head_size; - double base_angle = std::pow(rope_theta, exponent); - m_rope_sin_lut[i].push_back(-std::sin(i * base_angle)); // minus since we will be rotating by an inverse angle - m_rope_cos_lut[i].push_back(std::cos(i * base_angle)); - } - } - }; + /** + * Constructs a CacheRotationCalculator. + * @param block_size Block size of the KV cache to evict from. + * @param max_context_length Maximum length possible for a sequence in the current pipeline. + * @param kv_head_size The size (in elements) of the embedding dimension in the attention operation. + * @param rope_theta The base RoPE angle used in the original LLM. + */ + CacheRotationCalculator(size_t block_size, size_t max_context_length, size_t kv_head_size, double rope_theta = 10000.0f); using RotationCoefficientsPerToken = std::vector>; // dimensions: [BLOCK_SIZE, head_size / 2] + + /** + * Basic output structure for the calculator. + */ struct BlockRotationData { bool operator==(const BlockRotationData& rhs) const { return (logical_block_idx == rhs.logical_block_idx) && (sines == rhs.sines) && (cosines == rhs.cosines); } - size_t logical_block_idx; // **NOTE**: corresponds to logical index AFTER eviction - RotationCoefficientsPerToken sines; - RotationCoefficientsPerToken cosines; + size_t logical_block_idx; /** Logical index of the block AFTER eviction to which the sine and cosine coefficients should be applied */ + RotationCoefficientsPerToken sines; /** The sine coefficients to be applied to this block's contents for rotation, in order of the block's elements */ + RotationCoefficientsPerToken cosines; /** The cosine coefficients to be applied to this block's contents for rotation, in order of the block's elements */ }; - std::vector get_rotation_multipliers(const std::set& evicted_block_logical_indices, size_t num_logical_blocks_before_eviction) { - OPENVINO_ASSERT(num_logical_blocks_before_eviction * m_block_size < m_rope_sin_lut.size(), - "num_logical_blocks_before_eviction may not correspond to less tokens than max_context_length"); - - std::vector retval; - if (evicted_block_logical_indices.empty()) { - return retval; - } - - for (auto idx : evicted_block_logical_indices) { - OPENVINO_ASSERT(idx < num_logical_blocks_before_eviction); - } - - // num_logical_blocks_before_eviction > evicted_block_logical_indices.size() is automatically guaranteed by the - // set property and the previous assertion - retval.reserve(num_logical_blocks_before_eviction - evicted_block_logical_indices.size()); - - ptrdiff_t current_rotation_delta_in_blocks = 0; - std::vector logical_block_space(num_logical_blocks_before_eviction); - std::iota(logical_block_space.begin(), logical_block_space.end(), 0); - - for (size_t logical_block_idx : logical_block_space) { - if (evicted_block_logical_indices.find(logical_block_idx) != evicted_block_logical_indices.end()) { - current_rotation_delta_in_blocks += 1; - } - else { - if (current_rotation_delta_in_blocks != 0) { - BlockRotationData block_rotation_data; - block_rotation_data.logical_block_idx = logical_block_idx - current_rotation_delta_in_blocks; - block_rotation_data.cosines.reserve(m_block_size); - block_rotation_data.sines.reserve(m_block_size); - for (size_t i = 0; i < m_block_size; i++) { - block_rotation_data.cosines.push_back(m_rope_cos_lut[current_rotation_delta_in_blocks * m_block_size]); - block_rotation_data.sines.push_back(m_rope_sin_lut[current_rotation_delta_in_blocks * m_block_size]); - } - - retval.push_back(block_rotation_data); - } - } - } - - return retval; - } + /** + * Computes the rotation coefficients for the given state of the logical block space when eviction is about to take place. + * @param evicted_block_logical_indices The logical block indices that the prior cache eviction algorithm step determined to be necessary to evict. + * @param num_logical_blocks_before_eviction Number of logical blocks that the evicted-from sequence occupied before the eviction step. + * @return A vector of per-block rotation data, including the indices of blocks after eviction that should be rotated, and the pre-computed trigonometric coefficients necessary for rotation. + */ + std::vector get_rotation_coefficients(const std::set& evicted_block_logical_indices, size_t num_logical_blocks_before_eviction); + /** + * @return The size of the embedding dimension that this CacheRotationCalculator was initialized with. + */ size_t get_head_size() const { return m_head_size; } diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 3f2b4510aa..c5fef4ea4a 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -74,10 +74,11 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init( // and finally create model runner bool is_use_cache_eviction = m_scheduler->get_config().use_cache_eviction; if (is_use_cache_eviction) { - m_model_runner = std::make_shared(infer_request, updated_config, device_config.get_num_layers(), - /* m_collect_attention_scores = */ true); + m_model_runner = std::make_shared(infer_request, m_scheduler->get_block_size(), device_config.get_num_layers(), + /* collect_attention_scores = */ true, + /* is_use_per_layer_cache_control = */ true); m_rotation_coefficient_stores.reserve(device_config.get_num_layers()); - ov::Shape rotation_coefficient_store_shape{ device_config.get_head_size() * (scheduler_config.block_size * scheduler_config.num_kv_blocks) }; + ov::Shape rotation_coefficient_store_shape{ device_config.get_head_size() * (m_scheduler->get_block_size() * scheduler_config.num_kv_blocks) }; for (size_t i = 0; i < device_config.get_num_layers(); i++) { ov::Tensor store(ov::element::f32, rotation_coefficient_store_shape); std::memset(store.data(), 0, store.get_byte_size()); @@ -85,16 +86,15 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init( } m_next_step_rotation_coefficients.resize(device_config.get_num_layers()); m_next_step_rotated_block_logical_indices_per_sequence.resize(device_config.get_num_layers()); - m_cache_rotation_calculator = std::make_shared(scheduler_config.block_size, + m_cache_rotation_calculator = std::make_shared(m_scheduler->get_block_size(), // TODO (vshampor): LUT size equal to max cache size in tokens // is overkill - find a way to pass the max sequence length instead - scheduler_config.block_size * scheduler_config.num_kv_blocks, + m_scheduler->get_block_size() * scheduler_config.num_kv_blocks, device_config.get_head_size()); } else { - m_model_runner = std::make_shared(infer_request, updated_config, device_config.get_num_layers()); + m_model_runner = std::make_shared(infer_request, m_scheduler->get_block_size(), device_config.get_num_layers()); } - m_model_runner = std::make_shared(infer_request, m_scheduler->get_block_size(), device_config.get_num_layers(), is_use_cache_eviction); m_sampler = std::make_shared(m_tokenizer); m_sampler->set_seed(m_generation_config.rng_seed); @@ -218,7 +218,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() { // evict unimportant blocks from KV cache, if requested if (sched_config.use_cache_eviction) { maybe_evict_cache_blocks(sched_config); - m_model_runner->set_cache_rotation_data(m_next_step_rotation_coefficients, m_next_step_rotated_block_logical_indices_per_sequence); + m_model_runner->set_cache_rotation_data(std::move(m_next_step_rotation_coefficients), + std::move(m_next_step_rotated_block_logical_indices_per_sequence)); } #ifdef DEBUG_CACHE_STATE_DUMP @@ -407,6 +408,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_block std::vector num_blocks_to_rotate_for_each_layer(num_decoder_layers, 0); size_t head_size = m_cache_rotation_calculator->get_head_size(); + // necessary since we move from these members during previous steps m_next_step_rotation_coefficients.clear(); m_next_step_rotated_block_logical_indices_per_sequence.clear(); m_next_step_rotated_block_logical_indices_per_sequence.resize(num_decoder_layers); @@ -429,19 +431,19 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_block } size_t num_blocks_before_eviction = m_scheduler->get_block_tables(seq_id)[layer_idx].size(); auto rotation_multipliers = - m_cache_rotation_calculator->get_rotation_multipliers(logical_blocks_to_evict[layer_idx], + m_cache_rotation_calculator->get_rotation_coefficients(logical_blocks_to_evict[layer_idx], num_blocks_before_eviction); for (size_t i = 0; i < rotation_multipliers.size(); i++) { const auto& block_rotation_data = rotation_multipliers[i]; const auto& rotation_multipliers_cos = block_rotation_data.cosines; const auto& rotation_multipliers_sin = block_rotation_data.sines; OPENVINO_ASSERT(rotation_multipliers_cos.size() == rotation_multipliers_sin.size()); - OPENVINO_ASSERT(rotation_multipliers_cos.size() == sched_config.block_size); + OPENVINO_ASSERT(rotation_multipliers_cos.size() == m_scheduler->get_block_size()); m_next_step_rotated_block_logical_indices_per_sequence[layer_idx][seq_id].push_back(block_rotation_data.logical_block_idx); // Fill the store tensor with rotation coefficient data - cos and sin coefficients are each contiguous, cos goes first - size_t block_offset = num_blocks_to_rotate_for_each_layer[layer_idx] * sched_config.block_size * head_size; + size_t block_offset = num_blocks_to_rotate_for_each_layer[layer_idx] * m_scheduler->get_block_size() * head_size; auto rotation_multipliers_tensor_data = m_rotation_coefficient_stores[layer_idx].data() + block_offset; for (size_t tok_idx = 0; tok_idx < rotation_multipliers_cos.size(); tok_idx++) { size_t position_offset = head_size * tok_idx; @@ -471,7 +473,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_block // Select the previously filled rotation coefficients from the store tensor for (size_t i = 0; i < num_decoder_layers; i++) { - m_next_step_rotation_coefficients.emplace_back(m_rotation_coefficient_stores[i], ov::Coordinate{0}, ov::Coordinate{num_blocks_to_rotate_for_each_layer[i] * sched_config.block_size * head_size}); + m_next_step_rotation_coefficients.emplace_back(m_rotation_coefficient_stores[i], ov::Coordinate{0}, ov::Coordinate{num_blocks_to_rotate_for_each_layer[i] * m_scheduler->get_block_size() * head_size}); } for (const auto& seq_group_ptr_and_num_blocks_evicted : seq_group_to_num_blocks_evicted_map) { diff --git a/src/cpp/src/model_runner.hpp b/src/cpp/src/model_runner.hpp index db7229b685..cafcefc559 100644 --- a/src/cpp/src/model_runner.hpp +++ b/src/cpp/src/model_runner.hpp @@ -32,6 +32,7 @@ class ModelRunner { AttentionScoresForEachSubsequence m_last_attention_scores; size_t m_num_decoder_layers, m_block_size; bool m_collect_attention_scores; + bool m_is_use_per_layer_cache_control; std::vector m_cache_rotation_coefficients; std::vector>> m_rotated_block_logical_indices_per_sequence_for_each_layer; @@ -39,16 +40,17 @@ class ModelRunner { /** * Constructs the ModelRunner. * @param request The ov::InferRequest for the LLM to be inferred in the continuous batching mode. - * @param scheduler_config Configuration struct for the scheduler that is to be used with this ModelRunner. * @param num_decoder_layers Number of decoder attention layers in the LLM corresponding to the request. * @param collect_attention_scores If true, then after each `forward` call the ModelRunner will collect and make available the per-token attention * scores for each decoder layer, so that these can be used in per-step cache optimizations (such as cache eviction algorithm). + * @param is_use_per_layer_cache_control If true, then the runner will pass cache control input tensors to the model on a per-attention layer basis. */ - ModelRunner(ov::InferRequest request, size_t block_size, size_t num_decoder_layers = 1, bool collect_attention_scores = false) : + ModelRunner(ov::InferRequest request, size_t block_size, size_t num_decoder_layers = 1, bool collect_attention_scores = false, bool is_use_per_layer_cache_control = false) : m_request(std::move(request)), m_block_size(block_size), m_num_decoder_layers(num_decoder_layers), - m_collect_attention_scores(collect_attention_scores) { + m_collect_attention_scores(collect_attention_scores), + m_is_use_per_layer_cache_control(is_use_per_layer_cache_control) { OPENVINO_ASSERT(m_num_decoder_layers != 0, "num_decoder_layers must be non-zero"); } @@ -68,10 +70,9 @@ class ModelRunner { return m_last_attention_scores; } - void set_cache_rotation_data(const std::vector& cache_rotation_coefficients_for_each_layer, const std::vector>>& rotated_logical_block_indices_per_sequence_for_each_layer) { - // TODO (vshampor): avoid vector copy - m_cache_rotation_coefficients = cache_rotation_coefficients_for_each_layer; - m_rotated_block_logical_indices_per_sequence_for_each_layer = rotated_logical_block_indices_per_sequence_for_each_layer; + void set_cache_rotation_data(std::vector&& cache_rotation_coefficients_for_each_layer, const std::vector>>&& rotated_logical_block_indices_per_sequence_for_each_layer) { + m_cache_rotation_coefficients = std::move(cache_rotation_coefficients_for_each_layer); + m_rotated_block_logical_indices_per_sequence_for_each_layer = std::move(rotated_logical_block_indices_per_sequence_for_each_layer); } /** @@ -264,7 +265,7 @@ class ModelRunner { size_t total_num_blocks) { std::vector tensor_names = {"block_indices"}; - if (m_scheduler_config.use_cache_eviction) { + if (m_is_use_per_layer_cache_control) { tensor_names.resize(m_num_decoder_layers); for (size_t i = 0; i < tensor_names.size(); i++) { tensor_names[i] = std::string("block_indices.") + std::to_string(i); diff --git a/src/cpp/src/scheduler.hpp b/src/cpp/src/scheduler.hpp index 990b994a83..e7d0f2f060 100644 --- a/src/cpp/src/scheduler.hpp +++ b/src/cpp/src/scheduler.hpp @@ -81,6 +81,7 @@ class Scheduler { const size_t get_block_size() const { return m_block_manager.get_block_size(); + } const std::vector& get_block_tables(size_t seq_id) const { return m_block_manager.get_block_tables(seq_id); diff --git a/tests/cpp/cache_eviction.cpp b/tests/cpp/cache_eviction.cpp index 247a8e806b..f4558001b7 100644 --- a/tests/cpp/cache_eviction.cpp +++ b/tests/cpp/cache_eviction.cpp @@ -472,7 +472,7 @@ TEST_P(CacheRotationCalculatorInvalidInputParameterizedTest, ThrowsForInvalidEvi auto calc = ov::genai::CacheRotationCalculator(init_params.block_size, init_params.max_context_length, init_params.kv_head_size, init_params.rope_theta); - EXPECT_THROW(calc.get_rotation_multipliers(test_struct.evicted_block_logical_indices, test_struct.num_logical_blocks_before_eviction), + EXPECT_THROW(calc.get_rotation_coefficients(test_struct.evicted_block_logical_indices, test_struct.num_logical_blocks_before_eviction), ov::Exception); } @@ -516,7 +516,7 @@ TEST_P(CacheRotationCalculatorNumCoefficientsParameterizedTest, GivesCorrectNumb auto calc = ov::genai::CacheRotationCalculator(init_params.block_size, init_params.max_context_length, init_params.kv_head_size, init_params.rope_theta); - const auto rotation_multipliers = calc.get_rotation_multipliers(test_struct.evicted_block_logical_indices, + const auto rotation_multipliers = calc.get_rotation_coefficients(test_struct.evicted_block_logical_indices, test_struct.num_logical_blocks_before_eviction); ASSERT_EQ(rotation_multipliers.size(), test_struct.expected_num_rotated_blocks); @@ -675,7 +675,7 @@ TEST_P(CacheRotationCalculatorRefCoefficientsParameterizedTest, CalculatedCoeffi auto calc = ov::genai::CacheRotationCalculator(init_params.block_size, init_params.max_context_length, init_params.kv_head_size, init_params.rope_theta); - const auto rotation_multipliers = calc.get_rotation_multipliers(test_struct.evicted_block_logical_indices, + const auto rotation_multipliers = calc.get_rotation_coefficients(test_struct.evicted_block_logical_indices, test_struct.num_logical_blocks_before_eviction); compare_rotation_data(rotation_multipliers, test_struct.expected_rotation_data); @@ -733,7 +733,7 @@ TEST(CacheRotationCalculatorPOCRefCoefficientsTest, CalculatedCoefficientsAreSim } } - auto test_data = calc.get_rotation_multipliers(ref_evicted_logical_block_indices, num_blocks_before_eviction); + auto test_data = calc.get_rotation_coefficients(ref_evicted_logical_block_indices, num_blocks_before_eviction); compare_rotation_data(test_data, ref_data, 1e-2); // the dump values were originally calculated in FP16 precision }