Skip to content

Commit

Permalink
Add DeleteVectorEntry
Browse files Browse the repository at this point in the history
  • Loading branch information
Beihao-Zhou committed Jul 7, 2024
1 parent 7bbcd93 commit 8934d87
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 14 deletions.
73 changes: 67 additions & 6 deletions src/search/hnsw_indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,14 @@ Status HnswIndex::AddEdge(const NodeKey& node_key1, const NodeKey& node_key2, ui

Status HnswIndex::RemoveEdge(const NodeKey& node_key1, const NodeKey& node_key2, uint16_t layer,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto edge_index_key = search_key_.ConstructHnswEdge(layer, node_key1, node_key2);
auto s = batch->Delete(storage_->GetCFHandle(ColumnFamilyID::Search), edge_index_key);
auto edge_index_key1 = search_key_.ConstructHnswEdge(layer, node_key1, node_key2);
auto s = batch->Delete(storage_->GetCFHandle(ColumnFamilyID::Search), edge_index_key1);
if (!s.ok()) {
return {Status::NotOK, fmt::format("failed to delete edge, {}", s.ToString())};
}

edge_index_key = search_key_.ConstructHnswEdge(layer, node_key2, node_key1);
s = batch->Delete(storage_->GetCFHandle(ColumnFamilyID::Search), edge_index_key);
auto edge_index_key2 = search_key_.ConstructHnswEdge(layer, node_key2, node_key1);
s = batch->Delete(storage_->GetCFHandle(ColumnFamilyID::Search), edge_index_key2);
if (!s.ok()) {
return {Status::NotOK, fmt::format("failed to delete edge, {}", s.ToString())};
}
Expand Down Expand Up @@ -331,12 +331,12 @@ Status HnswIndex::InsertVectorEntryInternal(std::string_view key, kqir::NumericA
std::unordered_set<NodeKey> connected_edges_set;
std::unordered_map<NodeKey, std::unordered_set<NodeKey>> deleted_edges_map;

// Check against if candidate node has room for more outgoing edges
// Check if candidate node has room for more outgoing edges
auto has_room_for_more_edges = [&](int candidate_node_num_neighbours) {
return candidate_node_num_neighbours < m_max;
};

// Check against if candidate node has room after some other nodes' are pruned in current batch
// Check if candidate node has room after some other nodes' are pruned in current batch
auto has_room_after_deletions = [&](const Node& candidate_node, int candidate_node_num_neighbours) {
auto it = deleted_edges_map.find(candidate_node.key);
if (it != deleted_edges_map.end()) {
Expand Down Expand Up @@ -448,4 +448,65 @@ Status HnswIndex::InsertVectorEntry(std::string_view key, kqir::NumericArray vec
return InsertVectorEntryInternal(key, vector, batch, target_level);
}

Status HnswIndex::DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
std::string node_key(key);
for (uint16_t level = 0; level < metadata_->num_levels; level++) {
auto node = Node(node_key, level);
auto node_metadata_status = node.DecodeMetadata(search_key_, storage_);
if (!node_metadata_status.IsOK()) {
break;
}

auto node_metadata = std::move(node_metadata_status).GetValue();
auto node_index_key = search_key_.ConstructHnswNode(level, key);
auto s = batch->Delete(storage_->GetCFHandle(ColumnFamilyID::Search), node_index_key);
if (!s.ok()) {
return {Status::NotOK, s.ToString()};
}

node.DecodeNeighbours(search_key_, storage_);
for (const auto& neighbour_key : node.neighbours) {
GET_OR_RET(RemoveEdge(node_key, neighbour_key, level, batch));
auto neighbour_node = Node(neighbour_key, level);
HnswNodeFieldMetadata neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key_, storage_));
neighbour_node_metadata.num_neighbours--;
neighbour_node.PutMetadata(&neighbour_node_metadata, search_key_, storage_, batch);
}
}

auto hasOtherNodesAtLevel = [&](uint16_t level, std::string_view skip_key) -> bool {
auto prefix = search_key_.ConstructHnswLevelNodePrefix(level);
util::UniqueIterator it(storage_, storage_->DefaultScanOptions(), ColumnFamilyID::Search);
it->Seek(prefix);

Slice node_key;
Slice node_key_dst;
while (it->Valid() && it->key().starts_with(prefix)) {
node_key = Slice(it->key().ToString().substr(prefix.size()));
if (!GetSizedString(&node_key, &node_key_dst)) {
continue;
}
if (node_key_dst.ToString() != skip_key) {
return true;
}
it->Next();
}
return false;
};

while (metadata_->num_levels > 0) {
if (hasOtherNodesAtLevel(metadata_->num_levels - 1, key)) {
break;
}
metadata_->num_levels--;
}

std::string encoded_index_metadata;
metadata_->Encode(&encoded_index_metadata);
auto index_meta_key = search_key_.ConstructFieldMeta();
batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), index_meta_key, encoded_index_metadata);

return Status::OK();
}

} // namespace redis
3 changes: 3 additions & 0 deletions src/search/hnsw_indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ struct Node {
void PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
void DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage);

