Skip to content

Commit

Permalink
Fix clang-tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
Beihao-Zhou committed Jul 7, 2024
1 parent c6ad9f4 commit 5e2efc1
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 209 deletions.
154 changes: 78 additions & 76 deletions src/search/hnsw_indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@

namespace redis {

Node::Node(const NodeKey& key, uint16_t level) : key(key), level(level) {}
Node::Node(NodeKey key, uint16_t level) : key(std::move(key)), level(level) {}

StatusOr<HnswNodeFieldMetadata> Node::DecodeMetadata(const SearchKey& search_key, engine::Storage* storage) {
StatusOr<HnswNodeFieldMetadata> Node::DecodeMetadata(const SearchKey& search_key, engine::Storage* storage) const {
auto node_index_key = search_key.ConstructHnswNode(level, key);
rocksdb::PinnableSlice value;
auto s = storage->Get(rocksdb::ReadOptions(), storage->GetCFHandle(ColumnFamilyID::Search), node_index_key, &value);
Expand All @@ -49,7 +49,7 @@ StatusOr<HnswNodeFieldMetadata> Node::DecodeMetadata(const SearchKey& search_key
}

void Node::PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const {
std::string updated_metadata;
node_meta->Encode(&updated_metadata);
batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), search_key.ConstructHnswNode(level, key), updated_metadata);
Expand Down Expand Up @@ -96,10 +96,10 @@ Status Node::RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& sear
return Status::OK();
}

VectorItem::VectorItem(const NodeKey& key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata)
: key(key), vector(std::move(vector)), metadata(metadata) {}
VectorItem::VectorItem(const NodeKey& key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata)
: key(key), vector(std::move(vector)), metadata(metadata) {}
VectorItem::VectorItem(NodeKey key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata)
: key(std::move(key)), vector(vector), metadata(metadata) {}
VectorItem::VectorItem(NodeKey key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata)
: key(std::move(key)), vector(std::move(vector)), metadata(metadata) {}

bool VectorItem::operator==(const VectorItem& other) const { return key == other.key; }

Expand Down Expand Up @@ -146,23 +146,25 @@ StatusOr<double> ComputeSimilarity(const VectorItem& left, const VectorItem& rig
}

HnswIndex::HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, engine::Storage* storage)
: search_key_(search_key), metadata_(vector), storage_(storage) {
m_level_normalization_factor_ = 1.0 / std::log(metadata_->m);
: search_key(search_key),
metadata(vector),
storage(storage),
m_level_normalization_factor(1.0 / std::log(metadata->m)) {
std::random_device rand_dev;
generator_ = std::mt19937(rand_dev());
generator = std::mt19937(rand_dev());
}

uint16_t HnswIndex::RandomizeLayer() {
std::uniform_real_distribution<double> level_dist(0.0, 1.0);
double r = level_dist(generator_);
double r = level_dist(generator);
double log_val = -std::log(r);
double layer_val = log_val * m_level_normalization_factor_;
double layer_val = log_val * m_level_normalization_factor;
return static_cast<uint16_t>(std::floor(layer_val));
}

StatusOr<HnswIndex::NodeKey> HnswIndex::DefaultEntryPoint(uint16_t level) {
auto prefix = search_key_.ConstructHnswLevelNodePrefix(level);
util::UniqueIterator it(storage_, storage_->DefaultScanOptions(), ColumnFamilyID::Search);
StatusOr<HnswIndex::NodeKey> HnswIndex::DefaultEntryPoint(uint16_t level) const {
auto prefix = search_key.ConstructHnswLevelNodePrefix(level);
util::UniqueIterator it(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search);
it->Seek(prefix);

Slice node_key;
Expand Down Expand Up @@ -193,51 +195,52 @@ StatusOr<std::vector<VectorItem>> HnswIndex::DecodeNodesToVectorItems(const std:
}

Status HnswIndex::AddEdge(const NodeKey& node_key1, const NodeKey& node_key2, uint16_t layer,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto edge_index_key1 = search_key_.ConstructHnswEdge(layer, node_key1, node_key2);
auto s = batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge_index_key1, Slice());
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const {
auto edge_index_key1 = search_key.ConstructHnswEdge(layer, node_key1, node_key2);
auto s = batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key1, Slice());
if (!s.ok()) {
return {Status::NotOK, fmt::format("failed to add edge, {}", s.ToString())};
}

auto edge_index_key2 = search_key_.ConstructHnswEdge(layer, node_key2, node_key1);
s = batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge_index_key2, Slice());
auto edge_index_key2 = search_key.ConstructHnswEdge(layer, node_key2, node_key1);
s = batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key2, Slice());
if (!s.ok()) {
return {Status::NotOK, fmt::format("failed to add edge, {}", s.ToString())};
}
return Status::OK();
}

