Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 78 additions & 7 deletions fbgemm_gpu/src/dram_kv_embedding_cache/SynchronizedShardedMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

#pragma once

#include <folly/Synchronized.h>
#include <folly/container/F14Map.h>
#include "folly/Synchronized.h"

#include "fixed_block_pool.h"

namespace kv_mem {

Expand All @@ -29,18 +31,87 @@ class SynchronizedShardedMap {
public:
using iterator = typename folly::F14FastMap<K, V>::const_iterator;

explicit SynchronizedShardedMap(std::size_t numShards) : shards_(numShards) {}
explicit SynchronizedShardedMap(std::size_t numShards,
std::size_t block_size,
std::size_t block_alignment,
std::size_t blocks_per_chunk = 8192)
: shards_(numShards), mempools_(numShards) {
// Init mempools_
for (auto& pool : mempools_) {
pool = std::make_unique<kv_mem::FixedBlockPool>(
block_size, block_alignment, blocks_per_chunk);
}
}

// Get shard map by index
auto& by(int index) {
return shards_.at(index % shards_.size());
auto& by(int index) { return shards_.at(index % shards_.size()); }

// Get shard pool by index
auto* pool_by(int index) {
return mempools_.at(index % shards_.size()).get();
}

auto getNumShards() {
return shards_.size();
auto getNumShards() const { return shards_.size(); }

auto getUsedMemSize() const {
size_t used_mem_size = 0;
size_t block_size = mempools_[0]->get_aligned_block_size();
for (size_t i = 0; i < shards_.size(); ++i) {
auto rlmap = shards_[i].rlock();
// only calculate the sizes of K, V and block that are used
used_mem_size += rlmap->size() * (sizeof(K) + sizeof(V) + block_size);
}
return used_mem_size;
}

void save(const std::string& filename) const {
std::ofstream out(filename, std::ios::binary);
if (!out) {
throw std::runtime_error("Failed to open file for writing");
}

const std::size_t num_shards = getNumShards();
out.write(reinterpret_cast<const char*>(&num_shards), sizeof(num_shards));
out.close();

// save every mempool
for (std::size_t shard_id = 0; shard_id < getNumShards(); ++shard_id) {
std::string pool_filename = filename + ".pool." + std::to_string(i);
auto wlock = shards_[shard_id].wlock();
mempools_[shard_id]->serialize(pool_filename);
}
}

void load(const std::string& filename) {
std::ifstream in(filename, std::ios::binary);
if (!in) {
throw std::runtime_error("Failed to open file for reading");
}

size_t num_shards;
in.read(reinterpret_cast<char*>(&num_shards), sizeof(num_shards));
in.close();

if (num_shards != getNumShards()) {
throw std::runtime_error("Shard count mismatch between file and map");
}

for (std::size_t shard_id = 0; shard_id < getNumShards(); ++shard_id) {
std::string pool_filename = filename + ".pool." + std::to_string(i);
auto wlock = shards_[shard_id].wlock();
// first deserialize mempool
mempools_[shard_id]->deserialize(pool_filename);
// load map from mempool
wlock->clear();
mempools_[shard_id]->for_each_block([&wlock](void* block) {
auto key = FixedBlockPool::get_key(block);
wlock->emplace(key, reinterpret_cast<V>(block));
});
}
}

private:
std::vector<folly::Synchronized<folly::F14FastMap<K, V>, M>> shards_;
std::vector<std::unique_ptr<FixedBlockPool>> mempools_;
};
} // namespace kv_mem
} // namespace kv_mem
100 changes: 79 additions & 21 deletions fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

#include "SynchronizedShardedMap.h"
#include "deeplearning/fbgemm/fbgemm_gpu/src/ssd_split_embeddings_cache/initializer.h"
#include "store_value.h"
#include "fixed_block_pool.h"
#include "feature_evict.h"

#include <ATen/core/ivalue.h>
#include <caffe2/torch/fb/distributed/wireSerializer/WireSerializer.h>
Expand Down Expand Up @@ -46,6 +47,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
/// @param max_D the maximum dimension of of embedding tensor
/// @param uniform_init_lower the lower bound of the uniform distribution
/// @param uniform_init_upper the upper bound of the uniform distribution
/// @param feature_evict_config feature evict config
/// @param num_shards number of shards for the kvstore. This is to improve
/// parallelization. Each key value pair will be sharded into one shard.
/// @param num_threads num of threads that kvstore needs to be run upon for
Expand All @@ -59,6 +61,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
int64_t max_D,
double uniform_init_lower,
double uniform_init_upper,
FeatureEvictConfig feature_evict_config,
int64_t num_shards = 8,
int64_t num_threads = 32,
int64_t row_storage_bitwidth = 32,
Expand All @@ -68,10 +71,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
max_D,
0), // l2_cache_size_gb =0 to disable l2 cache
max_D_(max_D),
feature_evict_config_(feature_evict_config),
num_shards_(num_shards),
weight_ttl_in_hours_(weight_ttl_in_hours),
kv_store_(SynchronizedShardedMap<int64_t, StoreValue<weight_type>>(
num_shards_)),
block_size_(FixedBlockPool::calculate_block_size<weight_type>(max_D)),
block_alignment_(FixedBlockPool::calculate_block_alignment<weight_type>()),
kv_store_(SynchronizedShardedMap<int64_t, weight_type*>(
num_shards_,
block_size_,
block_alignment_,
/*blocks_per_chunk=*/8192)),
elem_size_(row_storage_bitwidth / 8) {
executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(std::max<size_t>(
num_threads, facebook::Proc::getCpuInfo().numCpuCores));
Expand All @@ -81,6 +90,9 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
uniform_init_lower,
uniform_init_upper,
row_storage_bitwidth);
if (feature_evict_config_.trigger_mode != EvictTriggerMode::DISABLED) {
feature_evict_ = create_feature_evict(feature_evict_config_, executor_.get(), kv_store_, max_D);
}
}

