diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/SynchronizedShardedMap.h b/fbgemm_gpu/src/dram_kv_embedding_cache/SynchronizedShardedMap.h index 12d8be97b5..0dda6e6fcb 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/SynchronizedShardedMap.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/SynchronizedShardedMap.h @@ -8,8 +8,10 @@ #pragma once +#include #include -#include "folly/Synchronized.h" + +#include "fixed_block_pool.h" namespace kv_mem { @@ -29,18 +31,87 @@ class SynchronizedShardedMap { public: using iterator = typename folly::F14FastMap::const_iterator; - explicit SynchronizedShardedMap(std::size_t numShards) : shards_(numShards) {} + explicit SynchronizedShardedMap(std::size_t numShards, + std::size_t block_size, + std::size_t block_alignment, + std::size_t blocks_per_chunk = 8192) + : shards_(numShards), mempools_(numShards) { + // Init mempools_ + for (auto& pool : mempools_) { + pool = std::make_unique( + block_size, block_alignment, blocks_per_chunk); + } + } // Get shard map by index - auto& by(int index) { - return shards_.at(index % shards_.size()); + auto& by(int index) { return shards_.at(index % shards_.size()); } + + // Get shard pool by index + auto* pool_by(int index) { + return mempools_.at(index % shards_.size()).get(); } - auto getNumShards() { - return shards_.size(); + auto getNumShards() const { return shards_.size(); } + + auto getUsedMemSize() const { + size_t used_mem_size = 0; + size_t block_size = mempools_[0]->get_aligned_block_size(); + for (size_t i = 0; i < shards_.size(); ++i) { + auto rlmap = shards_[i].rlock(); + // only calculate the sizes of K, V and block that are used + used_mem_size += rlmap->size() * (sizeof(K) + sizeof(V) + block_size); + } + return used_mem_size; + } + + void save(const std::string& filename) const { + std::ofstream out(filename, std::ios::binary); + if (!out) { + throw std::runtime_error("Failed to open file for writing"); + } + + const std::size_t num_shards = getNumShards(); + out.write(reinterpret_cast(&num_shards), sizeof(num_shards)); + out.close(); + + // save every mempool + for (std::size_t shard_id = 0; shard_id < getNumShards(); ++shard_id) { + std::string pool_filename = filename + ".pool." + std::to_string(i); + auto wlock = shards_[shard_id].wlock(); + mempools_[shard_id]->serialize(pool_filename); + } + } + + void load(const std::string& filename) { + std::ifstream in(filename, std::ios::binary); + if (!in) { + throw std::runtime_error("Failed to open file for reading"); + } + + size_t num_shards; + in.read(reinterpret_cast(&num_shards), sizeof(num_shards)); + in.close(); + + if (num_shards != getNumShards()) { + throw std::runtime_error("Shard count mismatch between file and map"); + } + + for (std::size_t shard_id = 0; shard_id < getNumShards(); ++shard_id) { + std::string pool_filename = filename + ".pool." + std::to_string(i); + auto wlock = shards_[shard_id].wlock(); + // first deserialize mempool + mempools_[shard_id]->deserialize(pool_filename); + // load map from mempool + wlock->clear(); + mempools_[shard_id]->for_each_block([&wlock](void* block) { + auto key = FixedBlockPool::get_key(block); + wlock->emplace(key, reinterpret_cast(block)); + }); + } } private: std::vector, M>> shards_; + std::vector> mempools_; }; -} // namespace kv_mem +} // namespace kv_mem diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h index d2daf1c26d..32225fe059 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h @@ -15,7 +15,8 @@ #include "SynchronizedShardedMap.h" #include "deeplearning/fbgemm/fbgemm_gpu/src/ssd_split_embeddings_cache/initializer.h" -#include "store_value.h" +#include "fixed_block_pool.h" +#include "feature_evict.h" #include #include @@ -46,6 +47,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { /// @param max_D the maximum dimension of of embedding tensor /// @param uniform_init_lower the lower bound of the uniform distribution /// @param uniform_init_upper the upper bound of the uniform distribution + /// @param feature_evict_config feature evict config /// @param num_shards number of shards for the kvstore. This is to improve /// parallelization. Each key value pair will be sharded into one shard. /// @param num_threads num of threads that kvstore needs to be run upon for @@ -59,6 +61,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { int64_t max_D, double uniform_init_lower, double uniform_init_upper, + FeatureEvictConfig feature_evict_config, int64_t num_shards = 8, int64_t num_threads = 32, int64_t row_storage_bitwidth = 32, @@ -68,10 +71,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { max_D, 0), // l2_cache_size_gb =0 to disable l2 cache max_D_(max_D), + feature_evict_config_(feature_evict_config), num_shards_(num_shards), weight_ttl_in_hours_(weight_ttl_in_hours), - kv_store_(SynchronizedShardedMap>( - num_shards_)), + block_size_(FixedBlockPool::calculate_block_size(max_D)), + block_alignment_(FixedBlockPool::calculate_block_alignment()), + kv_store_(SynchronizedShardedMap( + num_shards_, + block_size_, + block_alignment_, + /*blocks_per_chunk=*/8192)), elem_size_(row_storage_bitwidth / 8) { executor_ = std::make_unique(std::max( num_threads, facebook::Proc::getCpuInfo().numCpuCores)); @@ -81,6 +90,9 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { uniform_init_lower, uniform_init_upper, row_storage_bitwidth); + if (feature_evict_config_.trigger_mode != EvictTriggerMode::DISABLED) { + feature_evict_ = create_feature_evict(feature_evict_config_, executor_.get(), kv_store_, max_D); + } } void initialize_initializers( @@ -185,20 +197,34 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { CHECK_EQ(indices.size(0), weights.size(0)); { auto wlmap = kv_store_.by(shard_id).wlock(); + auto* pool = kv_store_.pool_by(shard_id); auto indices_data_ptr = indices.data_ptr(); for (auto index_iter = indexes.begin(); index_iter != indexes.end(); index_iter++) { const auto& id_index = *index_iter; auto id = int64_t(indices_data_ptr[id_index]); - wlmap->try_emplace( - id, - StoreValue(std::vector( - weights[id_index] - .template data_ptr(), - weights[id_index] - .template data_ptr() + - weights[id_index].numel()))); + // use mempool + weight_type* block = nullptr; + // First check if the key already exists + auto it = wlmap->find(id); + if (it != wlmap->end()) { + block = it->second; + } else { + // Key doesn't exist, allocate new block and insert. + block = pool->allocate_t(); + wlmap->insert({id, block}); + } + if (feature_evict_) { + feature_evict_->update_feature_statistics(block); + } + auto* data_ptr = FixedBlockPool::data_ptr(block); + std::copy(weights[id_index] + .template data_ptr(), + weights[id_index] + .template data_ptr() + + weights[id_index].numel(), + data_ptr); } } }); @@ -276,16 +302,12 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { row_storage_data_ptr)); continue; } - const auto& cache_results = - cached_iter->second.getValueAndPromote(); - CHECK_EQ(cache_results.size(), max_D_); + // use mempool + const auto* data_ptr = FixedBlockPool::data_ptr(cached_iter->second); std::copy( - reinterpret_cast( - &(cache_results[0])), - reinterpret_cast( - &(cache_results[max_D_])), - &(weights_data_ptr - [id_index * max_D_])); // dst_start + data_ptr, + data_ptr + max_D_, + &(weights_data_ptr[id_index * max_D_])); // dst_start } } }); @@ -307,6 +329,36 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { void compact() override {} + void trigger_feature_evict() { + if (feature_evict_) { + feature_evict_->trigger_evict(); + } + } + + void feature_evict_resume() { + if (feature_evict_) { + feature_evict_->resume(); + } + } + + void feature_evict_pause() { + if (feature_evict_) { + feature_evict_->pause(); + } + } + + void maybe_evict_by_step() { + if (feature_evict_config_.trigger_mode == EvictTriggerMode::ITERATION && + feature_evict_config_.trigger_step_interval > 0 && + ++current_iter_ % feature_evict_config_.trigger_step_interval == 0) { + trigger_feature_evict(); + } + } + + size_t get_map_used_memsize() const { + return kv_store_.getUsedMemSize(); + } + private: void fill_from_row_storage( int shard_id, @@ -368,10 +420,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { int64_t max_D_; int64_t num_shards_; int64_t weight_ttl_in_hours_; - SynchronizedShardedMap> kv_store_; + // mempool params + size_t block_size_; + size_t block_alignment_; + SynchronizedShardedMap kv_store_; std::atomic_bool is_eviction_ongoing_ = false; std::vector> initializers_; int64_t elem_size_; + FeatureEvictConfig feature_evict_config_; + std::unique_ptr> feature_evict_; + int current_iter_ = 0; }; // class DramKVEmbeddingCache } // namespace kv_mem diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h index 0b915e50ba..2543091d6e 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h @@ -26,15 +26,34 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { int64_t max_D, double uniform_init_lower, double uniform_init_upper, + int evict_trigger_mode, + int evict_trigger_strategy, + int64_t trigger_step_interval, + uint32_t ttl, + uint32_t count_threshold, + float count_decay_rate, + double l2_weight_threshold, int64_t num_shards = 8, int64_t num_threads = 32, int64_t row_storage_bitwidth = 32, int64_t weight_ttl_in_hours = 2) { + + // feature evict config + FeatureEvictConfig feature_evict_config; + feature_evict_config.trigger_mode = static_cast(evict_trigger_mode); + feature_evict_config.trigger_strategy = static_cast(evict_trigger_strategy); + feature_evict_config.trigger_step_interval = trigger_step_interval; + feature_evict_config.ttl = ttl; + feature_evict_config.count_threshold = count_threshold; + feature_evict_config.count_decay_rate = count_decay_rate; + feature_evict_config.l2_weight_threshold = l2_weight_threshold; + if (row_storage_bitwidth == 16) { impl_ = std::make_shared>( max_D, uniform_init_lower, uniform_init_upper, + feature_evict_config, num_shards, num_threads, row_storage_bitwidth, @@ -44,6 +63,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { max_D, uniform_init_lower, uniform_init_upper, + feature_evict_config, num_shards, num_threads, row_storage_bitwidth, @@ -67,7 +87,11 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { } void set(at::Tensor indices, at::Tensor weights, at::Tensor count) { - return impl_->set(indices, weights, count); + impl_->feature_evict_pause(); + impl_->set(indices, weights, count); + // when use ITERATION EvictTriggerMode, trigger evict by step + impl_->maybe_evict_by_step(); + impl_->feature_evict_resume(); } void flush() { @@ -86,7 +110,9 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { at::Tensor weights, at::Tensor count, int64_t sleep_ms) { - return impl_->get(indices, weights, count, sleep_ms); + impl_->feature_evict_pause(); + impl_->get(indices, weights, count, sleep_ms); + impl_->feature_evict_resume(); } void wait_util_filling_work_done() { @@ -97,6 +123,10 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { return impl_->get_keys_in_range(start, end); } + size_t get_map_used_memsize() const { + return impl_->get_map_used_memsize(); + } + private: // friend class EmbeddingRocksDBWrapper; friend class ssd::KVTensorWrapper; diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h b/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h new file mode 100644 index 0000000000..8a384b3f55 --- /dev/null +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h @@ -0,0 +1,348 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "SynchronizedShardedMap.h" + +namespace kv_mem { + +enum class EvictTriggerMode { + DISABLED, // Do not use feature evict + ITERATION, // Trigger based on iteration steps + MANUAL // Manually triggered by upstream +}; + +enum class EvictTriggerStrategy { BY_TIMESTAMP, BY_COUNTER, BY_TIMESTAMP_AND_COUNTER, BY_L2WEIGHT }; + +struct FeatureEvictConfig { + EvictTriggerStrategy trigger_strategy; + EvictTriggerMode trigger_mode; + int64_t trigger_step_interval; + uint32_t ttl; + uint32_t count_threshold; + float count_decay_rate; + double l2_weight_threshold; +}; + +template +class FeatureEvict { + public: + FeatureEvict(folly::CPUThreadPoolExecutor* executor, SynchronizedShardedMap& kv_store) + : executor_(executor), + kv_store_(kv_store), + evict_flag_(false), + evict_interrupt_(false), + num_shards_(kv_store.getNumShards()) { + init_shard_status(); + } + + virtual ~FeatureEvict() { + wait_completion(); // Wait for all asynchronous tasks to complete. + }; + + // Trigger asynchronous eviction. + // If there is an ongoing task, return directly to prevent multiple triggers. + // If there is no ongoing task, initialize the task state. + void trigger_evict() { + std::lock_guard lock(mutex_); + if (evict_flag_.exchange(true)) return; + prepare_evict(); + } + + // Resume task execution. Returns true if there is an ongoing task, false otherwise. + bool resume() { + std::lock_guard lock(mutex_); + if (!evict_flag_.load()) return false; + evict_interrupt_.store(false); + for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { + submit_shard_task(shard_id); + } + return true; + }; + + // Pause the eviction process. Returns true if there is an ongoing task, false otherwise. + // During the pause phase, check whether the eviction is complete. + bool pause() { + std::lock_guard lock(mutex_); + if (!evict_flag_.load()) return false; + evict_interrupt_.store(true); + check_and_reset_evict_flag(); + wait_completion(); + return true; + } + + // Check whether eviction is ongoing. + bool is_evicting() { + std::lock_guard lock(mutex_); + check_and_reset_evict_flag(); + return evict_flag_.load(); + } + + virtual void update_feature_statistics(weight_type* block) = 0; + + protected: + void init_shard_status() { + block_cursors_.resize(num_shards_); + block_nums_snapshot_.resize(num_shards_); + shards_finished_.clear(); + for (int i = 0; i < num_shards_; ++i) { + block_cursors_[i] = 0; + block_nums_snapshot_[i] = 0; + shards_finished_.emplace_back(std::make_unique>(false)); + } + } + + // Initialize shard state. + void prepare_evict() { + for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { + auto rlmap = kv_store_.by(shard_id).rlock(); + auto* mempool = kv_store_.pool_by(shard_id); + block_nums_snapshot_[shard_id] = mempool->get_chunks().size() * mempool->get_blocks_per_chunk(); + block_cursors_[shard_id] = 0; + shards_finished_[shard_id]->store(false); + } + } + + void submit_shard_task(int shard_id) { + if (shards_finished_[shard_id]->load()) return; + futures_.emplace_back(folly::via(executor_).thenValue([this, shard_id](auto&&) { process_shard(shard_id); })); + } + + void process_shard(int shard_id) { + auto start_time = std::chrono::high_resolution_clock::now(); + size_t evicted_count = 0; + size_t processed_count = 0; + + auto wlock = kv_store_.by(shard_id).wlock(); + auto* pool = kv_store_.pool_by(shard_id); + + while (!evict_interrupt_.load() && block_cursors_[shard_id] < block_nums_snapshot_[shard_id]) { + auto* block = pool->template get_block(block_cursors_[shard_id]++); + processed_count++; + if (block && evict_block(block)) { + int64_t key = FixedBlockPool::get_key(block); + auto it = wlock->find(key); + if (it != wlock->end() && block == it->second) { + wlock->erase(key); + pool->template deallocate_t(block); + evicted_count++; + } + } + } + + // Check whether the loop ends normally. + if (block_cursors_[shard_id] >= block_nums_snapshot_[shard_id]) { + shards_finished_[shard_id]->store(true); + } + + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end_time - start_time); + + fmt::print( + "Shard {} completed: \n" + " - Time taken: {}ms\n" + " - Total blocks processed: {}\n" + " - Blocks evicted: {}\n" + " - Eviction rate: {:.2f}%\n", + shard_id, + duration.count(), + processed_count, + evicted_count, + (evicted_count * 100.0f) / processed_count); + } + + virtual bool evict_block(weight_type* block) = 0; + + void wait_completion() { + folly::collectAll(futures_).wait(); + futures_.clear(); + } + + // Check and reset the eviction flag. + void check_and_reset_evict_flag() { + bool all_finished = true; + for (int i = 0; i < num_shards_; ++i) { + if (!shards_finished_[i]->load()) all_finished = false; + } + if (all_finished) evict_flag_.store(false); + } + + folly::CPUThreadPoolExecutor* executor_; // Thread pool. + SynchronizedShardedMap& kv_store_; // Sharded map. + std::vector block_cursors_; // Index of processed blocks. + std::vector block_nums_snapshot_; // Snapshot of total blocks at eviction trigger. + std::vector>> shards_finished_; // Flags indicating whether shards are finished. + std::atomic evict_flag_; // Indicates whether an eviction task is ongoing. + std::atomic evict_interrupt_; // Indicates whether the eviction task is paused. + std::vector> futures_; // Records of shard tasks. + std::mutex mutex_; // Interface lock to ensure thread safety for public methods. + int num_shards_; // Number of concurrent tasks. +}; + +template +class CounterBasedEvict : public FeatureEvict { + public: + CounterBasedEvict(folly::CPUThreadPoolExecutor* executor, + SynchronizedShardedMap& kv_store, + float decay_rate, + uint32_t threshold) + : FeatureEvict(executor, kv_store), decay_rate_(decay_rate), threshold_(threshold) {} + + void update_feature_statistics(weight_type* block) override { FixedBlockPool::update_count(block); } + + protected: + bool evict_block(weight_type* block) override { + // Apply decay and check the threshold. + auto current_count = FixedBlockPool::get_count(block); + current_count *= decay_rate_; + FixedBlockPool::set_count(block, current_count); + return current_count < threshold_; + } + + private: + float decay_rate_; // Decay rate for the block count. + uint32_t threshold_; // Threshold for eviction. +}; + +template +class TimeBasedEvict : public FeatureEvict { + public: + TimeBasedEvict(folly::CPUThreadPoolExecutor* executor, + SynchronizedShardedMap& kv_store, + uint32_t ttl) + : FeatureEvict(executor, kv_store), ttl_(ttl) {} + + void update_feature_statistics(weight_type* block) override { FixedBlockPool::update_timestamp(block); } + + protected: + bool evict_block(weight_type* block) override { + auto current_time = FixedBlockPool::current_timestamp(); + return current_time - FixedBlockPool::get_timestamp(block) > ttl_; + } + + private: + uint32_t ttl_; // Time-to-live for eviction. +}; + +template +class TimeCounterBasedEvict : public FeatureEvict { + public: + TimeCounterBasedEvict(folly::CPUThreadPoolExecutor* executor, + SynchronizedShardedMap& kv_store, + uint32_t ttl, + float decay_rate, + uint32_t threshold) + : FeatureEvict(executor, kv_store), ttl_(ttl), decay_rate_(decay_rate), threshold_(threshold) {} + + void update_feature_statistics(weight_type* block) override { + FixedBlockPool::update_timestamp(block); + FixedBlockPool::update_count(block); + } + + protected: + bool evict_block(weight_type* block) override { + // Apply decay and check the count threshold and ttl. + auto current_time = FixedBlockPool::current_timestamp(); + auto current_count = FixedBlockPool::get_count(block); + current_count *= decay_rate_; + FixedBlockPool::set_count(block, current_count); + return (current_time - FixedBlockPool::get_timestamp(block) > ttl_) && (current_count < threshold_); + } + + private: + uint32_t ttl_; // Time-to-live for eviction. + float decay_rate_; // Decay rate for the block count. + uint32_t threshold_; // Count threshold for eviction. +}; + +template +class L2WeightBasedEvict : public FeatureEvict { + public: + L2WeightBasedEvict(folly::CPUThreadPoolExecutor* executor, + SynchronizedShardedMap& kv_store, + double threshold, + size_t dimension) + : FeatureEvict(executor, kv_store), threshold_(threshold), dimension_(dimension) {} + + void update_feature_statistics([[maybe_unused]] weight_type* block) override {} + + protected: + bool evict_block(weight_type* block) override { + auto l2weight = FixedBlockPool::get_l2weight(block, dimension_); + return l2weight < threshold_; + } + + private: + double threshold_; // L2 weight threshold for eviction. + size_t dimension_; // Embedding dimension +}; + +template +std::unique_ptr> create_feature_evict( + const FeatureEvictConfig& config, + folly::CPUThreadPoolExecutor* executor, + SynchronizedShardedMap& kv_store, + size_t dimension) { + if (executor == nullptr) { + throw std::invalid_argument("executor cannot be null"); + } + + switch (config.trigger_strategy) { + case EvictTriggerStrategy::BY_TIMESTAMP: { + if (config.ttl <= 0) { + throw std::invalid_argument("ttl must be positive"); + } + return std::make_unique>(executor, kv_store, config.ttl); + } + + case EvictTriggerStrategy::BY_COUNTER: { + if (config.count_decay_rate <= 0 || config.count_decay_rate > 1) { + throw std::invalid_argument("count_decay_rate must be in range (0,1]"); + } + if (config.count_threshold <= 0) { + throw std::invalid_argument("count_threshold must be positive"); + } + return std::make_unique>( + executor, kv_store, config.count_decay_rate, config.count_threshold); + } + + case EvictTriggerStrategy::BY_TIMESTAMP_AND_COUNTER: { + if (config.ttl <= 0) { + throw std::invalid_argument("ttl must be positive"); + } + if (config.count_decay_rate <= 0 || config.count_decay_rate > 1) { + throw std::invalid_argument("count_decay_rate must be in range (0,1]"); + } + if (config.count_threshold <= 0) { + throw std::invalid_argument("count_threshold must be positive"); + } + return std::make_unique>( + executor, kv_store, config.ttl, config.count_decay_rate, config.count_threshold); + } + + case EvictTriggerStrategy::BY_L2WEIGHT: { + if (config.l2_weight_threshold <= 0) { + throw std::invalid_argument("l2_weight_threshold must be positive"); + } + // TODO: optimizer parameters should not be included in dimension + return std::make_unique>( + executor, kv_store, config.l2_weight_threshold, dimension); + } + + default: + throw std::runtime_error("Unknown evict trigger strategy"); + } +} + +} // namespace kv_mem diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h b/fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h new file mode 100644 index 0000000000..f8acbffb09 --- /dev/null +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h @@ -0,0 +1,331 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace kv_mem { +static constexpr uint32_t kMaxInt31Counter = 2147483647; + +class FixedBlockPool : public std::pmr::memory_resource { + public: + // Chunk metadata + struct ChunkInfo { + void* ptr; // Memory block pointer + std::size_t size; // Total size + std::size_t alignment; + }; + + // Metadata structure (publicly accessible) + // alignas(8) MetaHeader >= sizeof(void*), avoid mempool block too small. + struct alignas(8) MetaHeader { // 16bytes + int64_t key; // feature key 8bytes + uint32_t timestamp; // 4 bytes,the unit is second, uint32 indicates a range of over 120 years + uint32_t count : 31; // only 31 bit is used, max value is 2147483647 + bool used : 1; // Mark whether this block is in use for the judgment of memory pool traversal + // Can be extended with other fields: uint32_t click, etc. + }; + + // Metadata operations + + // Key operations + static uint64_t get_key(const void* block) { return reinterpret_cast(block)->key; } + static void set_key(void* block, uint64_t key) { reinterpret_cast(block)->key = key; } + + // used operations + static bool get_used(const void* block) { return reinterpret_cast(block)->used; } + static void set_used(void* block, bool used) { reinterpret_cast(block)->used = used; } + + // Score operations + static uint32_t get_count(const void* block) { return reinterpret_cast(block)->count; } + static void set_count(void* block, uint32_t count) { reinterpret_cast(block)->count = count; } + static void update_count(void* block) { + // Avoid addition removal + if (reinterpret_cast(block)->count < kMaxInt31Counter) { + reinterpret_cast(block)->count++; + } + } + // timestamp operations + static uint32_t get_timestamp(const void* block) { return reinterpret_cast(block)->timestamp; } + static void update_timestamp(void* block) { reinterpret_cast(block)->timestamp = current_timestamp(); } + static uint32_t current_timestamp() { + return std::time(nullptr); + } + + // Calculate storage size + template + static size_t calculate_block_size(size_t dimension) { + return sizeof(FixedBlockPool::MetaHeader) + dimension * sizeof(scalar_t); + } + + // Calculate alignment requirements + template + static size_t calculate_block_alignment() { + return std::max(alignof(FixedBlockPool::MetaHeader), alignof(scalar_t)); + } + + // Data pointer retrieval + template + static scalar_t* data_ptr(scalar_t* block) { + return reinterpret_cast(reinterpret_cast(block) + sizeof(FixedBlockPool::MetaHeader)); + } + + template + static const scalar_t* data_ptr(const scalar_t* block) { + return reinterpret_cast(reinterpret_cast(block) + sizeof(FixedBlockPool::MetaHeader)); + } + + template + static scalar_t get_l2weight(scalar_t* block, size_t dimension) { + scalar_t* data = FixedBlockPool::data_ptr(block); + return std::sqrt( + std::accumulate(data, data + dimension, scalar_t(0), + [](scalar_t sum, scalar_t val) { return sum + val * val; })); + } + + explicit FixedBlockPool(std::size_t block_size, // Size of each memory block + std::size_t block_alignment, // Memory block alignment requirement + std::size_t blocks_per_chunk = 8192, // Number of blocks per chunk + std::pmr::memory_resource* upstream = std::pmr::new_delete_resource()) + // Minimum block size is 8 bytes + : block_size_(std::max(block_size, sizeof(void*))), + block_alignment_(block_alignment), + blocks_per_chunk_(blocks_per_chunk), + upstream_(upstream), + chunks_(upstream) { + // Validate minimum data size, whether it's less than 8 bytes + // half type, 2 bytes, minimum embedding length 4 + // float type, 4 bytes, minimum embedding length 2 + // Large objects use memory pool, small objects are placed directly in the + // hashtable + if (block_size < sizeof(void*)) { + // Block size must be at least able to store a pointer (for free list) + throw std::invalid_argument("Block size must be at least sizeof(void*)"); + } + + // Validate that alignment requirement is a power of 2 + if ((block_alignment_ & (block_alignment_ - 1)) != 0) { + throw std::invalid_argument("Alignment must be power of two"); + } + + // Validate that block size is a multiple of alignment + if (block_size_ % block_alignment_ != 0) { + throw std::invalid_argument("Block size must align with alignment"); + } + + // Ensure block size is at least 1 + if (block_size_ < 1) { + throw std::invalid_argument("Block size must be at least 1"); + } + } + + // Release all allocated memory during destruction + ~FixedBlockPool() override { + for (auto&& chunk : chunks_) { + upstream_->deallocate(chunk.ptr, chunk.size, chunk.alignment); + } + } + + // Create memory block with metadata + template + scalar_t* allocate_t() { + return reinterpret_cast(this->allocate(block_size_, block_alignment_)); + } + + // Destroy memory block + template + void deallocate_t(scalar_t* block) { + this->deallocate(block, block_size_, block_alignment_); + } + + template + scalar_t* get_block(size_t index) { + char* current_chunk = static_cast(chunks_[index / blocks_per_chunk_].ptr); + char* block = current_chunk + block_size_ * (index % blocks_per_chunk_); + if (FixedBlockPool::get_used(block)) { + return reinterpret_cast(block); + } else { + return nullptr; + } + }; + + template + void for_each_block(Func&& func) const { + for (const auto& chunk : chunks_) { + char* current = static_cast(chunk.ptr); + for (size_t i = 0; i < blocks_per_chunk_; ++i) { + if (FixedBlockPool::get_used(current)) { + func(current); + } + current += block_size_; + } + } + } + + void serialize(const std::string& filename) const { + auto start = std::chrono::high_resolution_clock::now(); + + std::ofstream out(filename, std::ios::binary); + if (!out) { + throw std::runtime_error("Failed to open file for writing"); + } + // Write metadata + out.write(reinterpret_cast(&block_size_), sizeof(block_size_)); + out.write(reinterpret_cast(&block_alignment_), sizeof(block_alignment_)); + out.write(reinterpret_cast(&blocks_per_chunk_), sizeof(blocks_per_chunk_)); + const size_t num_chunks = chunks_.size(); + out.write(reinterpret_cast(&num_chunks), sizeof(num_chunks)); + + // Write data for each chunk + for (const auto& chunk : chunks_) { + assert(chunk.size == block_size_ * blocks_per_chunk_); + out.write(static_cast(chunk.ptr), static_cast(chunk.size)); + } + out.flush(); + out.close(); + double data_size_mb = static_cast((block_size_ * chunks_.size() * blocks_per_chunk_)) / (1024.0 * 1024.0); + + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration(end - start).count(); + + fmt::print("Serialized {}: size={:.3f}MB, time={}s, throughput={:.3f}MB/s\n", + filename, + data_size_mb, + duration, + (data_size_mb / duration)); + } + + void deserialize(const std::string& filename) { + auto start = std::chrono::high_resolution_clock::now(); + + std::ifstream in(filename, std::ios::binary); + if (!in) { + throw std::runtime_error("Failed to open file for reading"); + } + + // Read metadata + std::size_t block_size, block_alignment, blocks_per_chunk, num_chunks; + in.read(reinterpret_cast(&block_size), sizeof(block_size)); + in.read(reinterpret_cast(&block_alignment), sizeof(block_alignment)); + in.read(reinterpret_cast(&blocks_per_chunk), sizeof(blocks_per_chunk)); + in.read(reinterpret_cast(&num_chunks), sizeof(num_chunks)); + + // Validate parameters + if (block_size != block_size_) { + throw std::invalid_argument("Invalid block_size in file"); + } + if (block_alignment != block_alignment_) { + throw std::invalid_argument("Invalid block_alignment in file"); + } + if (blocks_per_chunk != blocks_per_chunk_) { + throw std::invalid_argument("Invalid blocks_per_chunk_ in file"); + } + + // Read data for each chunk and rebuild memory structure + const std::size_t chunk_size = block_size_ * blocks_per_chunk_; + for (size_t i = 0; i < num_chunks; ++i) { + void* chunk_ptr = upstream_->allocate(chunk_size, block_alignment_); + in.read(static_cast(chunk_ptr), static_cast(chunk_size)); + // Add chunk to memory pool + chunks_.push_back({chunk_ptr, chunk_size, block_alignment}); + // Rebuild free_list_ + char* current = static_cast(chunk_ptr); + for (size_t j = 0; j < blocks_per_chunk; ++j) { + void* block = current + j * block_size; + if (!get_used(block)) { + do_deallocate(block, block_size_, block_alignment_); + } + } + } + in.close(); + + double data_size_mb = static_cast((block_size_ * chunks_.size() * blocks_per_chunk_)) / (1024.0 * 1024.0); + + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration(end - start).count(); + + fmt::print("Deserialized {}: size={:.3f}MB, time={}s, throughput={:.3f}MB/s\n", + filename, + data_size_mb, + duration, + (data_size_mb / duration)); + } + + [[nodiscard]] const auto& get_chunks() const noexcept { return chunks_; } + [[nodiscard]] std::size_t get_block_size() const noexcept { return block_size_; } + [[nodiscard]] std::size_t get_block_alignment() const noexcept { return block_alignment_; } + [[nodiscard]] std::size_t get_blocks_per_chunk() const noexcept { return blocks_per_chunk_; } + [[nodiscard]] std::size_t get_aligned_block_size() const noexcept { + return (block_size_ + block_alignment_ - 1) / block_alignment_ * block_alignment_; + } + + protected: + // Core allocation function + void* do_allocate(std::size_t bytes, std::size_t alignment) override { + // Only handle matching block size and alignment requirements + if (bytes != block_size_ || alignment != block_alignment_) { + throw std::bad_alloc(); + } + + // Allocate a new chunk when no blocks are available + if (!free_list_) { + allocate_chunk(); + } + + // Take a block from the head of the free list + void* result = free_list_; + free_list_ = *static_cast(free_list_); + FixedBlockPool::set_used(result, true); + return result; + } + + // Core deallocation function + void do_deallocate(void* p, [[maybe_unused]] std::size_t bytes, [[maybe_unused]] std::size_t alignment) override { + // Insert memory block back to the head of free list + *static_cast(p) = free_list_; + free_list_ = p; + FixedBlockPool::set_used(free_list_, false); + } + + // Resource equality comparison (only the same object is equal) + [[nodiscard]] bool do_is_equal(const std::pmr::memory_resource& other) const noexcept override { return this == &other; } + + private: + // Allocate a new memory chunk + void allocate_chunk() { + const std::size_t chunk_size = block_size_ * blocks_per_chunk_; + + // Allocate aligned memory through upstream resource + void* chunk_ptr = upstream_->allocate(chunk_size, block_alignment_); + + // Record chunk information for later release + chunks_.push_back({chunk_ptr, chunk_size, block_alignment_}); + + // Initialize free list: link blocks in reverse order from chunk end to + // beginning (improves locality) + char* current = static_cast(chunk_ptr) + chunk_size; + for (std::size_t i = 0; i < blocks_per_chunk_; ++i) { + current -= block_size_; + *reinterpret_cast(current) = free_list_; + FixedBlockPool::set_used(current, false); + free_list_ = current; + } + } + + // Member variables + const std::size_t block_size_; // Block size (not less than pointer size) + const std::size_t block_alignment_; // Block alignment requirement + const std::size_t blocks_per_chunk_; // Number of blocks per chunk + std::pmr::memory_resource* upstream_; // Upstream memory resource + std::pmr::vector chunks_{1024}; // Records of all allocated chunks + void* free_list_ = nullptr; // Free block list head pointer +}; +} // namespace kv_mem diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/store_value.h b/fbgemm_gpu/src/dram_kv_embedding_cache/store_value.h deleted file mode 100644 index 375c63ce46..0000000000 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/store_value.h +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once -#include -#include "common/time/Time.h" - -namespace kv_mem { - -/// @ingroup embedding-dram-kvstore -/// -/// @brief data structure to store tensor value and it's timestamp -// -template -class StoreValue { - public: - explicit StoreValue(std::vector&& value) { - value_ = std::move(value); - timestamp_ = facebook::WallClockUtil::NowInUsecFast(); - } - - explicit StoreValue(StoreValue&& pv) noexcept { - timestamp_ = facebook::WallClockUtil::NowInUsecFast(); - value_ = std::move(pv.value_); - } - - int64_t getTimestamp() const { - return timestamp_; - } - - const std::vector& getValue() const { - return value_; - } - - const std::vector& getValueAndPromote() { - timestamp_ = facebook::WallClockUtil::NowInUsecFast(); - return value_; - } - - private: - StoreValue& operator=(const StoreValue&) = delete; - StoreValue& operator=(const StoreValue&&) = delete; - StoreValue(const StoreValue& other) = delete; - - // cached tensor value - std::vector value_; - - // last visit timestamp - int64_t timestamp_; -}; -} // namespace kv_mem diff --git a/fbgemm_gpu/test/dram_kv_embedding_cache/CMakeLists.txt b/fbgemm_gpu/test/dram_kv_embedding_cache/CMakeLists.txt new file mode 100644 index 0000000000..9ab483eab3 --- /dev/null +++ b/fbgemm_gpu/test/dram_kv_embedding_cache/CMakeLists.txt @@ -0,0 +1,24 @@ +find_package(folly REQUIRED) +find_package(gflags REQUIRED) + +include_directories( + ${FBGEMM_SOURCE_DIR} +) + +set(COMMON_COMPILE_FEATURES cxx_std_17) +set(COMMON_COMPILE_OPTIONS "-O3") +set(COMMON_LINK_LIBRARIES gtest gtest_main Folly::folly) + +set(TEST_TARGETS + fixed_block_pool_test + fixed_block_pool_saver_test + sharded_map_test + feature_evict_test +) + +foreach (target ${TEST_TARGETS}) + add_executable(${target} ${CMAKE_CURRENT_SOURCE_DIR}/${target}.cpp) + target_compile_features(${target} PUBLIC ${COMMON_COMPILE_FEATURES}) + target_compile_options(${target} PUBLIC ${COMMON_COMPILE_OPTIONS}) + target_link_libraries(${target} ${COMMON_LINK_LIBRARIES}) +endforeach () \ No newline at end of file diff --git a/fbgemm_gpu/test/dram_kv_embedding_cache/feature_evict_test.cpp b/fbgemm_gpu/test/dram_kv_embedding_cache/feature_evict_test.cpp new file mode 100644 index 0000000000..48a39d8b45 --- /dev/null +++ b/fbgemm_gpu/test/dram_kv_embedding_cache/feature_evict_test.cpp @@ -0,0 +1,381 @@ +#include "fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h" + +#include +#include + +#include +#include +#include +#include + +#include "fbgemm_gpu/src/dram_kv_embedding_cache/SynchronizedShardedMap.h" + +namespace kv_mem { +static constexpr int DIMENSION = 128; +size_t BLOCK_SIZE = FixedBlockPool::calculate_block_size(DIMENSION); +size_t BLOCK_ALIGNMENT = FixedBlockPool::calculate_block_alignment(); + +TEST(FeatureEvictTest, CounterBasedEviction) { + static constexpr int NUM_SHARDS = 8; + auto executor_ = std::make_unique(4); + auto kv_store_ = std::make_unique>(NUM_SHARDS, BLOCK_SIZE, BLOCK_ALIGNMENT); + + // Insert test data + for (int i = 0; i < 1000; ++i) { + int shard_id = i % NUM_SHARDS; + auto wlock = kv_store_->by(shard_id).wlock(); + auto* pool = kv_store_->pool_by(shard_id); + auto* block = pool->allocate_t(); + FixedBlockPool::set_key(block, i); + FixedBlockPool::set_count(block, 1); // Initial score + FixedBlockPool::set_used(block, true); + wlock->insert({i, block}); + } + + for (int i = 1000; i < 2000; ++i) { + int shard_id = i % NUM_SHARDS; + auto wlock = kv_store_->by(shard_id).wlock(); + auto* pool = kv_store_->pool_by(shard_id); + auto* block = pool->allocate_t(); + FixedBlockPool::set_key(block, i); + FixedBlockPool::set_count(block, 2); // Initial score + FixedBlockPool::set_used(block, true); + wlock->insert({i, block}); + } + + std::unique_ptr> feature_evict; + int evict_trigger_mode = 2; + int evict_trigger_strategy = 1; + uint32_t count_threshold = 1; + float count_decay_rate = 0.5; + // feature evict config + FeatureEvictConfig feature_evict_config; + feature_evict_config.trigger_mode = static_cast(evict_trigger_mode); + feature_evict_config.trigger_strategy = static_cast(evict_trigger_strategy); + feature_evict_config.count_threshold = count_threshold; + feature_evict_config.count_decay_rate = count_decay_rate; + + if (feature_evict_config.trigger_mode != EvictTriggerMode::DISABLED) { + feature_evict = create_feature_evict(feature_evict_config, executor_.get(),*kv_store_.get(), 4); + } + + // Initial validation + size_t total_blocks = 0; + for (int shard_id = 0; shard_id < NUM_SHARDS; ++shard_id) { + auto rlock = kv_store_->by(shard_id).rlock(); + total_blocks += rlock->size(); + } + ASSERT_EQ(total_blocks, 2000); + + // Perform eviction + feature_evict->trigger_evict(); + + // Validate eviction process + while (feature_evict->is_evicting()) { + feature_evict->resume(); + std::this_thread::sleep_for(std::chrono::microseconds(5)); + feature_evict->pause(); + } + + // Validate results + size_t remaining = 0; + for (int shard_id = 0; shard_id < NUM_SHARDS; ++shard_id) { + auto rlock = kv_store_->by(shard_id).rlock(); + remaining += rlock->size(); + // Validate score decay + for (const auto& [key, block] : *rlock) { + ASSERT_EQ(FixedBlockPool::get_count(block), 1); + } + } + std::cout << "remaining: " << remaining << std::endl; + ASSERT_EQ(remaining, 1000); +} + +TEST(FeatureEvictTest, TimeBasedEviction) { + static constexpr int NUM_SHARDS = 8; + auto executor_ = std::make_unique(4); + auto kv_store_ = std::make_unique>(NUM_SHARDS, BLOCK_SIZE, BLOCK_ALIGNMENT); + + // Insert test data + for (int i = 0; i < 1000; ++i) { + int shard_id = i % NUM_SHARDS; + auto wlock = kv_store_->by(shard_id).wlock(); + auto* pool = kv_store_->pool_by(shard_id); + auto* block = pool->allocate_t(); + FixedBlockPool::set_key(block, i); + FixedBlockPool::update_timestamp(block); // Initial score + FixedBlockPool::set_used(block, true); + wlock->insert({i, block}); + } + std::this_thread::sleep_for(std::chrono::seconds(5)); + + for (int i = 1000; i < 2000; ++i) { + int shard_id = i % NUM_SHARDS; + auto wlock = kv_store_->by(shard_id).wlock(); + auto* pool = kv_store_->pool_by(shard_id); + auto* block = pool->allocate_t(); + FixedBlockPool::set_key(block, i); + FixedBlockPool::update_timestamp(block); // Initial score + FixedBlockPool::set_used(block, true); + wlock->insert({i, block}); + } + + std::unique_ptr> feature_evict; + int evict_trigger_mode = 2; + int evict_trigger_strategy = 0; + uint32_t ttl = 4; + // feature evict config + FeatureEvictConfig feature_evict_config; + feature_evict_config.trigger_mode = static_cast(evict_trigger_mode); + feature_evict_config.trigger_strategy = static_cast(evict_trigger_strategy); + feature_evict_config.ttl = ttl; + + if (feature_evict_config.trigger_mode != EvictTriggerMode::DISABLED) { + feature_evict = create_feature_evict(feature_evict_config, executor_.get(),*kv_store_.get(), 4); + } + + // Initial validation + size_t total_blocks = 0; + for (int shard_id = 0; shard_id < NUM_SHARDS; ++shard_id) { + auto rlock = kv_store_->by(shard_id).rlock(); + total_blocks += rlock->size(); + } + ASSERT_EQ(total_blocks, 2000); + + // Perform eviction + feature_evict->trigger_evict(); + + // Validate eviction process + while (feature_evict->is_evicting()) { + feature_evict->resume(); + std::this_thread::sleep_for(std::chrono::microseconds(5)); + feature_evict->pause(); + } + + // Validate results + size_t remaining = 0; + for (int shard_id = 0; shard_id < NUM_SHARDS; ++shard_id) { + auto rlock = kv_store_->by(shard_id).rlock(); + remaining += rlock->size(); + } + std::cout << "remaining: " << remaining << std::endl; + ASSERT_EQ(remaining, 1000); +} + +TEST(FeatureEvictTest, TimeCounterBasedEviction) { + static constexpr int NUM_SHARDS = 8; + auto executor_ = std::make_unique(4); + auto kv_store_ = std::make_unique>(NUM_SHARDS, BLOCK_SIZE, BLOCK_ALIGNMENT); + + // Insert test data + for (int i = 0; i < 500; ++i) { + int shard_id = i % NUM_SHARDS; + auto wlock = kv_store_->by(shard_id).wlock(); + auto* pool = kv_store_->pool_by(shard_id); + auto* block = pool->allocate_t(); + FixedBlockPool::set_key(block, i); + FixedBlockPool::update_timestamp(block); // Initial score + FixedBlockPool::set_count(block, 1); + FixedBlockPool::set_used(block, true); + wlock->insert({i, block}); + } + std::this_thread::sleep_for(std::chrono::seconds(5)); + for (int i = 500; i < 1000; ++i) { + int shard_id = i % NUM_SHARDS; + auto wlock = kv_store_->by(shard_id).wlock(); + auto* pool = kv_store_->pool_by(shard_id); + auto* block = pool->allocate_t(); + FixedBlockPool::set_key(block, i); + FixedBlockPool::update_timestamp(block); // Initial score + FixedBlockPool::set_count(block, 1); + FixedBlockPool::set_used(block, true); + wlock->insert({i, block}); + } + + for (int i = 1000; i < 2000; ++i) { + int shard_id = i % NUM_SHARDS; + auto wlock = kv_store_->by(shard_id).wlock(); + auto* pool = kv_store_->pool_by(shard_id); + auto* block = pool->allocate_t(); + FixedBlockPool::set_key(block, i); + FixedBlockPool::update_timestamp(block); // Initial score + FixedBlockPool::set_count(block, 2); + FixedBlockPool::set_used(block, true); + wlock->insert({i, block}); + } + + std::unique_ptr> feature_evict; + int evict_trigger_mode = 2; + int evict_trigger_strategy = 2; + uint32_t ttl = 4; + uint32_t count_threshold = 1; + float count_decay_rate = 0.5; + + // feature evict config + FeatureEvictConfig feature_evict_config; + feature_evict_config.trigger_mode = static_cast(evict_trigger_mode); + feature_evict_config.trigger_strategy = static_cast(evict_trigger_strategy); + feature_evict_config.ttl = ttl; + feature_evict_config.count_threshold = count_threshold; + feature_evict_config.count_decay_rate = count_decay_rate; + + if (feature_evict_config.trigger_mode != EvictTriggerMode::DISABLED) { + feature_evict = create_feature_evict(feature_evict_config, executor_.get(),*kv_store_.get(), 4); + } + + // Initial validation + size_t total_blocks = 0; + for (int shard_id = 0; shard_id < NUM_SHARDS; ++shard_id) { + auto rlock = kv_store_->by(shard_id).rlock(); + total_blocks += rlock->size(); + } + ASSERT_EQ(total_blocks, 2000); + + // Perform eviction + feature_evict->trigger_evict(); + + // Validate eviction process + while (feature_evict->is_evicting()) { + feature_evict->resume(); + std::this_thread::sleep_for(std::chrono::microseconds(5)); + feature_evict->pause(); + } + + // Validate results + size_t remaining = 0; + for (int shard_id = 0; shard_id < NUM_SHARDS; ++shard_id) { + auto rlock = kv_store_->by(shard_id).rlock(); + remaining += rlock->size(); + } + std::cout << "remaining: " << remaining << std::endl; + ASSERT_EQ(remaining, 1500); +} + +TEST(FeatureEvictTest, L2WeightBasedEviction) { + static constexpr int NUM_SHARDS = 8; + auto executor_ = std::make_unique(4); + auto kv_store_ = std::make_unique>(NUM_SHARDS, BLOCK_SIZE, BLOCK_ALIGNMENT); + int dim = 4; + std::vector weight1(dim, 1.0); + // Insert test data + for (int i = 0; i < 1000; ++i) { + int shard_id = i % NUM_SHARDS; + auto wlock = kv_store_->by(shard_id).wlock(); + auto* pool = kv_store_->pool_by(shard_id); + auto* block = pool->allocate_t(); + auto* data_ptr = FixedBlockPool::data_ptr(block); + FixedBlockPool::set_key(block, i); + std::copy(weight1.begin(), weight1.end(), data_ptr); + FixedBlockPool::set_used(block, true); + wlock->insert({i, block}); + } + std::vector weight2(dim, 2.0); + for (int i = 1000; i < 2000; ++i) { + int shard_id = i % NUM_SHARDS; + auto wlock = kv_store_->by(shard_id).wlock(); + auto* pool = kv_store_->pool_by(shard_id); + auto* block = pool->allocate_t(); + auto* data_ptr = FixedBlockPool::data_ptr(block); + FixedBlockPool::set_key(block, i); + std::copy(weight2.begin(), weight2.end(), data_ptr); + FixedBlockPool::set_used(block, true); + wlock->insert({i, block}); + } + + std::unique_ptr> feature_evict; + int evict_trigger_mode = 2; + int evict_trigger_strategy = 3; + double l2_weight_threshold = 3.0; + // feature evict config + FeatureEvictConfig feature_evict_config; + feature_evict_config.trigger_mode = static_cast(evict_trigger_mode); + feature_evict_config.trigger_strategy = static_cast(evict_trigger_strategy); + feature_evict_config.l2_weight_threshold = l2_weight_threshold; + + if (feature_evict_config.trigger_mode != EvictTriggerMode::DISABLED) { + feature_evict = create_feature_evict(feature_evict_config, executor_.get(),*kv_store_.get(), dim); + } + + // Initial validation + size_t total_blocks = 0; + for (int shard_id = 0; shard_id < NUM_SHARDS; ++shard_id) { + auto rlock = kv_store_->by(shard_id).rlock(); + total_blocks += rlock->size(); + } + ASSERT_EQ(total_blocks, 2000); + + // Perform eviction + feature_evict->trigger_evict(); + + // Validate eviction process + while (feature_evict->is_evicting()) { + feature_evict->resume(); + std::this_thread::sleep_for(std::chrono::microseconds(5)); + feature_evict->pause(); + } + + // Validate results + size_t remaining = 0; + for (int shard_id = 0; shard_id < NUM_SHARDS; ++shard_id) { + auto rlock = kv_store_->by(shard_id).rlock(); + remaining += rlock->size(); + } + std::cout << "remaining: " << remaining << std::endl; + ASSERT_EQ(remaining, 1000); +} + +TEST(FeatureEvictTest, PerformanceTest) { + static constexpr int NUM_SHARDS = 1; + // Test configurations + const std::vector test_sizes = {100'000, 500'000, 1'000'000, 5'000'000, 10'000'000}; + + fmt::print("\nPerformance Test Results:\n"); + fmt::print("{:<15} {:<15} {:<15}\n", "Size", "Time(ms)", "Items/ms"); + fmt::print("{:-<45}\n", ""); // 分隔线 + + for (const auto& size : test_sizes) { + // Create executor and store for each test size + auto executor = std::make_unique(8); + auto kv_store = + std::make_unique>(NUM_SHARDS, BLOCK_SIZE, BLOCK_ALIGNMENT, 1000); + + // Insert test data with different initial scores + for (int i = 0; i < size; ++i) { + int shard_id = i % NUM_SHARDS; + auto wlock = kv_store->by(shard_id).wlock(); + auto* pool = kv_store->pool_by(shard_id); + auto* block = pool->allocate_t(); + FixedBlockPool::set_key(block, i); + FixedBlockPool::set_count(block, (i % 2) ? 1 : 2); // Alternate between scores + FixedBlockPool::set_used(block, true); + wlock->insert({i, block}); + } + + // Measure eviction time + std::vector execution_times; + CounterBasedEvict evictor(executor.get(), *kv_store.get(), 0.5f, 1); + + auto start_time = std::chrono::high_resolution_clock::now(); + + // Perform eviction + evictor.trigger_evict(); + evictor.resume(); + while (evictor.is_evicting()) { + std::this_thread::sleep_for(std::chrono::microseconds(1)); + } + + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end_time - start_time).count(); + + std::size_t current_size = 0; + for (int shard_id = 0; shard_id < NUM_SHARDS; ++shard_id) { + auto wlock = kv_store->by(shard_id).wlock(); + current_size += wlock->size(); + } + double eviction_rate = static_cast(size - current_size) / static_cast(size); + + // Print results + fmt::print("{:<15d} {:<15d} {:<15.2f}\n", size, duration, eviction_rate); + } +} +} // namespace kv_mem \ No newline at end of file diff --git a/fbgemm_gpu/test/dram_kv_embedding_cache/fixed_block_pool_saver_test.cpp b/fbgemm_gpu/test/dram_kv_embedding_cache/fixed_block_pool_saver_test.cpp new file mode 100644 index 0000000000..44ef79bc0a --- /dev/null +++ b/fbgemm_gpu/test/dram_kv_embedding_cache/fixed_block_pool_saver_test.cpp @@ -0,0 +1,157 @@ +#include +#include +#include +#include + +#include + +#include "fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h" + +namespace kv_mem { +void removeFileIfExists(const std::string& filename) { + if (std::filesystem::exists(filename)) { + std::filesystem::remove(filename); + } +} +class FixedBlockPoolTest : public ::testing::Test { + protected: + static constexpr size_t kDimension = 128; // embedding dimension + using scalar_t = float; // data type + + void SetUp() override { + block_size_ = kv_mem::FixedBlockPool::calculate_block_size(kDimension); + block_alignment_ = kv_mem::FixedBlockPool::calculate_block_alignment(); + pool_ = std::make_unique(block_size_, block_alignment_); + } + + // Generate random data + void generateRandomData(std::size_t num_blocks) { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution key_dist(1, UINT64_MAX); + std::uniform_real_distribution val_dist(-1.0, 1.0); + + for (size_t i = 0; i < num_blocks; ++i) { + auto* block = pool_->allocate_t(); + uint64_t key = key_dist(gen); + + // Set metadata + kv_mem::FixedBlockPool::set_key(block, key); + kv_mem::FixedBlockPool::set_count(block, i % 100); + kv_mem::FixedBlockPool::update_timestamp(block); + + // Set embedding data + auto* data = kv_mem::FixedBlockPool::data_ptr(block); + for (size_t j = 0; j < kDimension; ++j) { + data[j] = val_dist(gen); + } + + // Record for verification + original_data_[key] = std::vector(data, data + kDimension); + } + } + + // Verify data correctness + bool verifyData() { + size_t verified_count = 0; + + // Traverse all chunks to verify data + for (const auto& chunk : pool_->get_chunks()) { + char* current = static_cast(chunk.ptr); + size_t blocks_in_chunk = chunk.size / block_size_; + + for (size_t i = 0; i < blocks_in_chunk; ++i) { + void* block = current + i * block_size_; + if (kv_mem::FixedBlockPool::get_used(block)) { + uint64_t key = kv_mem::FixedBlockPool::get_key(block); + auto* data = kv_mem::FixedBlockPool::data_ptr(reinterpret_cast(block)); + + // Find and compare original data + auto it = original_data_.find(key); + if (it == original_data_.end()) { + return false; + } + + if (!std::equal(data, data + kDimension, it->second.begin())) { + return false; + } + + verified_count++; + } + } + } + + return verified_count == original_data_.size(); + } + + // Performance test helper function + template + double measureTime(Func&& func) { + auto start = std::chrono::high_resolution_clock::now(); + func(); + auto end = std::chrono::high_resolution_clock::now(); + return std::chrono::duration(end - start).count(); + } + + std::unique_ptr pool_; + size_t block_size_{}; + size_t block_alignment_{}; + std::unordered_map> original_data_; +}; + +// Correctness test +TEST_F(FixedBlockPoolTest, SerializationCorrectness) { + // 1. Generate random data + generateRandomData(1000); + + // 2. Serialize + const std::string filename = "test_pool.bin"; + pool_->serialize(filename); + + // 3. Create a new memory pool and deserialize + auto new_pool = std::make_unique(block_size_, block_alignment_); + new_pool->deserialize(filename); + + // 4. Verify data + pool_ = std::move(new_pool); + EXPECT_TRUE(verifyData()); +} + +// Edge case test +TEST_F(FixedBlockPoolTest, SerializationEdgeCases) { + // 1. Empty pool serialization test + const std::string empty_filename = "empty_pool.bin"; + pool_->serialize(empty_filename); + + auto new_pool = std::make_unique(block_size_, block_alignment_); + EXPECT_NO_THROW(new_pool->deserialize(empty_filename)); + + // 2. File not found test + EXPECT_THROW(pool_->deserialize("nonexistent_file.bin"), std::runtime_error); + + // 3. Parameter mismatch test + generateRandomData(1000); + const std::string filename = "test_pool.bin"; + pool_->serialize(filename); + + auto wrong_pool = std::make_unique(block_size_ * 2, // Incorrect block size + block_alignment_); + EXPECT_THROW(wrong_pool->deserialize(filename), std::invalid_argument); +} + +// Performance test +TEST_F(FixedBlockPoolTest, SerializationPerformance) { + const std::size_t num_blocks = 20'000'000; + generateRandomData(num_blocks); + const std::string filename = "test_pool.bin"; + removeFileIfExists(filename); + + pool_->serialize(filename); + + auto new_pool = std::make_unique(block_size_, block_alignment_); + new_pool->deserialize(filename); + + std::remove(filename.c_str()); +} + +} // namespace kv_mem \ No newline at end of file diff --git a/fbgemm_gpu/test/dram_kv_embedding_cache/fixed_block_pool_test.cpp b/fbgemm_gpu/test/dram_kv_embedding_cache/fixed_block_pool_test.cpp new file mode 100644 index 0000000000..f725d29a77 --- /dev/null +++ b/fbgemm_gpu/test/dram_kv_embedding_cache/fixed_block_pool_test.cpp @@ -0,0 +1,363 @@ +#include "fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +namespace kv_mem { + +double test_std_vector(size_t vector_size, size_t repeat_count) { + float sum = 0.0f; // Prevent optimization + std::vector> all_vectors; // Store all vectors to prevent release + all_vectors.reserve(repeat_count); + + auto start = std::chrono::high_resolution_clock::now(); + for (size_t i = 0; i < repeat_count; ++i) { + all_vectors.emplace_back(vector_size); + auto& vec = all_vectors.back(); + + for (size_t j = 0; j < vector_size; ++j) { + vec[j] = static_cast(j); + } + + // Simple usage to prevent optimization + sum += vec[0]; + } + + auto end = std::chrono::high_resolution_clock::now(); + return std::chrono::duration(end - start).count(); +} + +// Testing memory pool allocation +double test_pool_vector(size_t vector_size, size_t repeat_count) { + // Create a memory pool large enough + FixedBlockPool pool(vector_size * sizeof(float), alignof(float), 8092); + std::pmr::polymorphic_allocator alloc(&pool); + + auto start = std::chrono::high_resolution_clock::now(); + float sum = 0.0f; // Prevent optimization + for (size_t i = 0; i < repeat_count; ++i) { + float* arr = alloc.allocate(vector_size); + + for (size_t j = 0; j < vector_size; ++j) { + arr[j] = static_cast(j); + } + + // Simple usage to prevent optimization + sum += arr[0]; + + // Removed deallocate statement, no longer releasing memory to avoid memory + // reuse + // alloc.deallocate(arr, dim); + } + + auto end = std::chrono::high_resolution_clock::now(); + return std::chrono::duration(end - start).count(); +} + +void benchmark_memory_allocators() { + fmt::print( + "====== Testing performance difference between memory pool and " + "native vector allocation for 10 million times ======\n"); + + // Vector sizes to test (in number of float elements) + std::vector vector_sizes = {4, 8, 16, 32, 64, 128, 256}; + + // Repeat count (10 million times) + const size_t repeat_count = 10'000'000; + + for (const auto& size : vector_sizes) { + fmt::print("Vector size: {} floats ({} bytes)\n", size, size * sizeof(float)); + // Testing standard vector + double std_time = test_std_vector(size, repeat_count); + fmt::print(" Standard vector: {:.2f} ms\n", std_time); + + // Testing memory pool + double pool_time = test_pool_vector(size, repeat_count); + fmt::print(" Memory pool: {:.2f} ms\n", pool_time); + + // Calculate speed improvement + double speedup = std_time / pool_time; + fmt::print(" Speed improvement: {:.2f}x\n\n", speedup); + fmt::print("============================\n"); + } +} + +// Basic functionality test: Integer keys +TEST(FixedBlockPoolTest, benchmark_memory_allocators) { benchmark_memory_allocators(); } + +// Test constructor normal case +TEST(FixedBlockPoolTest, ConstructorNormal) { + EXPECT_NO_THROW({ kv_mem::FixedBlockPool pool(16, 8); }); +} + +// Test constructor exception cases +TEST(FixedBlockPoolTest, ConstructorExceptions) { + // Block size smaller than pointer size + EXPECT_THROW({ kv_mem::FixedBlockPool pool(1, 1); }, std::invalid_argument); + + // Alignment not a power of 2 + EXPECT_THROW({ kv_mem::FixedBlockPool pool(16, 3); }, std::invalid_argument); + + // Block size not a multiple of alignment + EXPECT_THROW({ kv_mem::FixedBlockPool pool(10, 8); }, std::invalid_argument); +} + +// Test basic memory allocation and deallocation +TEST(FixedBlockPoolTest, BasicAllocation) { + const size_t block_size = 16; + const size_t alignment = 8; + kv_mem::FixedBlockPool pool(block_size, alignment); + + void* p = pool.allocate(block_size, alignment); + EXPECT_NE(p, nullptr); + + // Verify allocated memory is usable + std::memset(p, 0xAB, block_size); + + pool.deallocate(p, block_size, alignment); +} + +// Test multiple allocations and deallocations +TEST(FixedBlockPoolTest, MultipleAllocations) { + const size_t block_size = 32; + const size_t alignment = 8; + kv_mem::FixedBlockPool pool(block_size, alignment); + + std::vector blocks; + const int NUM_BLOCKS = 100; + + // Allocate multiple blocks + for (int i = 0; i < NUM_BLOCKS; ++i) { + void* p = pool.allocate(block_size, alignment); + EXPECT_NE(p, nullptr); + // Write some data + *static_cast(p) = i; + blocks.push_back(p); + } + + // Verify data + for (int i = 0; i < NUM_BLOCKS; ++i) { + EXPECT_EQ(*static_cast(blocks[i]), i); + } + + // Release all blocks + for (auto p : blocks) { + pool.deallocate(p, block_size, alignment); + } +} + +// Test cross-chunk allocation (each chunk has only 10 blocks) +TEST(FixedBlockPoolTest, CrossChunkAllocation) { + const size_t block_size = 16; + const size_t alignment = 8; + const size_t blocks_per_chunk = 10; + kv_mem::FixedBlockPool pool(block_size, alignment, blocks_per_chunk); + + std::vector blocks; + const int NUM_BLOCKS = 25; // Exceeds 2 chunks + + // Allocate blocks beyond a single chunk capacity + for (int i = 0; i < NUM_BLOCKS; ++i) { + void* p = pool.allocate(block_size, alignment); + EXPECT_NE(p, nullptr); + blocks.push_back(p); + } + + // Release all blocks + for (auto p : blocks) { + pool.deallocate(p, block_size, alignment); + } +} + +// Test memory alignment +TEST(FixedBlockPoolTest, MemoryAlignment) { + const size_t block_size = 64; + const size_t alignment = 32; + kv_mem::FixedBlockPool pool(block_size, alignment); + + void* p = pool.allocate(block_size, alignment); + EXPECT_NE(p, nullptr); + + // Verify address is aligned to specified alignment + uintptr_t addr = reinterpret_cast(p); + EXPECT_EQ(addr % alignment, 0); + + pool.deallocate(p, block_size, alignment); +} + +// Test error handling - allocating blocks with mismatched size or alignment +TEST(FixedBlockPoolTest, ErrorHandling) { + const size_t block_size = 16; + const size_t alignment = 8; + kv_mem::FixedBlockPool pool(block_size, alignment); + + // Try to allocate memory with incorrect size + EXPECT_THROW({ [[maybe_unused]] void* p = pool.allocate(block_size * 2, alignment); }, std::bad_alloc); + + // Try to allocate memory with incorrect alignment + EXPECT_THROW({ [[maybe_unused]] void* p = pool.allocate(block_size, alignment * 2); }, std::bad_alloc); +} + +// Test memory reuse after deallocation +TEST(FixedBlockPoolTest, ReuseAfterDeallocation) { + const size_t block_size = 16; + const size_t alignment = 8; + kv_mem::FixedBlockPool pool(block_size, alignment); + + void* p1 = pool.allocate(block_size, alignment); + void* p2 = pool.allocate(block_size, alignment); + + // Release the first block + pool.deallocate(p1, block_size, alignment); + + // Reallocate, should get the recently freed block (due to LIFO order) + void* p3 = pool.allocate(block_size, alignment); + EXPECT_EQ(p3, p1); + + // Cleanup + pool.deallocate(p2, block_size, alignment); + pool.deallocate(p3, block_size, alignment); +} + +// Test custom upstream memory resource +TEST(FixedBlockPoolTest, CustomUpstreamResource) { + const size_t block_size = 16; + const size_t alignment = 8; + + // Use custom memory resource that tracks allocations + int allocate_count = 0; + int deallocate_count = 0; + + class CountingResource : public std::pmr::memory_resource { + public: + CountingResource(int& alloc_count, int& dealloc_count) : alloc_count_(alloc_count), dealloc_count_(dealloc_count) {} + + protected: + void* do_allocate(size_t bytes, size_t alignment) override { + ++alloc_count_; + return std::pmr::new_delete_resource()->allocate(bytes, alignment); + } + + void do_deallocate(void* p, size_t bytes, size_t alignment) override { + ++dealloc_count_; + std::pmr::new_delete_resource()->deallocate(p, bytes, alignment); + } + + bool do_is_equal(const std::pmr::memory_resource& other) const noexcept override { return this == &other; } + + private: + int& alloc_count_; + int& dealloc_count_; + }; + + CountingResource upstream(allocate_count, deallocate_count); + { + kv_mem::FixedBlockPool pool(block_size, alignment, 1024, &upstream); + + // Allocate some blocks to trigger chunk allocation + std::vector blocks; + for (int i = 0; i < 10; ++i) { + blocks.push_back(pool.allocate(block_size, alignment)); + } + + // Verify upstream resource was called + EXPECT_GT(allocate_count, 0); + EXPECT_EQ(deallocate_count, 0); + + // Release all blocks + for (auto p : blocks) { + pool.deallocate(p, block_size, alignment); + } + } + // Destructor should release all chunks + EXPECT_GT(deallocate_count, 0); +} + +TEST(FixedBlockPool, BasicFunctionality) { + constexpr int dim = 4; + size_t block_size = FixedBlockPool ::calculate_block_size(dim); + size_t alignment = FixedBlockPool::calculate_block_alignment(); + + // Initialize memory pool + FixedBlockPool pool(block_size, alignment, 1024); + + // Test memory allocation + auto* block = pool.allocate_t(); + FixedBlockPool::update_timestamp(block); + ASSERT_NE(block, nullptr); + + // Verify metadata header + int64_t ts1 = FixedBlockPool::get_timestamp(block); + EXPECT_LE(FixedBlockPool::current_timestamp(), ts1); + + // Test data pointer offset + float* data = FixedBlockPool::data_ptr(block); + ASSERT_EQ(reinterpret_cast(data) - reinterpret_cast(block), sizeof(FixedBlockPool::MetaHeader)); + + // Test timestamp update + FixedBlockPool::update_timestamp(block); + int64_t ts2 = FixedBlockPool::get_timestamp(block); + EXPECT_GE(ts2, ts1); // New timestamp should be greater or equal + + // Test memory deallocation + EXPECT_NO_THROW(pool.deallocate_t(block)); +} + +TEST(FixedBlockPool, MultiDimensionTest) { + // Test memory alignment for different dimensions + const std::vector test_dims = {1, 4, 16, 64, 256}; + for (int dim : test_dims) { + size_t block_size = FixedBlockPool::calculate_block_size(dim); + size_t alignment = FixedBlockPool::calculate_block_alignment(); + + // Verify alignment requirements + EXPECT_EQ(alignment % alignof(FixedBlockPool::MetaHeader), 0); + EXPECT_EQ(alignment % alignof(float), 0); + + // Verify block size calculation + const size_t expected_size = sizeof(FixedBlockPool::MetaHeader) + dim * sizeof(float); + EXPECT_EQ(block_size, expected_size); + } +} + +TEST(FixedBlockPool, TimestampPrecision) { + // Test timestamp precision accuracy + constexpr int test_iterations = 1000; + int64_t prev_ts = FixedBlockPool::current_timestamp(); + + for (int i = 0; i < test_iterations; ++i) { + int64_t curr_ts = FixedBlockPool::current_timestamp(); + EXPECT_GE(curr_ts, + prev_ts); // Timestamps should be monotonically increasing + prev_ts = curr_ts; + } +} + +TEST(FixedBlockPool, DataIntegrity) { + // Test data storage integrity + constexpr int dim = 8; + std::vector src_data(dim, 3.14f); + + size_t block_size = FixedBlockPool::calculate_block_size(dim); + size_t alignment = FixedBlockPool::calculate_block_alignment(); + FixedBlockPool pool(block_size, alignment, 1024); + + // Allocate and write data + auto* block = pool.allocate_t(); + auto* data_ptr = FixedBlockPool::data_ptr(block); + std::copy(src_data.begin(), src_data.end(), data_ptr); + + // Verify data consistency + for (int i = 0; i < dim; ++i) { + EXPECT_FLOAT_EQ(data_ptr[i], src_data[i]); + } + pool.deallocate_t(block); +} + +} // namespace kv_mem \ No newline at end of file diff --git a/fbgemm_gpu/test/dram_kv_embedding_cache/sharded_map_test.cpp b/fbgemm_gpu/test/dram_kv_embedding_cache/sharded_map_test.cpp new file mode 100644 index 0000000000..4445d1d4a2 --- /dev/null +++ b/fbgemm_gpu/test/dram_kv_embedding_cache/sharded_map_test.cpp @@ -0,0 +1,229 @@ +#include +#include + +#include +#include +#include +#include + +#include "fbgemm_gpu/src/dram_kv_embedding_cache/SynchronizedShardedMap.h" +#include "fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h" + +namespace kv_mem { +std::vector generateFixedEmbedding(int dimension) { return std::vector(dimension, 1.0); } + +void memPoolEmbedding(int dimension, size_t numInserts, size_t numLookups) { + const size_t numShards = 1; + + SynchronizedShardedMap embeddingMap(numShards, + dimension * sizeof(float), // block_size + alignof(float), // block_alignment + 8192); // blocks_per_chunk + double insertTime, lookupTime; + { + std::vector fixedEmbedding = generateFixedEmbedding(dimension); + + auto wlock = embeddingMap.by(0).wlock(); + auto* pool = embeddingMap.pool_by(0); + std::pmr::polymorphic_allocator alloc(pool); + + auto startInsert = std::chrono::high_resolution_clock::now(); + for (size_t i = 0; i < numInserts; i++) { + float* arr = alloc.allocate(dimension); + std::copy(fixedEmbedding.begin(), fixedEmbedding.end(), arr); + wlock->insert_or_assign(i, arr); + } + auto endInsert = std::chrono::high_resolution_clock::now(); + insertTime = std::chrono::duration(endInsert - startInsert).count(); + } + + std::vector lookEmbedding(dimension); + size_t hitCount = 0; + { + auto rlock = embeddingMap.by(0).rlock(); + auto startLookup = std::chrono::high_resolution_clock::now(); + for (size_t i = 0; i < numLookups; i++) { + auto it = rlock->find(i % numInserts); + if (it != rlock->end()) { + hitCount++; + std::copy(it->second, it->second + dimension, lookEmbedding.data()); + } + } + auto endLookup = std::chrono::high_resolution_clock::now(); + lookupTime = std::chrono::duration(endLookup - startLookup).count(); + } + + fmt::print("{:<20}{:<20.2f}{:<20.2f}{:<20.2f}\n", + dimension, + insertTime, + lookupTime, + 100.0 * static_cast(hitCount) / static_cast(numLookups)); +} + +void memPoolEmbeddingWithTime(int dimension, size_t numInserts, size_t numLookups) { + const size_t numShards = 1; + size_t block_size = FixedBlockPool::calculate_block_size(dimension); + size_t block_alignment = FixedBlockPool::calculate_block_alignment(); + + SynchronizedShardedMap embeddingMap(numShards, + block_size, // block_size + block_alignment, // block_alignment + 8192); // blocks_per_chunk + double insertTime, lookupTime; + { + std::vector fixedEmbedding = generateFixedEmbedding(dimension); + + auto wlock = embeddingMap.by(0).wlock(); + auto* pool = embeddingMap.pool_by(0); + + auto startInsert = std::chrono::high_resolution_clock::now(); + for (size_t i = 0; i < numInserts; i++) { + auto* block = pool->allocate_t(); + auto* data_ptr = FixedBlockPool::data_ptr(block); + std::copy(fixedEmbedding.begin(), fixedEmbedding.end(), data_ptr); + wlock->insert_or_assign(i, block); + } + auto endInsert = std::chrono::high_resolution_clock::now(); + insertTime = std::chrono::duration(endInsert - startInsert).count(); + } + + std::vector lookEmbedding(dimension); + size_t hitCount = 0; + { + auto rlock = embeddingMap.by(0).rlock(); + auto startLookup = std::chrono::high_resolution_clock::now(); + for (size_t i = 0; i < numLookups; i++) { + auto it = rlock->find(i % numInserts); + if (it != rlock->end()) { + hitCount++; + const float* data_ptr = FixedBlockPool::data_ptr(it->second); + // update timestamp + FixedBlockPool::update_timestamp(it->second); + std::copy(data_ptr, data_ptr + dimension, lookEmbedding.data()); + } + } + auto endLookup = std::chrono::high_resolution_clock::now(); + lookupTime = std::chrono::duration(endLookup - startLookup).count(); + } + + // 替换输出部分 + fmt::print("{:<20}{:<20.2f}{:<20.2f}{:<20.2f}\n", + dimension, + insertTime, + lookupTime, + 100.0 * static_cast(hitCount) / static_cast(numLookups)); +} + +void memPoolEmbeddingMemSize(int dimension, size_t numInserts) { + const size_t numShards = 4; + size_t block_size = FixedBlockPool::calculate_block_size(dimension); + size_t block_alignment = FixedBlockPool::calculate_block_alignment(); + + SynchronizedShardedMap embeddingMap(numShards, + block_size, // block_size + block_alignment, // block_alignment + 8192); // blocks_per_chunk + { + std::vector fixedEmbedding = generateFixedEmbedding(dimension); + + auto wlock = embeddingMap.by(0).wlock(); + auto* pool = embeddingMap.pool_by(0); + + for (size_t i = 0; i < numInserts; i++) { + auto* block = pool->allocate_t(); + auto* data_ptr = FixedBlockPool::data_ptr(block); + std::copy(fixedEmbedding.begin(), fixedEmbedding.end(), data_ptr); + wlock->insert_or_assign(i, block); + } + } + size_t totalMemory = embeddingMap.getUsedMemSize(); + fmt::print("{:<20}{:<20}{:<20.2f}\n", + dimension, + numInserts, + static_cast(totalMemory) / (1024 * 1024)); // MB + +} + +int benchmark() { + std::vector dimensions = {4, 8, 16, 32, 64}; + const size_t numInserts = 1'000'000; // 1 million insert + const size_t numLookups = 1'000'000; // 1 million find + + fmt::print("======================= mempool ====================================\n"); + fmt::print("{:<20}{:<20}{:<20}{:<20}\n", "dim", "insert time (ms)", "find time (ms)", "hit rate (%)"); + for (int dim : dimensions) { + memPoolEmbedding(dim, numInserts, numLookups); + } + fmt::print("\n\n"); + std::fflush(stdout); + + fmt::print("======================= mempool with time ====================================\n"); + fmt::print("{:<20}{:<20}{:<20}{:<20}\n", "dim", "insert time (ms)", "find time (ms)", "hit rate (%)"); + for (int dim : dimensions) { + memPoolEmbeddingWithTime(dim, numInserts, numLookups); + } + fmt::print("\n\n"); + + fmt::print("======================= memory usage statistics ====================================\n"); + fmt::print("{:<20}{:<20}{:<20}\n","dim", "numInserts", "total memory (MB)"); + for (int dim : dimensions) { + memPoolEmbeddingMemSize(dim, numInserts); + } + return 0; +} + +void save_and_restore() { + const int numShards = 4; + const std::size_t dimension = 32; + const std::size_t block_size = FixedBlockPool::calculate_block_size(dimension); + const std::size_t block_alignment = FixedBlockPool::calculate_block_alignment(); + const int numItems = 1'000'000; + const std::string filename = "test_map.bin"; + + SynchronizedShardedMap original_map(numShards, block_size, block_alignment); + + std::vector test_embedding = generateFixedEmbedding(dimension); + for (int i = 0; i < numItems; ++i) { + int shard_id = i % numShards; + auto wlock = original_map.by(shard_id).wlock(); + auto* pool = original_map.pool_by(shard_id); + + auto* block = pool->allocate_t(); + auto* data_ptr = FixedBlockPool::data_ptr(block); + std::copy(test_embedding.begin(), test_embedding.end(), data_ptr); + + FixedBlockPool::set_key(block, i); + wlock->insert({i, block}); + } + + original_map.save(filename); + + SynchronizedShardedMap restored_map(numShards, block_size, block_alignment); + restored_map.load(filename); + + for (int64_t i = 0; i < numItems; ++i) { + int shard_id = i % numShards; + auto rlock = restored_map.by(shard_id).rlock(); + + auto it = rlock->find(i); + ASSERT_NE(it, rlock->end()) << "Key " << i << " not found after load"; + + float* block = it->second; + ASSERT_EQ(FixedBlockPool::get_key(block), i); + + const float* data_ptr = FixedBlockPool::data_ptr(block); + for (std::size_t j = 0; j < dimension; ++j) { + ASSERT_FLOAT_EQ(data_ptr[j], test_embedding[j]) << "Data mismatch at position " << j << " for key " << i; + } + } + + std::remove(filename.c_str()); + for (int i = 0; i < numShards; ++i) { + std::remove((filename + ".pool." + std::to_string(i)).c_str()); + } +}; + +TEST(SynchronizedShardedMap, save_and_restore) { save_and_restore(); } +TEST(SynchronizedShardedMap, benchmark) { benchmark(); } + +} // namespace kv_mem \ No newline at end of file diff --git a/test/QuantUtilsTest.cc b/test/QuantUtilsTest.cc index fdd9af4ebd..6ea7dd12aa 100644 --- a/test/QuantUtilsTest.cc +++ b/test/QuantUtilsTest.cc @@ -560,7 +560,7 @@ class EmbeddingQuantizeFixedNumberTest : public testing::TestWithParam { 1, 1, 1, 1, // All the same. Range: 0, min: 1 -64, -2.75, 61.625, 191, // Range: 255, min: -64. Picking 61.625 because it differs under FP16 (will become 61.5). }; - assert(float_test_input.size() == row * col); + assert(float_test_input.size() == static_cast(row * col)); float16_test_input.resize(float_test_input.size()); std::transform(