Status HnswIndex::RemoveEdge(const NodeKey& node_key1, const NodeKey& node_key2, uint16_t layer,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto edge_index_key1 = search_key_.ConstructHnswEdge(layer, node_key1, node_key2);
auto s = batch->Delete(storage_->GetCFHandle(ColumnFamilyID::Search), edge_index_key1);
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) const {
auto edge_index_key1 = search_key.ConstructHnswEdge(layer, node_key1, node_key2);
auto s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key1);
if (!s.ok()) {
return {Status::NotOK, fmt::format("failed to delete edge, {}", s.ToString())};
}

auto edge_index_key2 = search_key_.ConstructHnswEdge(layer, node_key2, node_key1);
s = batch->Delete(storage_->GetCFHandle(ColumnFamilyID::Search), edge_index_key2);
auto edge_index_key2 = search_key.ConstructHnswEdge(layer, node_key2, node_key1);
s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key2);
if (!s.ok()) {
return {Status::NotOK, fmt::format("failed to delete edge, {}", s.ToString())};
}
return Status::OK();
}

StatusOr<std::vector<VectorItem>> HnswIndex::SelectNeighbors(const VectorItem& vec,
const std::vector<VectorItem>& vertors, uint16_t layer) {
const std::vector<VectorItem>& vertors,
uint16_t layer) const {
std::vector<std::pair<double, VectorItem>> distances;
distances.reserve(vertors.size());
for (const auto& candidate : vertors) {
auto dist = GET_OR_RET(ComputeSimilarity(vec, candidate));
distances.push_back({dist, candidate});
distances.emplace_back(dist, candidate);
}

std::sort(distances.begin(), distances.end());
std::vector<VectorItem> selected_vs;

selected_vs.reserve(vertors.size());
uint16_t m_max = layer != 0 ? metadata_->m : 2 * metadata_->m;
uint16_t m_max = layer != 0 ? metadata->m : 2 * metadata->m;
for (auto i = 0; i < std::min(m_max, (uint16_t)distances.size()); i++) {
selected_vs.push_back(distances[i].second);
}
Expand All @@ -246,7 +249,7 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SelectNeighbors(const VectorItem& v

StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const VectorItem& target_vector,
uint32_t ef_runtime,
const std::vector<NodeKey>& entry_points) {
const std::vector<NodeKey>& entry_points) const {
std::vector<VectorItem> candidates;
std::unordered_set<NodeKey> visited;
std::priority_queue<std::pair<double, VectorItem>, std::vector<std::pair<double, VectorItem>>, std::greater<>>
Expand All @@ -255,9 +258,9 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V

for (const auto& entry_point_key : entry_points) {
Node entry_node = Node(entry_point_key, level);
auto entry_node_metadata = GET_OR_RET(entry_node.DecodeMetadata(search_key_, storage_));
auto entry_node_metadata = GET_OR_RET(entry_node.DecodeMetadata(search_key, storage));

auto entry_point_vector = VectorItem(entry_point_key, std::move(entry_node_metadata.vector), metadata_);
auto entry_point_vector = VectorItem(entry_point_key, std::move(entry_node_metadata.vector), metadata);
auto dist = GET_OR_RET(ComputeSimilarity(target_vector, entry_point_vector));

explore_heap.push(std::make_pair(dist, entry_point_vector));
Expand All @@ -273,7 +276,7 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V
}

auto current_node = Node(current_vector.key, level);
current_node.DecodeNeighbours(search_key_, storage_);
current_node.DecodeNeighbours(search_key, storage);

for (const auto& neighbour_key : current_node.neighbours) {
if (visited.find(neighbour_key) != visited.end()) {
Expand All @@ -282,8 +285,8 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V
visited.insert(neighbour_key);

auto neighbour_node = Node(neighbour_key, level);
auto neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key_, storage_));
auto neighbour_node_vector = VectorItem(neighbour_key, std::move(neighbour_node_metadata.vector), metadata_);
auto neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage));
auto neighbour_node_vector = VectorItem(neighbour_key, std::move(neighbour_node_metadata.vector), metadata);