// For testing purpose
Status AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
Status RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
Expand Down Expand Up @@ -100,6 +102,7 @@ class HnswIndex {
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch, uint16_t layer);
Status InsertVectorEntry(std::string_view key, kqir::NumericArray vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
Status DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
};

} // namespace redis
15 changes: 9 additions & 6 deletions src/search/indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,20 +272,23 @@ Status IndexUpdater::UpdateHnswVectorIndex(std::string_view key, const kqir::Val
CHECK(original.IsNull() || original.Is<kqir::NumericArray>());
CHECK(current.IsNull() || current.Is<kqir::NumericArray>());

auto *storage = indexer->storage;
auto batch = storage->GetWriteBatchBase();
auto storage = indexer->storage;
auto hnsw = HnswIndex(search_key, vector, storage);

if (!original.IsNull()) {
// TODO(Beihao): implement vector deletion
auto batch = storage->GetWriteBatchBase();
GET_OR_RET(hnsw.DeleteVectorEntry(key, batch));
auto s = storage->Write(storage->DefaultWriteOptions(), batch->GetWriteBatch());
if (!s.ok()) return {Status::NotOK, s.ToString()};
}

if (!current.IsNull()) {
auto hnsw = HnswIndex(search_key, vector, indexer->storage);
auto batch = storage->GetWriteBatchBase();
GET_OR_RET(hnsw.InsertVectorEntry(key, current.Get<kqir::NumericArray>(), batch));
auto s = storage->Write(storage->DefaultWriteOptions(), batch->GetWriteBatch());
if (!s.ok()) return {Status::NotOK, s.ToString()};
}

auto s = storage->Write(storage->DefaultWriteOptions(), batch->GetWriteBatch());
if (!s.ok()) return {Status::NotOK, s.ToString()};
return Status::OK();
}

Expand Down
127 changes: 125 additions & 2 deletions tests/cppunit/hnsw_index_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ TEST_F(HnswIndexTest, SearchLayer) {
EXPECT_EQ(candidates[2].key, "node2");

// Test with a single entry point
entry_points = {"node1"};
entry_points = {"node5"};
auto s10 = hnsw_index_->SearchLayer(layer, target_vector, ef_runtime, entry_points);
ASSERT_TRUE(s10.IsOK());
candidates = s10.GetValue();
Expand All @@ -304,7 +304,7 @@ TEST_F(HnswIndexTest, SearchLayer) {
EXPECT_EQ(candidates[4].key, "node3");
}

TEST_F(HnswIndexTest, InsertVectorEntry) {
TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) {
std::vector<double> vec1 = {11.0, 12.0, 13.0};
std::vector<double> vec2 = {14.0, 15.0, 16.0};
std::vector<double> vec3 = {17.0, 18.0, 19.0};
Expand Down Expand Up @@ -511,4 +511,127 @@ TEST_F(HnswIndexTest, InsertVectorEntry) {
expected_set = {"n2", "n3", "n4"};
actual_set = {node5_layer1.neighbours.begin(), node5_layer1.neighbours.end()};
EXPECT_EQ(actual_set, expected_set);

auto s20 = node1_layer0.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
ASSERT_TRUE(s20.IsOK());
node1_layer0_meta = s20.GetValue();
EXPECT_EQ(node1_layer0_meta.num_neighbours, 4);
node1_layer0.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_);
expected_set = {"n2", "n3", "n4", "n5"};
actual_set = {node1_layer0.neighbours.begin(), node1_layer0.neighbours.end()};
EXPECT_EQ(actual_set, expected_set);

redis::Node node5_layer0(key5, 0);
auto s21 = node5_layer0.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
ASSERT_TRUE(s21.IsOK());
auto node5_layer0_meta = s21.GetValue();
EXPECT_EQ(node5_layer0_meta.num_neighbours, 4);
node5_layer0.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_);
expected_set = {"n1", "n2", "n3", "n4"};
actual_set = {node5_layer0.neighbours.begin(), node5_layer0.neighbours.end()};
EXPECT_EQ(actual_set, expected_set);

// Delete n2
batch = storage_->GetWriteBatchBase();
auto s22 = hnsw_index_->DeleteVectorEntry(key2, batch);
ASSERT_TRUE(s22.IsOK());
s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch());
ASSERT_TRUE(s.ok());