void initialize_initializers(
Expand Down Expand Up @@ -185,20 +197,34 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
CHECK_EQ(indices.size(0), weights.size(0));
{
auto wlmap = kv_store_.by(shard_id).wlock();
auto* pool = kv_store_.pool_by(shard_id);
auto indices_data_ptr = indices.data_ptr<index_t>();
for (auto index_iter = indexes.begin();
index_iter != indexes.end();
index_iter++) {
const auto& id_index = *index_iter;
auto id = int64_t(indices_data_ptr[id_index]);
wlmap->try_emplace(
id,
StoreValue<weight_type>(std::vector<weight_type>(
weights[id_index]
.template data_ptr<weight_type>(),
weights[id_index]
.template data_ptr<weight_type>() +
weights[id_index].numel())));
// use mempool
weight_type* block = nullptr;
// First check if the key already exists
auto it = wlmap->find(id);
if (it != wlmap->end()) {
block = it->second;
} else {
// Key doesn't exist, allocate new block and insert.
block = pool->allocate_t();
wlmap->insert({id, block});
}
if (feature_evict_) {
feature_evict_->update_feature_statistics(block);
}
auto* data_ptr = FixedBlockPool::data_ptr<weight_type>(block);
std::copy(weights[id_index]
.template data_ptr<weight_type>(),
weights[id_index]
.template data_ptr<weight_type>() +
weights[id_index].numel(),
data_ptr);
}
}
});
Expand Down Expand Up @@ -276,16 +302,12 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
row_storage_data_ptr));
continue;
}
const auto& cache_results =
cached_iter->second.getValueAndPromote();
CHECK_EQ(cache_results.size(), max_D_);
// use mempool
const auto* data_ptr = FixedBlockPool::data_ptr<weight_type>(cached_iter->second);
std::copy(
reinterpret_cast<const weight_type*>(
&(cache_results[0])),
reinterpret_cast<const weight_type*>(
&(cache_results[max_D_])),
&(weights_data_ptr
[id_index * max_D_])); // dst_start
data_ptr,
data_ptr + max_D_,
&(weights_data_ptr[id_index * max_D_])); // dst_start
}
}
});
Expand All @@ -307,6 +329,36 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {

void compact() override {}

void trigger_feature_evict() {
if (feature_evict_) {
feature_evict_->trigger_evict();
}
}

void feature_evict_resume() {
if (feature_evict_) {
feature_evict_->resume();
}
}

void feature_evict_pause() {
if (feature_evict_) {
feature_evict_->pause();
}
}

void maybe_evict_by_step() {
if (feature_evict_config_.trigger_mode == EvictTriggerMode::ITERATION &&
feature_evict_config_.trigger_step_interval > 0 &&
++current_iter_ % feature_evict_config_.trigger_step_interval == 0) {
trigger_feature_evict();
}
}

size_t get_map_used_memsize() const {
return kv_store_.getUsedMemSize();
}

private:
void fill_from_row_storage(
int shard_id,
Expand Down Expand Up @@ -368,10 +420,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
int64_t max_D_;
int64_t num_shards_;
int64_t weight_ttl_in_hours_;
SynchronizedShardedMap<int64_t, StoreValue<weight_type>> kv_store_;
// mempool params
size_t block_size_;
size_t block_alignment_;
SynchronizedShardedMap<int64_t, weight_type*> kv_store_;
std::atomic_bool is_eviction_ongoing_ = false;
std::vector<std::unique_ptr<ssd::Initializer>> initializers_;
int64_t elem_size_;
FeatureEvictConfig feature_evict_config_;
std::unique_ptr<FeatureEvict<weight_type>> feature_evict_;
int current_iter_ = 0;
}; // class DramKVEmbeddingCache

} // namespace kv_mem
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,34 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
int64_t max_D,
double uniform_init_lower,
double uniform_init_upper,
int evict_trigger_mode,
int evict_trigger_strategy,
int64_t trigger_step_interval,
uint32_t ttl,
uint32_t count_threshold,
float count_decay_rate,
double l2_weight_threshold,
int64_t num_shards = 8,
int64_t num_threads = 32,
int64_t row_storage_bitwidth = 32,
int64_t weight_ttl_in_hours = 2) {

// feature evict config
FeatureEvictConfig feature_evict_config;
feature_evict_config.trigger_mode = static_cast<EvictTriggerMode>(evict_trigger_mode);
feature_evict_config.trigger_strategy = static_cast<EvictTriggerStrategy>(evict_trigger_strategy);
feature_evict_config.trigger_step_interval = trigger_step_interval;
feature_evict_config.ttl = ttl;
feature_evict_config.count_threshold = count_threshold;
feature_evict_config.count_decay_rate = count_decay_rate;
feature_evict_config.l2_weight_threshold = l2_weight_threshold;

if (row_storage_bitwidth == 16) {
impl_ = std::make_shared<kv_mem::DramKVEmbeddingCache<at::Half>>(
max_D,
uniform_init_lower,
uniform_init_upper,
feature_evict_config,
num_shards,
num_threads,
row_storage_bitwidth,
Expand All @@ -44,6 +63,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
max_D,
uniform_init_lower,
uniform_init_upper,
feature_evict_config,
num_shards,
num_threads,
row_storage_bitwidth,
Expand All @@ -67,7 +87,11 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
}

void set(at::Tensor indices, at::Tensor weights, at::Tensor count) {
return impl_->set(indices, weights, count);
impl_->feature_evict_pause();
impl_->set(indices, weights, count);
// when use ITERATION EvictTriggerMode, trigger evict by step
impl_->maybe_evict_by_step();
impl_->feature_evict_resume();
}

void flush() {
Expand All @@ -86,7 +110,9 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
at::Tensor weights,
at::Tensor count,
int64_t sleep_ms) {
return impl_->get(indices, weights, count, sleep_ms);
impl_->feature_evict_pause();
impl_->get(indices, weights, count, sleep_ms);
impl_->feature_evict_resume();
}

void wait_util_filling_work_done() {
Expand All @@ -97,6 +123,10 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
return impl_->get_keys_in_range(start, end);
}

size_t get_map_used_memsize() const {
return impl_->get_map_used_memsize();
}

private:
// friend class EmbeddingRocksDBWrapper;
friend class ssd::KVTensorWrapper;
Expand Down
Loading