auto dist = GET_OR_RET(ComputeSimilarity(target_vector, neighbour_node_vector));
explore_heap.push(std::make_pair(dist, neighbour_node_vector));
Expand All @@ -306,27 +309,26 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V
Status HnswIndex::InsertVectorEntryInternal(std::string_view key, kqir::NumericArray vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
uint16_t target_level) {
auto cf_handle = storage_->GetCFHandle(ColumnFamilyID::Search);
auto inserted_vector_item = VectorItem(std::string(key), vector, metadata_);
auto cf_handle = storage->GetCFHandle(ColumnFamilyID::Search);
auto inserted_vector_item = VectorItem(std::string(key), vector, metadata);
std::vector<VectorItem> nearest_vec_items;

if (metadata_->num_levels != 0) {
auto level = metadata_->num_levels - 1;
if (metadata->num_levels != 0) {
auto level = metadata->num_levels - 1;

auto default_entry_node = GET_OR_RET(DefaultEntryPoint(level));
std::vector<NodeKey> entry_points{default_entry_node};

for (; level > target_level; level--) {
nearest_vec_items = GET_OR_RET(SearchLayer(level, inserted_vector_item, metadata_->ef_runtime, entry_points));
nearest_vec_items = GET_OR_RET(SearchLayer(level, inserted_vector_item, metadata->ef_runtime, entry_points));
entry_points = {nearest_vec_items[0].key};
}

for (; level >= 0; level--) {
nearest_vec_items =
GET_OR_RET(SearchLayer(level, inserted_vector_item, metadata_->ef_construction, entry_points));
nearest_vec_items = GET_OR_RET(SearchLayer(level, inserted_vector_item, metadata->ef_construction, entry_points));
auto candidate_vec_items = GET_OR_RET(SelectNeighbors(inserted_vector_item, nearest_vec_items, level));
auto node = Node(std::string(key), level);
auto m_max = level == 0 ? 2 * metadata_->m : metadata_->m;
auto m_max = level == 0 ? 2 * metadata->m : metadata->m;

std::unordered_set<NodeKey> connected_edges_set;
std::unordered_map<NodeKey, std::unordered_set<NodeKey>> deleted_edges_map;
Expand All @@ -348,7 +350,7 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, kqir::NumericA

for (const auto& candidate_vec : candidate_vec_items) {
auto candidate_node = Node(candidate_vec.key, level);
auto candidate_node_metadata = GET_OR_RET(candidate_node.DecodeMetadata(search_key_, storage_));
auto candidate_node_metadata = GET_OR_RET(candidate_node.DecodeMetadata(search_key, storage));
uint16_t candidate_node_num_neighbours = candidate_node_metadata.num_neighbours;

if (has_room_for_more_edges(candidate_node_num_neighbours) ||
Expand All @@ -359,9 +361,9 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, kqir::NumericA
}

// Re-evaluate the neighbours for the candidate node
candidate_node.DecodeNeighbours(search_key_, storage_);
candidate_node.DecodeNeighbours(search_key, storage);
auto candidate_node_neighbour_vec_items =
GET_OR_RET(DecodeNodesToVectorItems(candidate_node.neighbours, level, search_key_, storage_, metadata_));
GET_OR_RET(DecodeNodesToVectorItems(candidate_node.neighbours, level, search_key, storage, metadata));
candidate_node_neighbour_vec_items.push_back(inserted_vector_item);
auto sorted_neighbours = GET_OR_RET(SelectNeighbors(candidate_vec, candidate_node_neighbour_vec_items, level));

Expand Down Expand Up @@ -392,51 +394,51 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, kqir::NumericA

// Update inserted node metadata
HnswNodeFieldMetadata node_metadata(static_cast<uint16_t>(connected_edges_set.size()), vector);
node.PutMetadata(&node_metadata, search_key_, storage_, batch);
node.PutMetadata(&node_metadata, search_key, storage, batch);

// Update modified nodes metadata
for (const auto& node_edges : deleted_edges_map) {
auto& current_node_key = node_edges.first;
auto current_node = Node(current_node_key, level);
auto current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(search_key_, storage_));
auto current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(search_key, storage));
auto new_num_neighbours = current_node_metadata.num_neighbours - node_edges.second.size();
if (connected_edges_set.count(current_node_key) != 0) {
new_num_neighbours++;
connected_edges_set.erase(current_node_key);
}
current_node_metadata.num_neighbours = new_num_neighbours;
current_node.PutMetadata(&current_node_metadata, search_key_, storage_, batch);
current_node.PutMetadata(&current_node_metadata, search_key, storage, batch);
}

for (const auto& current_node_key : connected_edges_set) {
auto current_node = Node(current_node_key, level);
HnswNodeFieldMetadata current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(search_key_, storage_));
HnswNodeFieldMetadata current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(search_key, storage));
current_node_metadata.num_neighbours++;
current_node.PutMetadata(&current_node_metadata, search_key_, storage_, batch);
current_node.PutMetadata(&current_node_metadata, search_key, storage, batch);
}

entry_points.clear();
for (const auto& new_entry_point : nearest_vec_items) {
entry_points.push_back(std::move(new_entry_point.key));
entry_points.push_back(new_entry_point.key);
}
}
} else {
auto node = Node(std::string(key), 0);
HnswNodeFieldMetadata node_metadata(0, vector);
node.PutMetadata(&node_metadata, search_key_, storage_, batch);
metadata_->num_levels = 1;
node.PutMetadata(&node_metadata, search_key, storage, batch);
metadata->num_levels = 1;
}

