diff --git a/src/cpp/src/cache_eviction.cpp b/src/cpp/src/cache_eviction.cpp index 94f3110cce..d79f12dd18 100644 --- a/src/cpp/src/cache_eviction.cpp +++ b/src/cpp/src/cache_eviction.cpp @@ -267,4 +267,84 @@ namespace ov::genai { m_scores[decoder_layer_idx] = new_scores; m_cache_counter[decoder_layer_idx] = new_counter; } + + CacheRotationCalculator::CacheRotationCalculator(size_t block_size, + size_t max_context_length, + size_t kv_head_size, + double rope_theta) + : m_block_size(block_size), + m_head_size(kv_head_size) { + // Frequencies follow the original recipe from RoFormer: + // https://arxiv.org/pdf/2104.09864v5 + // + // However, the way the rotation coefficients are ultimately applied in Llama and related models from + // huggingface is very different from the RoFormer - the embedding-dimension coefficients are not treated as + // consecutive x-y coordinate pairs, but are rather divided into contiguous x-like and y-like halves - see + // `rotate_half` function in HF transformers. It can be shown that this form still preserves the relative + // positioning property from the RoFormer article. + OPENVINO_ASSERT(rope_theta > 0, "rope_theta must be positive"); + size_t max_position_angle_multiplier = max_context_length; + size_t num_freqs = kv_head_size / 2; + m_rope_sin_lut.resize(max_position_angle_multiplier); + m_rope_cos_lut.resize(max_position_angle_multiplier); + + for (size_t i = 0; i < max_position_angle_multiplier; i++) { + m_rope_sin_lut[i].reserve(num_freqs); + m_rope_cos_lut[i].reserve(num_freqs); + for (size_t j = 0; j < num_freqs; j++) { + double exponent = -static_cast(2 * j) / kv_head_size; + double base_angle = std::pow(rope_theta, exponent); + m_rope_sin_lut[i].push_back( + -std::sin(i * base_angle)); // minus since we will be rotating by an inverse angle + m_rope_cos_lut[i].push_back(std::cos(i * base_angle)); + } + } + } + + std::vector CacheRotationCalculator::get_rotation_coefficients( + const std::set& evicted_block_logical_indices, + size_t num_logical_blocks_before_eviction) { + OPENVINO_ASSERT(num_logical_blocks_before_eviction * m_block_size < m_rope_sin_lut.size(), + "num_logical_blocks_before_eviction may not correspond to less tokens than max_context_length"); + + std::vector retval; + if (evicted_block_logical_indices.empty()) { + return retval; + } + + for (auto idx : evicted_block_logical_indices) { + OPENVINO_ASSERT(idx < num_logical_blocks_before_eviction); + } + + // num_logical_blocks_before_eviction > evicted_block_logical_indices.size() is automatically guaranteed by the + // set property and the previous assertion + retval.reserve(num_logical_blocks_before_eviction - evicted_block_logical_indices.size()); + + ptrdiff_t current_rotation_delta_in_blocks = 0; + std::vector logical_block_space(num_logical_blocks_before_eviction); + std::iota(logical_block_space.begin(), logical_block_space.end(), 0); + + for (size_t logical_block_idx : logical_block_space) { + if (evicted_block_logical_indices.find(logical_block_idx) != evicted_block_logical_indices.end()) { + current_rotation_delta_in_blocks += 1; + } else { + if (current_rotation_delta_in_blocks != 0) { + BlockRotationData block_rotation_data; + block_rotation_data.logical_block_idx = logical_block_idx - current_rotation_delta_in_blocks; + block_rotation_data.cosines.reserve(m_block_size); + block_rotation_data.sines.reserve(m_block_size); + for (size_t i = 0; i < m_block_size; i++) { + block_rotation_data.cosines.push_back( + m_rope_cos_lut[current_rotation_delta_in_blocks * m_block_size]); + block_rotation_data.sines.push_back( + m_rope_sin_lut[current_rotation_delta_in_blocks * m_block_size]); + } + + retval.push_back(block_rotation_data); + } + } + } + + return retval; + } } diff --git a/src/cpp/src/cache_eviction.hpp b/src/cpp/src/cache_eviction.hpp index a32eb1ad0a..ee0d0a76b1 100644 --- a/src/cpp/src/cache_eviction.hpp +++ b/src/cpp/src/cache_eviction.hpp @@ -117,4 +117,79 @@ class CacheEvictionAlgorithm { std::vector> m_cache_counter; }; +/** + * @brief Computes, based on the logical indices of the blocks to be evicted, the rotation coefficients for the + * remaining cache blocks. + * + * The rotation assumes that the executed model applies rotary positional embedding (RoPE) during the execution of + * the attention operation. Each cache block therefore has the RoPE values already "baked in", with positions equivalent + * to the point in time when the cache block values were originally computed in one of the previous attention + * operations. When blocks are evicted, the logical index space of the remaining blocks is in general no longer + * contiguous with respect to the effective positions of tokens in the blocks. Cache rotation allows to remedy this by + * effectively adjusting the RoPE positions of certain blocks in the cache after eviction, by additionally "rotating" + * them (in the same sense as in RoPE) by such angles that the cache blocks in the logical index space are again + * contiguous in terms of the RoPE positions. This is supposed to make the eviction process less impactful on the + * accuracy of the generation. + * + * Currently only the basic RoPE method is supported (as applied in the Llama original models). Each model in general + * may have its own RoPE method (e.g. non-linear/NTK frequency scaling), and ideally the cache rotation calculator + * should be adjusted based on the specifics of the RoPE defined by the LLM. + */ +class CacheRotationCalculator { +public: + /** + * Constructs a CacheRotationCalculator. + * @param block_size Block size of the KV cache to evict from. + * @param max_context_length Maximum length possible for a sequence in the current pipeline. + * @param kv_head_size The size (in elements) of the embedding dimension in the attention operation. + * @param rope_theta The base RoPE angle used in the original LLM. + */ + CacheRotationCalculator(size_t block_size, + size_t max_context_length, + size_t kv_head_size, + double rope_theta = 10000.0f); + + using RotationCoefficientsPerToken = std::vector>; // dimensions: [BLOCK_SIZE, head_size / 2] + + /** + * Basic output structure for the calculator. + */ + struct BlockRotationData { + bool operator==(const BlockRotationData& rhs) const { + return (logical_block_idx == rhs.logical_block_idx) && (sines == rhs.sines) && (cosines == rhs.cosines); + } + size_t logical_block_idx; /** Logical index of the block AFTER eviction to which the sine and cosine + coefficients should be applied */ + RotationCoefficientsPerToken sines; /** The sine coefficients to be applied to this block's contents for + rotation, in order of the block's elements */ + RotationCoefficientsPerToken cosines; /** The cosine coefficients to be applied to this block's contents for + rotation, in order of the block's elements */ + }; + + /** + * Computes the rotation coefficients for the given state of the logical block space when eviction is about to take + * place. + * @param evicted_block_logical_indices The logical block indices that the prior cache eviction algorithm step + * determined to be necessary to evict. + * @param num_logical_blocks_before_eviction Number of logical blocks that the evicted-from sequence occupied before + * the eviction step. + * @return A vector of per-block rotation data, including the indices of blocks after eviction that should be + * rotated, and the pre-computed trigonometric coefficients necessary for rotation. + */ + std::vector get_rotation_coefficients(const std::set& evicted_block_logical_indices, + size_t num_logical_blocks_before_eviction); + + /** + * @return The size of the embedding dimension that this CacheRotationCalculator was initialized with. + */ + size_t get_head_size() const { + return m_head_size; + } + +private: + size_t m_block_size; + size_t m_head_size; + std::vector> m_rope_sin_lut; // dimensions: [ max_context_length, head_size / 2] + std::vector> m_rope_cos_lut; // dimensions: [ max_context_length, head_size / 2] +}; } diff --git a/src/cpp/src/cache_manager.hpp b/src/cpp/src/cache_manager.hpp index a7444555ab..5d93e69895 100644 --- a/src/cpp/src/cache_manager.hpp +++ b/src/cpp/src/cache_manager.hpp @@ -25,6 +25,7 @@ class CacheManager { m_value_cache.reserve(m_device_config.get_num_layers()); const std::string device_name = device_config.get_device(); + std::cout << "VSHAMPOR: cache precision is " << device_config.get_cache_precision() << std::endl; if (device_name.find("GPU") == std::string::npos) {// Allocate KV caches for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) { ov::Tensor key_cache(device_config.get_cache_precision(), device_config.get_key_cache_shape()); diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 73bf4ec083..fc629b39d4 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -70,9 +70,36 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init( } m_scheduler = std::make_shared(device_config.get_block_size(), updated_config, device_config.get_num_layers(), can_use_partial_preemption); + // and finally create model runner bool is_use_cache_eviction = m_scheduler->get_config().use_cache_eviction; - m_model_runner = std::make_shared(infer_request, m_scheduler->get_block_size(), device_config.get_num_layers(), is_use_cache_eviction); + if (is_use_cache_eviction) { + m_model_runner = std::make_shared(infer_request, + m_scheduler->get_block_size(), + device_config.get_num_layers(), + /* collect_attention_scores = */ true, + /* is_use_per_layer_cache_control = */ true); + m_rotation_coefficient_stores.reserve(device_config.get_num_layers()); + ov::Shape rotation_coefficient_store_shape{device_config.get_head_size() * + (m_scheduler->get_block_size() * scheduler_config.num_kv_blocks)}; + for (size_t i = 0; i < device_config.get_num_layers(); i++) { + ov::Tensor store(ov::element::f32, rotation_coefficient_store_shape); + std::memset(store.data(), 0, store.get_byte_size()); + m_rotation_coefficient_stores.push_back(store); + } + m_next_step_rotation_coefficients.resize(device_config.get_num_layers()); + m_next_step_rotated_block_logical_indices_per_sequence.resize(device_config.get_num_layers()); + m_cache_rotation_calculator = std::make_shared( + m_scheduler->get_block_size(), + // TODO (vshampor): LUT size equal to max cache size in tokens + // is overkill - find a way to pass the max sequence length instead + m_scheduler->get_block_size() * scheduler_config.num_kv_blocks, + device_config.get_head_size()); + } else { + m_model_runner = + std::make_shared(infer_request, m_scheduler->get_block_size(), device_config.get_num_layers()); + } + m_sampler = std::make_shared(m_tokenizer); m_sampler->set_seed(m_generation_config.rng_seed); @@ -196,6 +223,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() { // evict unimportant blocks from KV cache, if requested if (sched_config.use_cache_eviction) { maybe_evict_cache_blocks(sched_config); + m_model_runner->set_cache_rotation_data(std::move(m_next_step_rotation_coefficients), + std::move(m_next_step_rotated_block_logical_indices_per_sequence)); } #ifdef DEBUG_CACHE_STATE_DUMP @@ -378,12 +407,21 @@ float ContinuousBatchingPipeline::ContinuousBatchingImpl::_get_current_running_a void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_blocks(const SchedulerConfig& sched_config) { std::unordered_map seq_group_to_num_blocks_evicted_map; auto sequence_attention_scores = m_model_runner->get_last_attention_scores(); + + OPENVINO_ASSERT(!sequence_attention_scores.empty()); + size_t num_decoder_layers = sequence_attention_scores.begin()->second.size(); + std::vector num_blocks_to_rotate_for_each_layer(num_decoder_layers, 0); + size_t head_size = m_cache_rotation_calculator->get_head_size(); + + // necessary since we move from these members during previous steps + m_next_step_rotation_coefficients.clear(); + m_next_step_rotated_block_logical_indices_per_sequence.clear(); + m_next_step_rotated_block_logical_indices_per_sequence.resize(num_decoder_layers); + for (auto& seq_id_and_attention_scores : sequence_attention_scores) { auto seq_id = seq_id_and_attention_scores.first; const auto& attention_scores_for_all_decoder_layers = seq_id_and_attention_scores.second; if (m_seq_group_id_to_cache_eviction_algo_map.find(seq_id) == m_seq_group_id_to_cache_eviction_algo_map.end()) { - auto num_decoder_layers = attention_scores_for_all_decoder_layers.size(); - m_seq_group_id_to_cache_eviction_algo_map[seq_id] = CacheEvictionAlgorithm(sched_config.cache_eviction_config, m_scheduler->get_block_size(), num_decoder_layers); } auto& cache_eviction_algo = m_seq_group_id_to_cache_eviction_algo_map[seq_id]; @@ -391,6 +429,43 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_block cache_eviction_algo.register_new_token_scores(attention_scores_for_all_decoder_layers); auto logical_blocks_to_evict = cache_eviction_algo.evict_logical_blocks(); + for (size_t layer_idx = 0; layer_idx < logical_blocks_to_evict.size(); layer_idx++) { + if (logical_blocks_to_evict[layer_idx].empty()) { + continue; + } + size_t num_blocks_before_eviction = m_scheduler->get_block_tables(seq_id)[layer_idx].size(); + auto rotation_multipliers = + m_cache_rotation_calculator->get_rotation_coefficients(logical_blocks_to_evict[layer_idx], + num_blocks_before_eviction); + for (size_t i = 0; i < rotation_multipliers.size(); i++) { + const auto& block_rotation_data = rotation_multipliers[i]; + const auto& rotation_multipliers_cos = block_rotation_data.cosines; + const auto& rotation_multipliers_sin = block_rotation_data.sines; + OPENVINO_ASSERT(rotation_multipliers_cos.size() == rotation_multipliers_sin.size()); + OPENVINO_ASSERT(rotation_multipliers_cos.size() == m_scheduler->get_block_size()); + + m_next_step_rotated_block_logical_indices_per_sequence[layer_idx][seq_id].push_back( + block_rotation_data.logical_block_idx); + + // Fill the store tensor with rotation coefficient data - cos and sin coefficients are each contiguous, + // cos goes first + size_t block_offset = + num_blocks_to_rotate_for_each_layer[layer_idx] * m_scheduler->get_block_size() * head_size; + auto rotation_multipliers_tensor_data = + m_rotation_coefficient_stores[layer_idx].data() + block_offset; + for (size_t tok_idx = 0; tok_idx < rotation_multipliers_cos.size(); tok_idx++) { + size_t position_offset = head_size * tok_idx; + for (size_t embedding_pair_idx = 0; embedding_pair_idx < head_size / 2; embedding_pair_idx++) { + rotation_multipliers_tensor_data[position_offset + embedding_pair_idx] = + rotation_multipliers_cos[tok_idx][embedding_pair_idx]; + rotation_multipliers_tensor_data[position_offset + embedding_pair_idx + head_size / 2] = + rotation_multipliers_sin[tok_idx][embedding_pair_idx]; + } + } + num_blocks_to_rotate_for_each_layer[layer_idx] += 1; + } + } + m_scheduler->free_blocks_from_sequence(seq_id, logical_blocks_to_evict); auto seq_group_ptr_it = std::find_if(m_requests.begin(), m_requests.end(), [seq_id](const SequenceGroup::Ptr& val) { return val->has_sequence_with_id(seq_id); }); @@ -405,6 +480,15 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_block } } + + // Select the previously filled rotation coefficients from the store tensor + for (size_t i = 0; i < num_decoder_layers; i++) { + m_next_step_rotation_coefficients.emplace_back( + m_rotation_coefficient_stores[i], + ov::Coordinate{0}, + ov::Coordinate{num_blocks_to_rotate_for_each_layer[i] * m_scheduler->get_block_size() * head_size}); + } + for (const auto& seq_group_ptr_and_num_blocks_evicted : seq_group_to_num_blocks_evicted_map) { // Assuming that the evicted blocks are always full (since they by design are only selected from intermediate-age blocks) auto seq_group_ptr = seq_group_ptr_and_num_blocks_evicted.first; diff --git a/src/cpp/src/continuous_batching_impl.hpp b/src/cpp/src/continuous_batching_impl.hpp index 8276edb36b..69a3a1fae2 100644 --- a/src/cpp/src/continuous_batching_impl.hpp +++ b/src/cpp/src/continuous_batching_impl.hpp @@ -26,10 +26,22 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc static const size_t AVG_CACHE_USAGE_WINDOW_SIZE_IN_STEPS = 1000; std::deque m_previous_step_cache_usages; - + // flag to enable validation mode for sampler bool m_is_validation_mode_enabled = false; + // Pre-allocated per-layer storages for the per-token cache re-rotation coefficients used in cache eviction case + std::vector m_rotation_coefficient_stores; + + // Per-layer ROI tensors, reusing storage from the pre-allocated tensors above, that actually represent the + // re-rotation coefficients to be sent to the proper model inputs at the *next* pipeline step. + std::vector m_next_step_rotation_coefficients; + + using SeqIdToRotatedLogicalBlocksMap = std::map>; + std::vector m_next_step_rotated_block_logical_indices_per_sequence; + + std::shared_ptr m_cache_rotation_calculator; + #ifdef DEBUG_CACHE_STATE_DUMP size_t step_count = 0; #endif @@ -86,4 +98,4 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc const std::vector& sampling_params, const StreamerVariant& streamer) override; }; -} \ No newline at end of file +} // namespace ov::genai diff --git a/src/cpp/src/device_config.hpp b/src/cpp/src/device_config.hpp index 2af4559ef1..1245fce8f3 100644 --- a/src/cpp/src/device_config.hpp +++ b/src/cpp/src/device_config.hpp @@ -136,6 +136,10 @@ class DeviceConfig { return m_num_decoder_layers; } + size_t get_head_size() const { + return m_head_size; + } + ov::Shape get_key_cache_shape() const { OPENVINO_ASSERT(!m_key_cache_shape.empty()); return m_key_cache_shape; diff --git a/src/cpp/src/model_runner.hpp b/src/cpp/src/model_runner.hpp index 1b96cdc505..c147440748 100644 --- a/src/cpp/src/model_runner.hpp +++ b/src/cpp/src/model_runner.hpp @@ -32,20 +32,31 @@ class ModelRunner { AttentionScoresForEachSubsequence m_last_attention_scores; size_t m_num_decoder_layers, m_block_size; bool m_collect_attention_scores; + bool m_is_use_per_layer_cache_control; + std::vector m_cache_rotation_coefficients; + std::vector>> m_rotated_block_logical_indices_per_sequence_for_each_layer; + public: /** * Constructs the ModelRunner. * @param request The ov::InferRequest for the LLM to be inferred in the continuous batching mode. - * @param scheduler_config Configuration struct for the scheduler that is to be used with this ModelRunner. * @param num_decoder_layers Number of decoder attention layers in the LLM corresponding to the request. - * @param collect_attention_scores If true, then after each `forward` call the ModelRunner will collect and make available the per-token attention - * scores for each decoder layer, so that these can be used in per-step cache optimizations (such as cache eviction algorithm). + * @param collect_attention_scores If true, then after each `forward` call the ModelRunner will collect and make + * available the per-token attention scores for each decoder layer, so that these can be used in per-step cache + * optimizations (such as cache eviction algorithm). + * @param is_use_per_layer_cache_control If true, then the runner will pass cache control input tensors to the model + * on a per-attention layer basis. */ - ModelRunner(ov::InferRequest request, size_t block_size, size_t num_decoder_layers = 1, bool collect_attention_scores = false) : - m_request(std::move(request)), - m_block_size(block_size), - m_num_decoder_layers(num_decoder_layers), - m_collect_attention_scores(collect_attention_scores) { + ModelRunner(ov::InferRequest request, + size_t block_size, + size_t num_decoder_layers = 1, + bool collect_attention_scores = false, + bool is_use_per_layer_cache_control = false) + : m_request(std::move(request)), + m_block_size(block_size), + m_num_decoder_layers(num_decoder_layers), + m_collect_attention_scores(collect_attention_scores), + m_is_use_per_layer_cache_control(is_use_per_layer_cache_control) { OPENVINO_ASSERT(m_num_decoder_layers != 0, "num_decoder_layers must be non-zero"); } @@ -65,6 +76,14 @@ class ModelRunner { return m_last_attention_scores; } + void set_cache_rotation_data(std::vector&& cache_rotation_coefficients_for_each_layer, + const std::vector>>&& + rotated_logical_block_indices_per_sequence_for_each_layer) { + m_cache_rotation_coefficients = std::move(cache_rotation_coefficients_for_each_layer); + m_rotated_block_logical_indices_per_sequence_for_each_layer = + std::move(rotated_logical_block_indices_per_sequence_for_each_layer); + } + /** * Runs the forward inference call on the underlying LLM's ov::InferRequest, scheduling for inferencing tokens for given sequences * taking into account the supplied scheduler output struct. @@ -163,7 +182,11 @@ class ModelRunner { m_request.set_tensor("past_lens", past_lens); m_request.set_tensor("subsequence_begins", subsequence_begins); - _set_block_indices(m_request, sequence_groups, scheduler_output, total_num_blocks); + _set_block_indices(sequence_groups, scheduler_output, total_num_blocks); + + if (!m_cache_rotation_coefficients.empty()) { + _set_cache_rotation_coefficients(sequence_groups, scheduler_output); + } m_request.set_tensor("block_indices_begins", block_indices_begins); m_request.set_tensor("max_context_len", max_context_len); @@ -188,17 +211,75 @@ class ModelRunner { _collect_attention_scores(sequence_groups, scheduler_output); } + m_cache_rotation_coefficients.clear(); + // return logits return m_request.get_tensor("logits"); } private: - void _set_block_indices(ov::InferRequest& infer_request, const std::vector & sequence_groups, const Scheduler::Output& scheduler_output, - size_t total_num_blocks) { + void _fill_indices_from_block_tables( + const std::vector& dst_tensor_names, + const std::vector& sequence_groups, + const Scheduler::Output& scheduler_output, + const std::vector>>& seq_id_to_select_logical_idx_maps) { + OPENVINO_ASSERT(seq_id_to_select_logical_idx_maps.size() == dst_tensor_names.size() || + seq_id_to_select_logical_idx_maps.empty()); + bool is_fill_all = seq_id_to_select_logical_idx_maps.empty(); size_t num_sequence_groups = scheduler_output.m_scheduled_sequence_groups_ids.size(); + std::vector block_offsets_per_layer(dst_tensor_names.size(), 0); + + for (size_t i = 0; i < num_sequence_groups; ++i) { + size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; + SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id]; + std::vector running_sequences = sequence_group->get_running_sequences(); + size_t num_running_sequences = running_sequences.size(); + + for (size_t i = 0; i < num_running_sequences; ++i) { + Sequence::CPtr sequence = running_sequences[i]; + size_t seq_id = sequence->get_id(); + + size_t num_blocks = + (sequence_group->get_context_len() - sequence_group->get_num_evicted_tokens() + m_block_size - 1) / + m_block_size; + const auto& kv_blocks = scheduler_output.m_block_tables.at(sequence->get_id()); + + for (size_t layer_idx = 0; layer_idx < dst_tensor_names.size(); layer_idx++) { + auto input_tensor = m_request.get_tensor(dst_tensor_names[layer_idx]); + auto block_indices_data = input_tensor.data() + block_offsets_per_layer[layer_idx]; + + if (is_fill_all) { + for (size_t block_id = 0; block_id < num_blocks; ++block_id) { + // In case no cache eviction is requested, all per-layer block tables are expected to be + // identical at all times + block_indices_data[block_id] = kv_blocks[layer_idx][block_id]->get_index(); + } + block_offsets_per_layer[layer_idx] += num_blocks; + } else { + auto seq_id_to_select_logical_idx_map = seq_id_to_select_logical_idx_maps[layer_idx]; + auto it = seq_id_to_select_logical_idx_map.find(seq_id); + if (it == seq_id_to_select_logical_idx_map.end()) { + continue; + } + auto select_logical_idxs = it->second; + for (size_t block_id = 0; block_id < select_logical_idxs.size(); ++block_id) { + size_t logical_block_idx = select_logical_idxs[block_id]; + OPENVINO_ASSERT(logical_block_idx < num_blocks); + block_indices_data[block_id] = kv_blocks[layer_idx][logical_block_idx]->get_index(); + } + + block_offsets_per_layer[layer_idx] += select_logical_idxs.size(); + } + } + } + } + } + void _set_block_indices(const std::vector& sequence_groups, + const Scheduler::Output& scheduler_output, + size_t total_num_blocks) { std::vector tensor_names = {"block_indices"}; - if (m_collect_attention_scores) { + if (m_is_use_per_layer_cache_control) { tensor_names.resize(m_num_decoder_layers); for (size_t i = 0; i < tensor_names.size(); i++) { tensor_names[i] = std::string("block_indices.") + std::to_string(i); @@ -209,31 +290,34 @@ class ModelRunner { m_request.get_tensor(name).set_shape({total_num_blocks}); } - size_t block_offset = 0; - for (size_t i = 0; i < num_sequence_groups; ++i) { - size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i]; - SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id]; - std::vector running_sequences = sequence_group->get_running_sequences(); - size_t num_running_sequences = running_sequences.size(); - - for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id) { - Sequence::CPtr sequence = running_sequences[seq_id]; + _fill_indices_from_block_tables(tensor_names, sequence_groups, scheduler_output, {}); + } - size_t num_blocks = (sequence_group->get_context_len() - sequence_group->get_num_evicted_tokens() + m_block_size - 1) / m_block_size; - const auto & kv_blocks = scheduler_output.m_block_tables.at(sequence->get_id()); - - for (size_t layer_idx = 0; layer_idx < tensor_names.size(); layer_idx++) { - auto input_tensor = infer_request.get_tensor(tensor_names[layer_idx]); - auto block_indices_data = input_tensor.data() + block_offset; - for (size_t block_id = 0; block_id < num_blocks; ++block_id) - // In case no cache eviction is requested, all per-layer block tables are expected to be identical - // at all times - block_indices_data[block_id] = kv_blocks[layer_idx][block_id]->get_index(); - } + void _set_cache_rotation_coefficients(const std::vector& sequence_groups, + const Scheduler::Output& scheduler_output) { + for (size_t i = 0; i < m_num_decoder_layers; i++) { + auto tensor_name = std::string("rotation_coefficients.") + std::to_string(i); + m_request.set_tensor(tensor_name, m_cache_rotation_coefficients[i]); + } - block_offset += num_blocks; + std::vector rotation_indices_tensor_names(m_num_decoder_layers); + for (size_t i = 0; i < m_num_decoder_layers; i++) { + auto tensor_name = std::string("rotated_block_indices.") + std::to_string(i); + rotation_indices_tensor_names[i] = tensor_name; + size_t num_indices = 0; + for (const auto& entry : m_rotated_block_logical_indices_per_sequence_for_each_layer[i]) { + num_indices += entry.second.size(); } + auto rotated_block_indices_tensor = m_request.get_tensor(tensor_name); + rotated_block_indices_tensor.set_shape({num_indices}); } + + // NB: the order of per-sequence index filling in the function below must be the same + // as the order of `seq_id`s in which the "rotation_coefficients.N" inputs are filled + _fill_indices_from_block_tables(rotation_indices_tensor_names, + sequence_groups, + scheduler_output, + m_rotated_block_logical_indices_per_sequence_for_each_layer); } void _collect_attention_scores(const std::vector & sequence_groups, const Scheduler::Output& scheduler_output) { diff --git a/src/cpp/src/scheduler.hpp b/src/cpp/src/scheduler.hpp index 6de4adaa47..e7d0f2f060 100644 --- a/src/cpp/src/scheduler.hpp +++ b/src/cpp/src/scheduler.hpp @@ -83,6 +83,10 @@ class Scheduler { return m_block_manager.get_block_size(); } + const std::vector& get_block_tables(size_t seq_id) const { + return m_block_manager.get_block_tables(seq_id); + } + const bool has_block_table(uint64_t seq_id) { return m_block_manager.has_block_table(seq_id); } diff --git a/src/cpp/src/utils/paged_attention_transformations.cpp b/src/cpp/src/utils/paged_attention_transformations.cpp index 53690f770c..77a66f9abe 100644 --- a/src/cpp/src/utils/paged_attention_transformations.cpp +++ b/src/cpp/src/utils/paged_attention_transformations.cpp @@ -32,7 +32,9 @@ void apply_paged_attention_transformations(std::shared_ptr model, boo bool use_block_indices_inputs = per_layer_cache_control; bool use_score_outputs = per_layer_cache_control; - ov::pass::SDPAToPagedAttention(use_block_indices_inputs, use_score_outputs).run_on_model(model); + bool allow_cache_rotation = per_layer_cache_control; + ov::pass::SDPAToPagedAttention(use_block_indices_inputs, use_score_outputs, allow_cache_rotation) + .run_on_model(model); } void set_kv_cache_type_and_shape(std::shared_ptr model, DeviceConfig& device_config) { @@ -80,4 +82,4 @@ void apply_paged_attention_transformations(std::shared_ptr model, Dev } // namespace utils } // namespace genai -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index f404e63cff..82c227d3c3 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -26,6 +26,6 @@ file(GLOB src_files "${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/sequence_group.cpp" add_executable(${TEST_TARGET_NAME} ${tests_src} block_allocator.cpp) -target_link_libraries(${TEST_TARGET_NAME} PRIVATE openvino::genai gtest_main) +target_link_libraries(${TEST_TARGET_NAME} PRIVATE openvino::genai gtest_main gmock_main) target_include_directories(${TEST_TARGET_NAME} PRIVATE "${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src") target_sources(${TEST_TARGET_NAME} PRIVATE ${src_files}) diff --git a/tests/cpp/cache_eviction.cpp b/tests/cpp/cache_eviction.cpp index 026a7cbe64..8fd3404c5e 100644 --- a/tests/cpp/cache_eviction.cpp +++ b/tests/cpp/cache_eviction.cpp @@ -2,9 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 #include "cache_eviction.hpp" -#include "gtest/gtest.h" #include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" const ov::genai::CacheEvictionConfig DEFAULT_CACHE_EVICTION_CONFIG = {32, 32, 192, ov::genai::AggregationMode::NORM_SUM}; const ov::genai::CacheEvictionConfig SHORT_RECENT_EVICTION_CONFIG = {32, 32, 72, ov::genai::AggregationMode::NORM_SUM}; @@ -145,6 +148,8 @@ struct LowScoreBlocksTestStruct { }; using CacheEvictionLowScoreBlocksParameterizedTest = ::testing::TestWithParam; + +// clang-format off const std::vector LOW_SCORE_BLOCK_EVICTION_TEST_CASES = { // low-scored blocks in evictable area { @@ -197,6 +202,7 @@ const std::vector LOW_SCORE_BLOCK_EVICTION_TEST_CASES {{8, 9, 10, 11, 13}, {8, 9, 10, 11, 12}} }, }; +// clang-format on TEST_P(CacheEvictionLowScoreBlocksParameterizedTest, EvictsLowestScoredBlocks) { auto test_struct = GetParam(); @@ -241,6 +247,8 @@ struct NormalizationSettingTestStruct { }; using CacheEvictionNormalizationSettingTest = ::testing::TestWithParam; + +// clang-format off const std::vector NORMALIZATION_SETTING_TEST_CASES = { // power of 1.1 beats the 1 / N in the normalization, low-score blocks are in the end of the evictable area { ov::genai::AggregationMode::NORM_SUM, 1.1, false, { 40, 41, 42} }, @@ -260,6 +268,7 @@ const std::vector NORMALIZATION_SETTING_TEST_CAS { ov::genai::AggregationMode::SUM, 1.1, false, { 40, 41, 42} }, { ov::genai::AggregationMode::SUM, 1.1, true, { 8, 9, 10} }, }; +// clang-format on TEST_P(CacheEvictionNormalizationSettingTest, TokenLifetimeNormalizationHasEffect) { const auto& test_struct = GetParam(); @@ -420,3 +429,338 @@ TEST_P(CacheEvictionAlgoInitializationTest, ThrowsForInvalidConfigs) { INSTANTIATE_TEST_SUITE_P(VariousInvalidInitParams, CacheEvictionAlgoInitializationTest, ::testing::ValuesIn(INVALID_ALGO_INIT_PARAMS_CASES)); + +TEST(CacheRotationCalculatorTest, CanInitializeWithBasicParams) { + EXPECT_NO_THROW(ov::genai::CacheRotationCalculator(32, 128, 64)); +} + +TEST(CacheRotationCalculatorTest, ThrowsForNonPositiveTheta) { + EXPECT_THROW(ov::genai::CacheRotationCalculator(32, 128, 64, -1.0), ov::Exception); + EXPECT_THROW(ov::genai::CacheRotationCalculator(32, 128, 64, 0.0), ov::Exception); +} + +struct CacheRotationCalculatorInitParams { + size_t block_size; + size_t max_context_length; + size_t kv_head_size; + double rope_theta; +}; + +struct CacheRotationCalculatorInputTestStruct { + CacheRotationCalculatorInitParams init_params; + std::set evicted_block_logical_indices; + size_t num_logical_blocks_before_eviction; +}; + +using CacheRotationCalculatorInvalidInputParameterizedTest = + ::testing::TestWithParam; + +// clang-format off +const std::vector CACHE_ROTATION_CALCULATOR_INVALID_INPUT_TEST_CASES = { + { // more num_logical_blocks_before_eviction than possible by max_context_length + {8, 16, 4, 1337.0}, + {1, 2, 6}, + 32 + }, + { // evicted block index out of bounds + {16, 256, 32, 665.0}, + {8, 0, 5,50}, + 9 + }, + { // more blocks attempted to evict than num_logical_blocks_before_eviction + {16, 256, 32, 665.0}, + {0, 1, 2}, + 2 + } +}; +// clang-format on + +TEST_P(CacheRotationCalculatorInvalidInputParameterizedTest, ThrowsForInvalidEvictedBlocksInput) { + const auto& test_struct = GetParam(); + const auto& init_params = test_struct.init_params; + + auto calc = ov::genai::CacheRotationCalculator(init_params.block_size, + init_params.max_context_length, + init_params.kv_head_size, + init_params.rope_theta); + EXPECT_THROW(calc.get_rotation_coefficients(test_struct.evicted_block_logical_indices, + test_struct.num_logical_blocks_before_eviction), + ov::Exception); +} + +INSTANTIATE_TEST_SUITE_P(VariousInputsAndInitParams, + CacheRotationCalculatorInvalidInputParameterizedTest, + testing::ValuesIn(CACHE_ROTATION_CALCULATOR_INVALID_INPUT_TEST_CASES)); + +struct CacheRotationCalculatorNumCoefficientsTestStruct { + CacheRotationCalculatorInitParams init_params; + std::set evicted_block_logical_indices; + size_t num_logical_blocks_before_eviction; + size_t expected_num_rotated_blocks; +}; + +// clang-format off +const std::vector CACHE_ROTATION_CALCULATOR_VALID_INPUT_TEST_CASES = { + { + {8, 512, 4, 1337.0}, + {1, 2}, + 7, + 4 + }, + { + {16, 256, 32, 665.0}, + {8, 0, 5, 3}, + 9, + 5 + }, + { // more blocks attempted to evict than num_logical_blocks_before_eviction + {16, 1024, 32, 665.0}, + {24, 25, 26, 27, 28}, + 30, + 1 + } +}; +// clang-format on + +using CacheRotationCalculatorNumCoefficientsParameterizedTest = + ::testing::TestWithParam; + +TEST_P(CacheRotationCalculatorNumCoefficientsParameterizedTest, GivesCorrectNumberOfRotationMultipliers) { + const auto& test_struct = GetParam(); + const auto& init_params = test_struct.init_params; + + auto calc = ov::genai::CacheRotationCalculator(init_params.block_size, + init_params.max_context_length, + init_params.kv_head_size, + init_params.rope_theta); + + const auto rotation_multipliers = calc.get_rotation_coefficients(test_struct.evicted_block_logical_indices, + test_struct.num_logical_blocks_before_eviction); + + ASSERT_EQ(rotation_multipliers.size(), test_struct.expected_num_rotated_blocks); + for (const auto& block_rotation_data : rotation_multipliers) { + EXPECT_EQ(block_rotation_data.cosines.size(), block_rotation_data.sines.size()); + EXPECT_EQ(block_rotation_data.cosines.size(), init_params.block_size); + for (const auto& token_coefficients : block_rotation_data.cosines) { + EXPECT_EQ(token_coefficients.size(), init_params.kv_head_size / 2); + } + for (const auto& token_coefficients : block_rotation_data.sines) { + EXPECT_EQ(token_coefficients.size(), init_params.kv_head_size / 2); + } + } +} + +INSTANTIATE_TEST_SUITE_P(VariousInputsAndInitParams, + CacheRotationCalculatorNumCoefficientsParameterizedTest, + testing::ValuesIn(CACHE_ROTATION_CALCULATOR_VALID_INPUT_TEST_CASES)); + +struct CacheRotationCalculatorRefCoefficientsTestStruct { + CacheRotationCalculatorInitParams init_params; + std::set evicted_block_logical_indices; + size_t num_logical_blocks_before_eviction; + std::vector expected_rotation_data; +}; + +// clang-format off +const std::vector CACHE_ROTATION_CALCULATOR_REF_COEFFICIENTS_TEST_CASES = { + // 0 + { + {2, 512, 4, 1.0}, + {1, 2}, + 4, + // pre-eviction block 3 rotated left by 2 blocks, coefficients are cos(4) and -sin(4) due to theta == 1.0 + { + {1, + { + {0.75680249, 0.75680249}, // block token 0 + {0.75680249, 0.75680249} // block token 1 + }, + { + {-0.65364362, -0.65364362}, // block token 0 + {-0.65364362, -0.65364362} // block token 1 + }, + } + } + }, + + // 1 - same as 0, but adjusted theta + { + {2, 512, 4, 2.0}, + {1, 2}, + 4, + // coefficients are [cos(4 / 1), -sin(4 / 1)], [cos(4 / sqrt(2)), -sin(4 / sqrt(2))] now + { + {1, + { + {0.75680249, -0.30807174}, // block token 0 + {0.75680249, -0.30807174} // block token 1 + }, + { + {-0.65364362, -0.95136312}, // block token 0 + {-0.65364362, -0.95136312} // block token 1 + }, + } + } + }, + // 2 - same as 0, but corner case blocks + { + {2, 512, 4, 2.0}, + {0, 3}, + 4, + // delta of 2 tokens for both blocks + // coefficients are [cos(2 / 1), -sin(2 / 1)], [cos(2 / sqrt(2)), -sin(2 / sqrt(2))] + { + {0, + { + {-0.90929742, -0.98776594}, // block token 0 + {-0.90929742, -0.98776594} // block token 1 + }, + { + {-0.41614683, 0.15594369}, // block token 0 + {-0.41614683, 0.15594369} // block token 1 + }, + }, + {1, + { + {-0.90929742, -0.98776594}, // block token 0 + {-0.90929742, -0.98776594} // block token 1 + }, + { + {-0.41614683, 0.15594369}, // block token 0 + {-0.41614683, 0.15594369} // block token 1 + }, + } + } + }, + // 3 - same as 0, but different deltas for each rotated block + { + {2, 512, 4, 2.0}, + {0, 2}, + 4, + // delta of 2 tokens for first remaining block: + // coefficients are [cos(2 / 1), -sin(2 / 1)], [cos(2 / sqrt(2)), -sin(2 / sqrt(2))] + // and 4 tokens for second remaining block + // coefficients are [cos(4 / 1), -sin(4 / 1)], [cos(4 / sqrt(2)), -sin(4 / sqrt(2))] + { + {0, + { + {-0.90929742, -0.98776594}, // block token 0 + {-0.90929742, -0.98776594} // block token 1 + }, + { + {-0.41614683, 0.15594369}, // block token 0 + {-0.41614683, 0.15594369} // block token 1 + }, + }, + {1, + { + {0.75680249, -0.30807174}, // block token 0 + {0.75680249, -0.30807174} // block token 1 + }, + { + {-0.65364362, -0.95136312}, // block token 0 + {-0.65364362, -0.95136312} // block token 1 + }, + } + } + }, +}; +// clang-format on + +using CacheRotationCalculatorRefCoefficientsParameterizedTest = + ::testing::TestWithParam; + +void compare_rotation_data(const std::vector& test_data, + const std::vector& ref_data, + double abs_tol = 1e-8) { + ASSERT_EQ(test_data.size(), ref_data.size()); + + for (size_t i = 0; i < test_data.size(); i++) { + const auto& test_block_data = test_data[i]; + const auto& ref_block_data = ref_data[i]; + EXPECT_EQ(test_block_data.logical_block_idx, ref_block_data.logical_block_idx); + + ASSERT_EQ(test_block_data.sines.size(), ref_block_data.sines.size()); + for (size_t j = 0; j < test_block_data.sines.size(); j++) { + EXPECT_THAT(test_block_data.sines[j], + ::testing::Pointwise(::testing::DoubleNear(abs_tol), ref_block_data.sines[j])); + } + + ASSERT_EQ(test_block_data.cosines.size(), ref_block_data.cosines.size()); + for (size_t j = 0; j < test_block_data.cosines.size(); j++) { + EXPECT_THAT(test_block_data.cosines[j], + ::testing::Pointwise(::testing::DoubleNear(abs_tol), ref_block_data.cosines[j])); + } + } +} + +TEST_P(CacheRotationCalculatorRefCoefficientsParameterizedTest, CalculatedCoefficientsMatchToReference) { + const auto& test_struct = GetParam(); + const auto& init_params = test_struct.init_params; + + auto calc = ov::genai::CacheRotationCalculator(init_params.block_size, + init_params.max_context_length, + init_params.kv_head_size, + init_params.rope_theta); + + const auto rotation_multipliers = calc.get_rotation_coefficients(test_struct.evicted_block_logical_indices, + test_struct.num_logical_blocks_before_eviction); + + compare_rotation_data(rotation_multipliers, test_struct.expected_rotation_data); +} + +INSTANTIATE_TEST_SUITE_P(VariousInputsAndInitParams, + CacheRotationCalculatorRefCoefficientsParameterizedTest, + testing::ValuesIn(CACHE_ROTATION_CALCULATOR_REF_COEFFICIENTS_TEST_CASES)); + +TEST(CacheRotationCalculatorPOCRefCoefficientsTest, CalculatedCoefficientsAreSimilarToPOCResults) { + std::ifstream input_file("tests/cpp/data/cache_rotation_poc_ref_coefficients_per_block.txt", std::ios::in); + + const size_t ref_block_size = 16; + const size_t ref_max_context_length = 1024; + const size_t ref_head_size = 64; + + auto calc = ov::genai::CacheRotationCalculator(ref_block_size, ref_max_context_length, ref_head_size); + size_t num_blocks_before_eviction = 0; + std::set ref_evicted_logical_block_indices; + std::vector ref_data; + + input_file >> num_blocks_before_eviction; + size_t num_evicted_blocks; + input_file >> num_evicted_blocks; + for (size_t i = 0; i < num_evicted_blocks; i++) { + size_t evicted_block_idx = 0; + input_file >> evicted_block_idx; + ref_evicted_logical_block_indices.insert(evicted_block_idx); + } + + size_t num_rotated_blocks = 0; + input_file >> num_rotated_blocks; + ref_data.resize(num_rotated_blocks); + + for (size_t i = 0; i < num_rotated_blocks; i++) { + size_t logical_block_idx_after_eviction = 0; + input_file >> logical_block_idx_after_eviction; + ref_data[i].logical_block_idx = logical_block_idx_after_eviction; + std::vector coeffts(ref_head_size / 2); + + for (size_t j = 0; j < ref_head_size / 2; j++) { + input_file >> coeffts[j]; + } + ref_data[i].sines.resize(ref_block_size); + for (size_t k = 0; k < ref_block_size; k++) { + ref_data[i].sines[k] = coeffts; + } + + for (size_t j = 0; j < ref_head_size / 2; j++) { + input_file >> coeffts[j]; + } + ref_data[i].cosines.resize(ref_block_size); + for (size_t k = 0; k < ref_block_size; k++) { + ref_data[i].cosines[k] = coeffts; + } + } + + auto test_data = calc.get_rotation_coefficients(ref_evicted_logical_block_indices, num_blocks_before_eviction); + compare_rotation_data(test_data, ref_data, 1e-2); // the dump values were originally calculated in FP16 precision +} diff --git a/tests/cpp/data/cache_rotation_poc_ref_coefficients_per_block.txt b/tests/cpp/data/cache_rotation_poc_ref_coefficients_per_block.txt new file mode 100644 index 0000000000..da68dcba87 --- /dev/null +++ b/tests/cpp/data/cache_rotation_poc_ref_coefficients_per_block.txt @@ -0,0 +1,28 @@ +53 +27 +18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 +8 +18 +0.9990234375 0.36181640625 0.8564453125 0.039520263671875 0.99853515625 -0.9423828125 -0.98876953125 -0.8720703125 0.705078125 -0.830078125 0.74462890625 0.5908203125 -0.888671875 0.73095703125 -0.9853515625 0.4990234375 0.923828125 0.09765625 -0.6533203125 -0.96875 -0.97900390625 -0.8544921875 -0.69482421875 -0.544921875 -0.418701171875 -0.318115234375 -0.240478515625 -0.1812744140625 -0.13623046875 -0.102294921875 -0.07672119140625 -0.05755615234375 +0.0311431884765625 -0.93212890625 -0.515625 0.9990234375 -0.048828125 -0.335205078125 0.1468505859375 0.4892578125 0.708984375 0.5576171875 0.66748046875 0.806640625 0.45849609375 -0.6826171875 0.1710205078125 0.86669921875 -0.382568359375 -0.99560546875 -0.7568359375 -0.2484130859375 0.2030029296875 0.51953125 0.71923828125 0.83837890625 0.908203125 0.9482421875 0.97021484375 0.9833984375 0.99072265625 0.99462890625 0.99658203125 0.99853515625 +19 +1.0 0.362060546875 0.85693359375 0.03955078125 0.99853515625 -0.9423828125 -0.9892578125 -0.87255859375 0.705078125 -0.83056640625 0.74462890625 0.59130859375 -0.88916015625 0.73095703125 -0.98486328125 0.4990234375 0.923828125 0.09796142578125 -0.6533203125 -0.96875 -0.97900390625 -0.8544921875 -0.69482421875 -0.544921875 -0.418701171875 -0.318359375 -0.240478515625 -0.18115234375 -0.1361083984375 -0.10235595703125 -0.07672119140625 -0.057525634765625 +0.0311431884765625 -0.93212890625 -0.515625 0.9990234375 -0.0489501953125 -0.3349609375 0.1468505859375 0.48974609375 0.70947265625 0.55712890625 0.66748046875 0.80712890625 0.45849609375 -0.6826171875 0.1707763671875 0.86669921875 -0.38232421875 -0.99560546875 -0.7568359375 -0.2484130859375 0.203125 0.51953125 0.71923828125 0.83837890625 0.908203125 0.9482421875 0.970703125 0.9833984375 0.9912109375 0.9951171875 0.99658203125 0.998046875 +20 +0.99951171875 0.36181640625 0.8564453125 0.039337158203125 0.99853515625 -0.9423828125 -0.98876953125 -0.8720703125 0.70458984375 -0.83056640625 0.744140625 0.5908203125 -0.888671875 0.73046875 -0.9853515625 0.49853515625 0.92431640625 0.09765625 -0.6533203125 -0.96875 -0.97900390625 -0.8544921875 -0.6943359375 -0.54443359375 -0.418701171875 -0.318359375 -0.2406005859375 -0.1810302734375 -0.1363525390625 -0.10223388671875 -0.07672119140625 -0.05755615234375 +0.031036376953125 -0.9326171875 -0.515625 0.9990234375 -0.0491943359375 -0.335205078125 0.1468505859375 0.489501953125 0.708984375 0.5576171875 0.66748046875 0.806640625 0.458251953125 -0.6826171875 0.1712646484375 0.86669921875 -0.382568359375 -0.99560546875 -0.7568359375 -0.248291015625 0.203125 0.51953125 0.71923828125 0.83837890625 0.908203125 0.94775390625 0.970703125 0.9833984375 0.99072265625 0.9951171875 0.99755859375 0.99853515625 +21 +0.99951171875 0.361572265625 0.85693359375 0.03936767578125 0.99853515625 -0.9423828125 -0.98876953125 -0.8720703125 0.705078125 -0.83056640625 0.744140625 0.5908203125 -0.888671875 0.73095703125 -0.98583984375 0.498779296875 0.923828125 0.097412109375 -0.6533203125 -0.96875 -0.97900390625 -0.8544921875 -0.69482421875 -0.544921875 -0.41845703125 -0.318115234375 -0.240478515625 -0.180908203125 -0.1361083984375 -0.10235595703125 -0.07666015625 -0.057586669921875 +0.031005859375 -0.93212890625 -0.51611328125 0.9990234375 -0.048614501953125 -0.335205078125 0.1468505859375 0.489501953125 0.70947265625 0.5576171875 0.66796875 0.806640625 0.45849609375 -0.6826171875 0.1710205078125 0.86669921875 -0.38232421875 -0.99560546875 -0.7568359375 -0.248291015625 0.2034912109375 0.51953125 0.71923828125 0.8388671875 0.908203125 0.94775390625 0.970703125 0.9833984375 0.99072265625 0.9951171875 0.9970703125 0.998046875 +22 +0.99951171875 0.36181640625 0.8564453125 0.03936767578125 0.99853515625 -0.9423828125 -0.9892578125 -0.8720703125 0.705078125 -0.830078125 0.74462890625 0.5908203125 -0.888671875 0.73095703125 -0.98486328125 0.498779296875 0.923828125 0.097900390625 -0.65380859375 -0.96826171875 -0.97900390625 -0.8544921875 -0.69482421875 -0.54443359375 -0.41845703125 -0.318359375 -0.2406005859375 -0.18115234375 -0.1361083984375 -0.10223388671875 -0.07672119140625 -0.057586669921875 +0.03094482421875 -0.93212890625 -0.515625 0.9990234375 -0.04888916015625 -0.335205078125 0.1468505859375 0.4892578125 0.70947265625 0.55712890625 0.66796875 0.806640625 0.458251953125 -0.6826171875 0.1708984375 0.86669921875 -0.38232421875 -0.9951171875 -0.75634765625 -0.248291015625 0.2030029296875 0.51953125 0.71923828125 0.83837890625 0.90771484375 0.94775390625 0.970703125 0.9833984375 0.990234375 0.9951171875 0.9970703125 0.99853515625 +23 +0.99951171875 0.36181640625 0.85693359375 0.03936767578125 0.9990234375 -0.94189453125 -0.98876953125 -0.8720703125 0.70458984375 -0.830078125 0.74462890625 0.5908203125 -0.888671875 0.73046875 -0.9853515625 0.498779296875 0.923828125 0.09783935546875 -0.6533203125 -0.96875 -0.9794921875 -0.8544921875 -0.6953125 -0.54443359375 -0.41845703125 -0.318359375 -0.240478515625 -0.18115234375 -0.1361083984375 -0.102294921875 -0.07672119140625 -0.057647705078125 +0.031005859375 -0.93212890625 -0.515625 0.9990234375 -0.048797607421875 -0.334716796875 0.14697265625 0.4892578125 0.708984375 0.5576171875 0.66796875 0.80615234375 0.45849609375 -0.6826171875 0.1708984375 0.86669921875 -0.38232421875 -0.9951171875 -0.7568359375 -0.248291015625 0.203125 0.51953125 0.71923828125 0.83837890625 0.908203125 0.94775390625 0.97021484375 0.9833984375 0.99072265625 0.99462890625 0.9970703125 0.99853515625 +24 +0.99951171875 0.3623046875 0.8564453125 0.03955078125 0.99853515625 -0.9423828125 -0.9892578125 -0.8720703125 0.705078125 -0.83056640625 0.74462890625 0.5908203125 -0.888671875 0.73095703125 -0.9853515625 0.498779296875 0.923828125 0.097900390625 -0.6533203125 -0.96875 -0.97900390625 -0.8544921875 -0.69482421875 -0.54443359375 -0.4189453125 -0.318359375 -0.240478515625 -0.18115234375 -0.1361083984375 -0.10223388671875 -0.0767822265625 -0.05755615234375 +0.031097412109375 -0.93212890625 -0.51611328125 0.99951171875 -0.04888916015625 -0.3349609375 0.1466064453125 0.4892578125 0.708984375 0.5576171875 0.66796875 0.80712890625 0.4580078125 -0.6826171875 0.1707763671875 0.8671875 -0.382568359375 -0.99560546875 -0.75634765625 -0.248291015625 0.203369140625 0.52001953125 0.71875 0.83837890625 0.908203125 0.94775390625 0.97021484375 0.9833984375 0.99072265625 0.9951171875 0.99658203125 0.998046875 +25 +0.99951171875 0.36181640625 0.8564453125 0.039398193359375 0.9990234375 -0.94189453125 -0.9892578125 -0.87255859375 0.705078125 -0.830078125 0.74462890625 0.5908203125 -0.88916015625 0.73095703125 -0.9853515625 0.4990234375 0.92431640625 0.0977783203125 -0.6533203125 -0.96875 -0.97900390625 -0.8544921875 -0.6943359375 -0.54443359375 -0.418701171875 -0.318359375 -0.2403564453125 -0.1812744140625 -0.13623046875 -0.102294921875 -0.07684326171875 -0.057586669921875 +0.030853271484375 -0.93212890625 -0.515625 0.9990234375 -0.04901123046875 -0.3349609375 0.146728515625 0.489501953125 0.708984375 0.5576171875 0.66796875 0.806640625 0.458251953125 -0.68212890625 0.1707763671875 0.86669921875 -0.382568359375 -0.9951171875 -0.7568359375 -0.2481689453125 0.2034912109375 0.51953125 0.71875 0.83837890625 0.90771484375 0.94775390625 0.970703125 0.9833984375 0.9912109375 0.99462890625 0.99755859375 0.998046875 diff --git a/tests/python_tests/test_cache_optimizations.py b/tests/python_tests/test_cache_optimizations.py index a34e604382..18be429549 100644 --- a/tests/python_tests/test_cache_optimizations.py +++ b/tests/python_tests/test_cache_optimizations.py @@ -22,7 +22,7 @@ def load_prompts_dataset(file_name : str) -> Dict[str, List[str]]: file_path = TESTS_ROOT / 'data' / file_name with open(file_path, 'r') as f: - return {"prompts": [s for s in f]} + return {"questions": [s for s in f]} def get_scheduler_config(num_kv_blocks: int) -> SchedulerConfig: scheduler_config = SchedulerConfig() @@ -118,7 +118,7 @@ def test_cache_optimized_generation_is_similar_to_unoptimized(converted_model, t data_dict = load_prompts_dataset(test_struct.prompt_file) - evaluator = whowhatbench.TextEvaluator(base_model=model_cb_noopt, tokenizer=tokenizer, test_data=data_dict, + evaluator = whowhatbench.Evaluator(base_model=model_cb_noopt, tokenizer=tokenizer, test_data=data_dict, generation_config=generation_config, generation_config_base=generation_config, max_new_tokens=test_struct.max_new_tokens, seqs_per_request=seqs_per_request)