index_meta_key = hnsw_index_->search_key_.ConstructFieldMeta();
s = storage_->Get(rocksdb::ReadOptions(), hnsw_index_->storage_->GetCFHandle(ColumnFamilyID::Search), index_meta_key,
&value);
ASSERT_TRUE(s.ok());
decoded_metadata.Decode(&value);
ASSERT_TRUE(decoded_metadata.num_levels == 3);

auto s23 = node2_layer3.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
EXPECT_TRUE(!s23.IsOK());

auto s24 = node2_layer2.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
EXPECT_TRUE(!s24.IsOK());

auto s25 = node2_layer1.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
EXPECT_TRUE(!s25.IsOK());

auto s26 = node2_layer0.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
EXPECT_TRUE(!s26.IsOK());

auto s27 = node3_layer2.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
ASSERT_TRUE(s27.IsOK());
node3_layer2_meta = s27.GetValue();
EXPECT_EQ(node3_layer2_meta.num_neighbours, 0);

auto s28 = node1_layer1.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
ASSERT_TRUE(s28.IsOK());
node1_layer1_meta = s28.GetValue();
EXPECT_EQ(node1_layer1_meta.num_neighbours, 2);
node1_layer1.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_);
expected_set = {"n3", "n4"};
actual_set = {node1_layer1.neighbours.begin(), node1_layer1.neighbours.end()};
EXPECT_EQ(actual_set, expected_set);

auto s29 = node3_layer1.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
ASSERT_TRUE(s29.IsOK());
node3_layer1_meta = s29.GetValue();
EXPECT_EQ(node3_layer1_meta.num_neighbours, 2);
node3_layer1.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_);
expected_set = {"n1", "n5"};
actual_set = {node3_layer1.neighbours.begin(), node3_layer1.neighbours.end()};
EXPECT_EQ(actual_set, expected_set);

auto s30 = node4_layer1.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
ASSERT_TRUE(s30.IsOK());
node4_layer1_meta = s30.GetValue();
EXPECT_EQ(node4_layer1_meta.num_neighbours, 2);
node4_layer1.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_);
expected_set = {"n1", "n5"};
actual_set = {node4_layer1.neighbours.begin(), node4_layer1.neighbours.end()};
EXPECT_EQ(actual_set, expected_set);

auto s31 = node5_layer1.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
ASSERT_TRUE(s31.IsOK());
node5_layer1_meta = s31.GetValue();
EXPECT_EQ(node5_layer1_meta.num_neighbours, 2);
node5_layer1.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_);
expected_set = {"n3", "n4"};
actual_set = {node5_layer1.neighbours.begin(), node5_layer1.neighbours.end()};
EXPECT_EQ(actual_set, expected_set);

auto s32 = node1_layer0.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
ASSERT_TRUE(s32.IsOK());
node1_layer0_meta = s32.GetValue();
EXPECT_EQ(node1_layer0_meta.num_neighbours, 3);
node1_layer0.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_);
expected_set = {"n3", "n4", "n5"};
actual_set = {node1_layer0.neighbours.begin(), node1_layer0.neighbours.end()};
EXPECT_EQ(actual_set, expected_set);

redis::Node node3_layer0(key3, 0);
auto s33 = node3_layer0.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
ASSERT_TRUE(s33.IsOK());
auto node3_layer0_meta = s33.GetValue();
EXPECT_EQ(node3_layer0_meta.num_neighbours, 3);
node3_layer0.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_);
expected_set = {"n1", "n4", "n5"};
actual_set = {node3_layer0.neighbours.begin(), node3_layer0.neighbours.end()};
EXPECT_EQ(actual_set, expected_set);

auto s34 = node4_layer0.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
ASSERT_TRUE(s34.IsOK());
node4_layer0_meta = s34.GetValue();
EXPECT_EQ(node4_layer0_meta.num_neighbours, 3);
node4_layer0.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_);
expected_set = {"n1", "n3", "n5"};
actual_set = {node4_layer0.neighbours.begin(), node4_layer0.neighbours.end()};
EXPECT_EQ(actual_set, expected_set);

auto s35 = node5_layer0.DecodeMetadata(hnsw_index_->search_key_, hnsw_index_->storage_);
ASSERT_TRUE(s35.IsOK());
node5_layer0_meta = s35.GetValue();
EXPECT_EQ(node5_layer0_meta.num_neighbours, 3);
node5_layer0.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_);
expected_set = {"n1", "n3", "n4"};
actual_set = {node5_layer0.neighbours.begin(), node5_layer0.neighbours.end()};
EXPECT_EQ(actual_set, expected_set);
}
20 changes: 20 additions & 0 deletions tests/cppunit/hnsw_node_test.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#include <encoding.h>
#include <gtest/gtest.h>
#include <test_base.h>
Expand Down

0 comments on commit 8934d87

Please sign in to comment.