while (target_level > metadata_->num_levels - 1) {
auto node = Node(std::string(key), metadata_->num_levels);
while (target_level > metadata->num_levels - 1) {
auto node = Node(std::string(key), metadata->num_levels);
HnswNodeFieldMetadata node_metadata(0, vector);
node.PutMetadata(&node_metadata, search_key_, storage_, batch);
metadata_->num_levels++;
node.PutMetadata(&node_metadata, search_key, storage, batch);
metadata->num_levels++;
}

std::string encoded_index_metadata;
metadata_->Encode(&encoded_index_metadata);
auto index_meta_key = search_key_.ConstructFieldMeta();
metadata->Encode(&encoded_index_metadata);
auto index_meta_key = search_key.ConstructFieldMeta();
batch->Put(cf_handle, index_meta_key, encoded_index_metadata);

return Status::OK();
Expand All @@ -445,38 +447,38 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, kqir::NumericA
Status HnswIndex::InsertVectorEntry(std::string_view key, kqir::NumericArray vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto target_level = RandomizeLayer();
return InsertVectorEntryInternal(key, vector, batch, target_level);
return InsertVectorEntryInternal(key, std::move(vector), batch, target_level);
}

Status HnswIndex::DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
std::string node_key(key);
for (uint16_t level = 0; level < metadata_->num_levels; level++) {
for (uint16_t level = 0; level < metadata->num_levels; level++) {
auto node = Node(node_key, level);
auto node_metadata_status = node.DecodeMetadata(search_key_, storage_);
auto node_metadata_status = node.DecodeMetadata(search_key, storage);
if (!node_metadata_status.IsOK()) {
break;
}

auto node_metadata = std::move(node_metadata_status).GetValue();
auto node_index_key = search_key_.ConstructHnswNode(level, key);
auto s = batch->Delete(storage_->GetCFHandle(ColumnFamilyID::Search), node_index_key);
auto node_index_key = search_key.ConstructHnswNode(level, key);
auto s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search), node_index_key);
if (!s.ok()) {
return {Status::NotOK, s.ToString()};
}

node.DecodeNeighbours(search_key_, storage_);
node.DecodeNeighbours(search_key, storage);
for (const auto& neighbour_key : node.neighbours) {
GET_OR_RET(RemoveEdge(node_key, neighbour_key, level, batch));
auto neighbour_node = Node(neighbour_key, level);
HnswNodeFieldMetadata neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key_, storage_));
HnswNodeFieldMetadata neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage));
neighbour_node_metadata.num_neighbours--;
neighbour_node.PutMetadata(&neighbour_node_metadata, search_key_, storage_, batch);
neighbour_node.PutMetadata(&neighbour_node_metadata, search_key, storage, batch);
}
}

auto hasOtherNodesAtLevel = [&](uint16_t level, std::string_view skip_key) -> bool {
auto prefix = search_key_.ConstructHnswLevelNodePrefix(level);
util::UniqueIterator it(storage_, storage_->DefaultScanOptions(), ColumnFamilyID::Search);
auto has_other_nodes_at_level = [&](uint16_t level, std::string_view skip_key) -> bool {
auto prefix = search_key.ConstructHnswLevelNodePrefix(level);
util::UniqueIterator it(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search);
it->Seek(prefix);

Slice node_key;
Expand All @@ -494,17 +496,17 @@ Status HnswIndex::DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr<ro
return false;
};

while (metadata_->num_levels > 0) {
if (hasOtherNodesAtLevel(metadata_->num_levels - 1, key)) {
while (metadata->num_levels > 0) {
if (has_other_nodes_at_level(metadata->num_levels - 1, key)) {
break;
}
metadata_->num_levels--;
metadata->num_levels--;
}

std::string encoded_index_metadata;
metadata_->Encode(&encoded_index_metadata);
auto index_meta_key = search_key_.ConstructFieldMeta();
batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), index_meta_key, encoded_index_metadata);
metadata->Encode(&encoded_index_metadata);
auto index_meta_key = search_key.ConstructFieldMeta();
batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key, encoded_index_metadata);

return Status::OK();
}
Expand Down
Loading

0 comments on commit 5e2efc1

Please sign in to comment.