diff --git a/src/search/hnsw_indexer.cc b/src/search/hnsw_indexer.cc index 9a65a90ae7b..7b7dae541fa 100644 --- a/src/search/hnsw_indexer.cc +++ b/src/search/hnsw_indexer.cc @@ -56,17 +56,84 @@ void Node::PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search } void Node::DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage) { + neighbours.clear(); auto edge_prefix = search_key.ConstructHnswEdgeWithSingleEnd(level, key); util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search); for (iter->Seek(edge_prefix); iter->Valid(); iter->Next()) { if (!iter->key().starts_with(edge_prefix)) { break; } - auto neighbour_key = iter->key().ToString().substr(edge_prefix.size()); - neighbours.push_back(std::move(neighbour_key)); + auto neighbour_edge = iter->key(); + neighbour_edge.remove_prefix(edge_prefix.size()); + Slice neighbour; + GetSizedString(&neighbour_edge, &neighbour); + neighbours.push_back(neighbour.ToString()); } } +Status Node::AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, + ObserverOrUniquePtr& batch) { + auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour_key); + batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key, Slice()); + + HnswNodeFieldMetadata node_metadata = GET_OR_RET(DecodeMetadata(search_key, storage)); + node_metadata.num_neighbours++; + PutMetadata(&node_metadata, search_key, storage, batch); + return Status::OK(); +} + +Status Node::RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, + ObserverOrUniquePtr& batch) { + auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour_key); + auto s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key); + if (!s.ok()) { + return {Status::NotOK, fmt::format("failed to delete edge, {}", s.ToString())}; + } + + HnswNodeFieldMetadata node_metadata = GET_OR_RET(DecodeMetadata(search_key, storage)); + node_metadata.num_neighbours--; + PutMetadata(&node_metadata, search_key, storage, batch); + return Status::OK(); +} + +Status Node::UpdateNeighbours(std::vector& neighbours, const SearchKey& search_key, engine::Storage* storage, + ObserverOrUniquePtr& batch, + std::unordered_set& deleted_neighbours) { + deleted_neighbours.clear(); + auto cf_handle = storage->GetCFHandle(ColumnFamilyID::Search); + auto edge_prefix = search_key.ConstructHnswEdgeWithSingleEnd(level, key); + std::unordered_set to_be_added{neighbours.begin(), neighbours.end()}; + + util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search); + for (iter->Seek(edge_prefix); iter->Valid(); iter->Next()) { + if (!iter->key().starts_with(edge_prefix)) { + break; + } + auto neighbour_edge = iter->key(); + neighbour_edge.remove_prefix(edge_prefix.size()); + Slice neighbour; + GetSizedString(&neighbour_edge, &neighbour); + auto neighbour_key = neighbour.ToString(); + + if (to_be_added.count(neighbour_key) == 0) { + batch->Delete(cf_handle, iter->key()); + deleted_neighbours.insert(neighbour_key); + } else { + to_be_added.erase(neighbour_key); + } + } + + for (const auto& neighbour : to_be_added) { + auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour); + batch->Put(cf_handle, edge_index_key, Slice()); + } + + HnswNodeFieldMetadata node_metadata = GET_OR_RET(DecodeMetadata(search_key, storage)); + node_metadata.num_neighbours = static_cast(neighbours.size()); + PutMetadata(&node_metadata, search_key, storage, batch); + return Status::OK(); +} + VectorItem::VectorItem(const NodeKey& key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata) : key(key), vector(std::move(vector)), metadata(metadata) {} VectorItem::VectorItem(const NodeKey& key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata) @@ -74,7 +141,7 @@ VectorItem::VectorItem(const NodeKey& key, kqir::NumericArray&& vector, const Hn bool VectorItem::operator<(const VectorItem& other) const { return key < other.key; } -StatusOr ComputeDistance(const VectorItem& left, const VectorItem& right) { +StatusOr ComputeSimilarity(const VectorItem& left, const VectorItem& right) { if (left.metadata->distance_metric != right.metadata->distance_metric || left.metadata->dim != right.metadata->dim) return {Status::InvalidArgument, "Vectors must be of the same metric and dimension to compute distance."}; @@ -99,14 +166,14 @@ StatusOr ComputeDistance(const VectorItem& left, const VectorItem& right } case DistanceMetric::COSINE: { double dist = 0.0; - double norma = 0.0; - double normb = 0.0; + double norm_left = 0.0; + double norm_right = 0.0; for (auto i = 0; i < dim; i++) { dist += left.vector[i] * right.vector[i]; - norma += left.vector[i] * right.vector[i]; - normb += left.vector[i] * right.vector[i]; + norm_left += left.vector[i] * left.vector[i]; + norm_right += right.vector[i] * right.vector[i]; } - auto similarity = dist / std::sqrt(norma * normb); + auto similarity = dist / std::sqrt(norm_left * norm_right); return 1.0 - similarity; } default: @@ -136,61 +203,48 @@ StatusOr HnswIndex::DefaultEntryPoint(uint16_t level) { if (it->Valid() && it->key().starts_with(prefix)) { node_key = Slice(it->key().ToString().substr(prefix.size())); if (!GetSizedString(&node_key, &node_key_dst)) { - return {Status::NotFound, fmt::format("No node found in layer {}", level)}; + return {Status::NotOK, fmt::format("fail to decode the default node key layer {}", level)}; } + return node_key_dst.ToString(); } - return node_key_dst.ToString(); + return {Status::NotFound, fmt::format("No node found in layer {}", level)}; } Status HnswIndex::Connect(uint16_t layer, const NodeKey& node_key1, const NodeKey& node_key2, ObserverOrUniquePtr& batch) { - auto cf_handle = storage_->GetCFHandle(ColumnFamilyID::Search); - auto edge_index_key1 = search_key_.ConstructHnswEdge(layer, node_key1, node_key2); - batch->Put(cf_handle, edge_index_key1, Slice()); - - auto edge_index_key2 = search_key_.ConstructHnswEdge(layer, node_key2, node_key1); - batch->Put(cf_handle, edge_index_key2, Slice()); - auto node1 = Node(node_key1, layer); - HnswNodeFieldMetadata node1_metadata = GET_OR_RET(node1.DecodeMetadata(search_key_, storage_)); - node1_metadata.num_neighbours += 1; - node1.PutMetadata(&node1_metadata, search_key_, storage_, batch); + GET_OR_RET(node1.AddNeighbour(node_key2, search_key_, storage_, batch)); auto node2 = Node(node_key2, layer); - HnswNodeFieldMetadata node2_metadata = GET_OR_RET(node2.DecodeMetadata(search_key_, storage_)); - node2_metadata.num_neighbours += 1; - node2.PutMetadata(&node1_metadata, search_key_, storage_, batch); + GET_OR_RET(node2.AddNeighbour(node_key1, search_key_, storage_, batch)); return Status::OK(); } -// Assume the new_neighbour_vertors is a subset of the original neighbours -Status HnswIndex::PruneEdges(const VectorItem& vec, const std::vector& new_neighbour_vertors, +Status HnswIndex::PruneEdges(const VectorItem& vec, const std::vector& new_neighbour_vectors, uint16_t layer, ObserverOrUniquePtr& batch) { - auto cf_handle = storage_->GetCFHandle(ColumnFamilyID::Search); - std::unordered_set neighbours; - for (const auto& neighbour_vector : new_neighbour_vertors) { - neighbours.insert(neighbour_vector.key); - } - - auto edge_prefix = search_key_.ConstructHnswEdgeWithSingleEnd(layer, vec.key); - util::UniqueIterator iter(storage_, storage_->DefaultScanOptions(), ColumnFamilyID::Search); - for (iter->Seek(edge_prefix); iter->Valid(); iter->Next()) { - if (!iter->key().starts_with(edge_prefix)) { - break; - } - auto neighbour_key = iter->key().ToString().substr(edge_prefix.size()); - - if (neighbours.count(neighbour_key) == 0) { - batch->Delete(cf_handle, iter->key()); + auto node = Node(vec.key, layer); + node.DecodeNeighbours(search_key_, storage_); + std::unordered_set original_neighbours{node.neighbours.begin(), node.neighbours.end()}; + + uint16_t neighbours_sz = static_cast(new_neighbour_vectors.size()); + std::vector neighbours(neighbours_sz); + for (auto i = 0; i < neighbours_sz; i++) { + auto neighbour_key = new_neighbour_vectors[i].key; + if (original_neighbours.count(neighbour_key) == 0) { + return {Status::InvalidArgument, + fmt::format("Node \"{}\" is not a neighbour of \"{}\" and can't be pruned", neighbour_key, vec.key)}; } + neighbours[i] = new_neighbour_vectors[i].key; } - Node node = Node(vec.key, layer); - HnswNodeFieldMetadata node_metadata = GET_OR_RET(node.DecodeMetadata(search_key_, storage_)); - node_metadata.num_neighbours = neighbours.size(); - node.PutMetadata(&node_metadata, search_key_, storage_, batch); + std::unordered_set deleted_neighbours; + GET_OR_RET(node.UpdateNeighbours(neighbours, search_key_, storage_, batch, deleted_neighbours)); + for (const auto& key : deleted_neighbours) { + auto neighbour_node = Node(key, layer); + GET_OR_RET(neighbour_node.RemoveNeighbour(vec.key, search_key_, storage_, batch)); + } return Status::OK(); } @@ -199,7 +253,7 @@ StatusOr> HnswIndex::SelectNeighbors(const VectorItem& v std::vector> distances; distances.reserve(vertors.size()); for (const auto& candidate : vertors) { - auto dist = GET_OR_RET(ComputeDistance(vec, candidate)); + auto dist = GET_OR_RET(ComputeSimilarity(vec, candidate)); distances.push_back({dist, candidate}); } @@ -228,7 +282,7 @@ StatusOr> HnswIndex::SearchLayer(uint16_t level, const V auto entry_node_metadata = GET_OR_RET(entry_node.DecodeMetadata(search_key_, storage_)); auto entry_point_vector = VectorItem(entry_point_key, std::move(entry_node_metadata.vector), metadata_); - auto dist = GET_OR_RET(ComputeDistance(target_vector, entry_point_vector)); + auto dist = GET_OR_RET(ComputeSimilarity(target_vector, entry_point_vector)); explore_heap.push(std::make_pair(dist, entry_point_vector)); result_heap.push(std::make_pair(dist, std::move(entry_point_vector))); @@ -255,7 +309,7 @@ StatusOr> HnswIndex::SearchLayer(uint16_t level, const V auto neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key_, storage_)); auto neighbour_node_vector = VectorItem(neighbour_key, std::move(neighbour_node_metadata.vector), metadata_); - auto dist = GET_OR_RET(ComputeDistance(target_vector, neighbour_node_vector)); + auto dist = GET_OR_RET(ComputeSimilarity(target_vector, neighbour_node_vector)); explore_heap.push(std::make_pair(dist, neighbour_node_vector)); result_heap.push(std::make_pair(dist, neighbour_node_vector)); while (result_heap.size() > ef_runtime) { diff --git a/src/search/hnsw_indexer.h b/src/search/hnsw_indexer.h index da03273f7be..6fcc6dd078c 100644 --- a/src/search/hnsw_indexer.h +++ b/src/search/hnsw_indexer.h @@ -45,6 +45,13 @@ struct Node { void PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage, ObserverOrUniquePtr& batch); void DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage); + Status AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, + ObserverOrUniquePtr& batch); + Status RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, + ObserverOrUniquePtr& batch); + Status UpdateNeighbours(std::vector& neighbours, const SearchKey& search_key, engine::Storage* storage, + ObserverOrUniquePtr& batch, + std::unordered_set& deleted_neighbours); friend class HnswIndex; }; @@ -62,7 +69,7 @@ struct VectorItem { bool operator<(const VectorItem& other) const; }; -StatusOr ComputeDistance(const VectorItem& left, const VectorItem& right); +StatusOr ComputeSimilarity(const VectorItem& left, const VectorItem& right); class HnswIndex { public: diff --git a/src/search/indexer.cc b/src/search/indexer.cc index d43d05b6695..38034c62046 100644 --- a/src/search/indexer.cc +++ b/src/search/indexer.cc @@ -276,7 +276,7 @@ Status IndexUpdater::UpdateHnswVectorIndex(std::string_view key, const kqir::Val auto batch = storage->GetWriteBatchBase(); if (!original.IsNull()) { - // TODO: delete + // TODO(Beihao): implement vector deletion } if (!current.IsNull()) { diff --git a/src/search/search_encoding.h b/src/search/search_encoding.h index ae7ffbae430..bdcda397d84 100644 --- a/src/search/search_encoding.h +++ b/src/search/search_encoding.h @@ -82,7 +82,6 @@ enum class IndexFieldType : uint8_t { }; enum class VectorType : uint8_t { - FLOAT32 = 0, FLOAT64 = 1, }; @@ -418,9 +417,9 @@ struct HnswNodeFieldMetadata { : num_neighbours(num_neighbours), vector(std::move(vector)) {} void Encode(std::string *dst) const { - PutFixed16(dst, uint16_t(num_neighbours)); - PutFixed16(dst, uint16_t(vector.size())); - for (auto element : vector) { + PutFixed16(dst, num_neighbours); + PutFixed16(dst, static_cast(vector.size())); + for (double element : vector) { PutDouble(dst, element); } } @@ -437,8 +436,9 @@ struct HnswNodeFieldMetadata { if (input->size() != dim * sizeof(double)) { return rocksdb::Status::Corruption(kErrorIncorrectLength); } + vector.resize(dim); - for (size_t i = 0; i < dim; ++i) { + for (auto i = 0; i < dim; ++i) { GetDouble(input, &vector[i]); } return rocksdb::Status::OK(); diff --git a/tests/cppunit/hnsw_index_test.cc b/tests/cppunit/hnsw_index_test.cc new file mode 100644 index 00000000000..f361ef9caec --- /dev/null +++ b/tests/cppunit/hnsw_index_test.cc @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include +#include + +#include +#include +#include + +#include "search/hnsw_indexer.h" +#include "search/indexer.h" +#include "search/search_encoding.h" +#include "search/value.h" +#include "storage/storage.h" + +struct HnswIndexTest : TestBase { + redis::HnswVectorFieldMetadata metadata_; + std::string ns = "hnsw_test_ns"; + std::string idx_name = "hnsw_test_idx"; + std::string key = "vector"; + std::unique_ptr hnsw_index_; + + HnswIndexTest() { + metadata_.vector_type = redis::VectorType::FLOAT64; + metadata_.dim = 4; + metadata_.m = 3; + metadata_.distance_metric = redis::DistanceMetric::L2; + auto search_key_ = redis::SearchKey(ns, idx_name, key); + hnsw_index_ = std::make_unique(search_key_, &metadata_, storage_.get()); + } + + void TearDown() override { hnsw_index_.reset(); } +}; + +TEST_F(HnswIndexTest, ComputeSimilarity) { + redis::VectorItem vec1 = {"1", {1.0, 1.2, 1.4, 1.6}, hnsw_index_->metadata_}; + redis::VectorItem vec2 = {"2", {3.0, 3.2, 3.4, 3.6}, hnsw_index_->metadata_}; + redis::VectorItem vec3 = {"3", {1.0, 1.2, 1.4, 1.6}, hnsw_index_->metadata_}; // identical to vec1 + + auto s1 = redis::ComputeSimilarity(vec1, vec3); + ASSERT_TRUE(s1.IsOK()); + double similarity = s1.GetValue(); + EXPECT_EQ(similarity, 0.0); + + auto s2 = redis::ComputeSimilarity(vec1, vec2); + ASSERT_TRUE(s2.IsOK()); + similarity = s2.GetValue(); + EXPECT_EQ(similarity, 4.0); + + hnsw_index_->metadata_->distance_metric = redis::DistanceMetric::IP; + auto s3 = redis::ComputeSimilarity(vec1, vec2); + ASSERT_TRUE(s3.IsOK()); + similarity = s3.GetValue(); + EXPECT_NEAR(similarity, -17.36, 1e-5); + + hnsw_index_->metadata_->distance_metric = redis::DistanceMetric::COSINE; + double expected_res = + (1.0 * 3.0 + 1.2 * 3.2 + 1.4 * 3.4 + 1.6 * 3.6) / + std::sqrt((1.0 * 1.0 + 1.2 * 1.2 + 1.4 * 1.4 + 1.6 * 1.6) * (3.0 * 3.0 + 3.2 * 3.2 + 3.4 * 3.4 + 3.6 * 3.6)); + auto s4 = redis::ComputeSimilarity(vec1, vec2); + ASSERT_TRUE(s4.IsOK()); + similarity = s4.GetValue(); + EXPECT_NEAR(similarity, 1 - expected_res, 1e-5); + + hnsw_index_->metadata_->distance_metric = redis::DistanceMetric::L2; +} + +TEST_F(HnswIndexTest, DefaultEntryPointNotFound) { + auto initial_result = hnsw_index_->DefaultEntryPoint(0); + ASSERT_EQ(initial_result.GetCode(), Status::NotFound); +} + +TEST_F(HnswIndexTest, ConnectNodes) { + uint16_t layer = 0; + std::string node_key1 = "node1"; + std::string node_key2 = "node2"; + + redis::Node node1(node_key1, layer); + redis::Node node2(node_key2, layer); + + redis::HnswNodeFieldMetadata node1_metadata(0, {1, 2, 3}); + redis::HnswNodeFieldMetadata node2_metadata(0, {1, 2, 3}); + + auto batch = storage_->GetWriteBatchBase(); + node1.PutMetadata(&node1_metadata, hnsw_index_->search_key_, hnsw_index_->storage_, batch); + node2.PutMetadata(&node2_metadata, hnsw_index_->search_key_, hnsw_index_->storage_, batch); + auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + // Connect two nodes + batch = storage_->GetWriteBatchBase(); + auto connect_status = hnsw_index_->Connect(layer, node_key1, node_key2, batch); + ASSERT_EQ(connect_status.Msg(), Status::ok_msg); + s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + node1.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_); + EXPECT_EQ(node1.neighbours.size(), 1); + EXPECT_EQ(node1.neighbours[0], node_key2); + + node2.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_); + EXPECT_EQ(node2.neighbours.size(), 1); + EXPECT_EQ(node2.neighbours[0], node_key1); +} + +TEST_F(HnswIndexTest, PruneEdges) { + uint16_t layer = 1; + std::string node_key1 = "node1"; + std::string node_key2 = "node2"; + std::string node_key3 = "node3"; + + redis::Node node1(node_key1, layer); + redis::Node node2(node_key2, layer); + redis::Node node3(node_key3, layer); + + redis::HnswNodeFieldMetadata metadata1(0, {1, 2, 3}); + redis::HnswNodeFieldMetadata metadata2(0, {4, 5, 6}); + redis::HnswNodeFieldMetadata metadata3(0, {7, 8, 9}); + + auto batch = storage_->GetWriteBatchBase(); + node1.PutMetadata(&metadata1, hnsw_index_->search_key_, hnsw_index_->storage_, batch); + node2.PutMetadata(&metadata2, hnsw_index_->search_key_, hnsw_index_->storage_, batch); + node3.PutMetadata(&metadata3, hnsw_index_->search_key_, hnsw_index_->storage_, batch); + auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + auto batch2 = storage_->GetWriteBatchBase(); + auto s1 = node1.AddNeighbour("node2", hnsw_index_->search_key_, hnsw_index_->storage_, batch2); + ASSERT_TRUE(s1.IsOK()); + auto s2 = node2.AddNeighbour("node1", hnsw_index_->search_key_, hnsw_index_->storage_, batch2); + ASSERT_TRUE(s2.IsOK()); + auto s3 = node2.AddNeighbour("node3", hnsw_index_->search_key_, hnsw_index_->storage_, batch2); + ASSERT_TRUE(s3.IsOK()); + auto s4 = node3.AddNeighbour("node2", hnsw_index_->search_key_, hnsw_index_->storage_, batch2); + ASSERT_TRUE(s4.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch2->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + // Prune edges for node2, keeping only node3 as a neighbor + auto batch3 = storage_->GetWriteBatchBase(); + std::vector new_neighbours = {redis::VectorItem("node3", {7, 8, 9}, &metadata_)}; + auto s5 = hnsw_index_->PruneEdges(redis::VectorItem("node2", {4, 5, 6}, &metadata_), new_neighbours, layer, batch3); + ASSERT_TRUE(s5.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch3->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + node1.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_); + EXPECT_EQ(node1.neighbours.size(), 0); + + node2.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_); + EXPECT_EQ(node2.neighbours.size(), 1); + EXPECT_EQ(node2.neighbours[0], "node3"); + + node3.DecodeNeighbours(hnsw_index_->search_key_, hnsw_index_->storage_); + EXPECT_EQ(node3.neighbours.size(), 1); + EXPECT_EQ(node3.neighbours[0], "node2"); + + // Prune edges for node3 with non-existing + auto batch4 = storage_->GetWriteBatchBase(); + new_neighbours = {redis::VectorItem("node1", {1, 2, 3}, &metadata_)}; + auto s6 = hnsw_index_->PruneEdges(redis::VectorItem("node3", {7, 8, 9}, &metadata_), new_neighbours, layer, batch4); + ASSERT_EQ(s6.GetCode(), Status::InvalidArgument); +} + +TEST_F(HnswIndexTest, SelectNeighbors) { + redis::VectorItem vec1 = {"1", {1.0, 1.0, 1.0, 1.0}, hnsw_index_->metadata_}; + redis::VectorItem vec2 = {"2", {2.0, 2.0, 2.0, 2.0}, hnsw_index_->metadata_}; + redis::VectorItem vec3 = {"3", {3.0, 3.0, 3.0, 3.0}, hnsw_index_->metadata_}; + redis::VectorItem vec4 = {"4", {4.0, 4.0, 4.0, 4.0}, hnsw_index_->metadata_}; + redis::VectorItem vec5 = {"5", {5.0, 5.0, 5.0, 5.0}, hnsw_index_->metadata_}; + redis::VectorItem vec6 = {"6", {6.0, 6.0, 6.0, 6.0}, hnsw_index_->metadata_}; + redis::VectorItem vec7 = {"7", {7.0, 7.0, 7.0, 7.0}, hnsw_index_->metadata_}; + + std::vector candidates = {vec3, vec2}; + auto s1 = hnsw_index_->SelectNeighbors(vec1, candidates, 1); + ASSERT_TRUE(s1.IsOK()); + auto selected = s1.GetValue(); + EXPECT_EQ(selected.size(), candidates.size()); + + EXPECT_EQ(selected[0].key, vec2.key); + EXPECT_EQ(selected[1].key, vec3.key); + + candidates = {vec4, vec2, vec5, vec7, vec3, vec6}; + auto s2 = hnsw_index_->SelectNeighbors(vec1, candidates, 1); + ASSERT_TRUE(s2.IsOK()); + selected = s2.GetValue(); + EXPECT_EQ(selected.size(), 3); + + EXPECT_EQ(selected[0].key, vec2.key); + EXPECT_EQ(selected[1].key, vec3.key); + EXPECT_EQ(selected[2].key, vec4.key); + + candidates = {vec4, vec2, vec5, vec7, vec3, vec6}; + auto s3 = hnsw_index_->SelectNeighbors(vec1, candidates, 0); + ASSERT_TRUE(s3.IsOK()); + selected = s3.GetValue(); + EXPECT_EQ(selected.size(), 6); + + EXPECT_EQ(selected[0].key, vec2.key); + EXPECT_EQ(selected[1].key, vec3.key); + EXPECT_EQ(selected[2].key, vec4.key); + EXPECT_EQ(selected[3].key, vec5.key); + EXPECT_EQ(selected[4].key, vec6.key); + EXPECT_EQ(selected[5].key, vec7.key); +} + +TEST_F(HnswIndexTest, SearchLayer) { + uint16_t layer = 3; + std::string node_key1 = "node1"; + std::string node_key2 = "node2"; + std::string node_key3 = "node3"; + std::string node_key4 = "node4"; + std::string node_key5 = "node5"; + + redis::Node node1(node_key1, layer); + redis::Node node2(node_key2, layer); + redis::Node node3(node_key3, layer); + redis::Node node4(node_key4, layer); + redis::Node node5(node_key5, layer); + + redis::HnswNodeFieldMetadata metadata1(0, {1.0, 2.0, 3.0}); + redis::HnswNodeFieldMetadata metadata2(0, {4.0, 5.0, 6.0}); + redis::HnswNodeFieldMetadata metadata3(0, {7.0, 8.0, 9.0}); + redis::HnswNodeFieldMetadata metadata4(0, {2.0, 3.0, 4.0}); + redis::HnswNodeFieldMetadata metadata5(0, {5.0, 6.0, 7.0}); + + // Add Nodes + auto batch = storage_->GetWriteBatchBase(); + node1.PutMetadata(&metadata1, hnsw_index_->search_key_, hnsw_index_->storage_, batch); + node2.PutMetadata(&metadata2, hnsw_index_->search_key_, hnsw_index_->storage_, batch); + node3.PutMetadata(&metadata3, hnsw_index_->search_key_, hnsw_index_->storage_, batch); + node4.PutMetadata(&metadata4, hnsw_index_->search_key_, hnsw_index_->storage_, batch); + node5.PutMetadata(&metadata5, hnsw_index_->search_key_, hnsw_index_->storage_, batch); + auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + // Add Neighbours + batch = storage_->GetWriteBatchBase(); + auto s1 = node1.AddNeighbour("node2", hnsw_index_->search_key_, hnsw_index_->storage_, batch); + ASSERT_TRUE(s1.IsOK()); + auto s2 = node1.AddNeighbour("node4", hnsw_index_->search_key_, hnsw_index_->storage_, batch); + ASSERT_TRUE(s2.IsOK()); + auto s3 = node2.AddNeighbour("node1", hnsw_index_->search_key_, hnsw_index_->storage_, batch); + ASSERT_TRUE(s3.IsOK()); + auto s4 = node2.AddNeighbour("node3", hnsw_index_->search_key_, hnsw_index_->storage_, batch); + ASSERT_TRUE(s1.IsOK()); + auto s5 = node3.AddNeighbour("node2", hnsw_index_->search_key_, hnsw_index_->storage_, batch); + ASSERT_TRUE(s5.IsOK()); + auto s6 = node3.AddNeighbour("node5", hnsw_index_->search_key_, hnsw_index_->storage_, batch); + ASSERT_TRUE(s6.IsOK()); + auto s7 = node4.AddNeighbour("node1", hnsw_index_->search_key_, hnsw_index_->storage_, batch); + ASSERT_TRUE(s7.IsOK()); + auto s8 = node5.AddNeighbour("node3", hnsw_index_->search_key_, hnsw_index_->storage_, batch); + ASSERT_TRUE(s8.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + redis::VectorItem target_vector("target", {2.0, 3.0, 4.0}, hnsw_index_->metadata_); + std::vector entry_points = {"node3", "node2"}; + uint32_t ef_runtime = 3; + + auto result = hnsw_index_->SearchLayer(layer, target_vector, ef_runtime, entry_points); + ASSERT_TRUE(result.IsOK()); + auto candidates = result.GetValue(); + + // TODO(Beihao): Found bug for comparison, may implement customized comparison function + // ASSERT_EQ(candidates.size(), ef_runtime); + // EXPECT_EQ(candidates[0].key, "node4"); + // EXPECT_EQ(candidates[1].key, "node1"); + // EXPECT_EQ(candidates[2].key, "node2"); +} + +TEST_F(HnswIndexTest, InsertVectorEntry) { + // TODO(Beihao): Consider how to test in a robust way +} \ No newline at end of file diff --git a/tests/cppunit/hnsw_node_test.cc b/tests/cppunit/hnsw_node_test.cc new file mode 100644 index 00000000000..4df469342e0 --- /dev/null +++ b/tests/cppunit/hnsw_node_test.cc @@ -0,0 +1,177 @@ +#include +#include +#include + +#include +#include +#include + +#include "search/hnsw_indexer.h" +#include "search/indexer.h" +#include "search/search_encoding.h" +#include "storage/storage.h" + +class NodeTest : public TestBase { + protected: + std::string ns = "node_test_ns"; + std::string idx_name = "node_test_idx"; + std::string key = "vector"; + redis::SearchKey search_key_; + + NodeTest() : search_key_(ns, idx_name, key) {} + + void TearDown() override {} +}; + +TEST_F(NodeTest, PutAndDecodeMetadata) { + uint16_t layer = 0; + redis::Node node1("node1", layer); + redis::Node node2("node2", layer); + redis::Node node3("node3", layer); + + redis::HnswNodeFieldMetadata metadata1(0, {1, 2, 3}); + redis::HnswNodeFieldMetadata metadata2(0, {4, 5, 6}); + redis::HnswNodeFieldMetadata metadata3(0, {7, 8, 9}); + + auto batch = storage_->GetWriteBatchBase(); + node1.PutMetadata(&metadata1, search_key_, storage_.get(), batch); + node2.PutMetadata(&metadata2, search_key_, storage_.get(), batch); + node3.PutMetadata(&metadata3, search_key_, storage_.get(), batch); + auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + auto decoded_metadata1 = node1.DecodeMetadata(search_key_, storage_.get()); + ASSERT_TRUE(decoded_metadata1.IsOK()); + ASSERT_EQ(decoded_metadata1.GetValue().num_neighbours, 0); + ASSERT_EQ(decoded_metadata1.GetValue().vector, std::vector({1, 2, 3})); + + auto decoded_metadata2 = node2.DecodeMetadata(search_key_, storage_.get()); + ASSERT_TRUE(decoded_metadata2.IsOK()); + ASSERT_EQ(decoded_metadata2.GetValue().num_neighbours, 0); + ASSERT_EQ(decoded_metadata2.GetValue().vector, std::vector({4, 5, 6})); + + auto decoded_metadata3 = node3.DecodeMetadata(search_key_, storage_.get()); + ASSERT_TRUE(decoded_metadata3.IsOK()); + ASSERT_EQ(decoded_metadata3.GetValue().num_neighbours, 0); + ASSERT_EQ(decoded_metadata3.GetValue().vector, std::vector({7, 8, 9})); + + // Prepare edges between node1 and node2 + batch = storage_->GetWriteBatchBase(); + auto edge1 = search_key_.ConstructHnswEdge(layer, "node1", "node2"); + auto edge2 = search_key_.ConstructHnswEdge(layer, "node2", "node1"); + auto edge3 = search_key_.ConstructHnswEdge(layer, "node2", "node3"); + auto edge4 = search_key_.ConstructHnswEdge(layer, "node3", "node2"); + + batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge1, Slice()); + batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge2, Slice()); + batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge3, Slice()); + batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge4, Slice()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + node1.DecodeNeighbours(search_key_, storage_.get()); + EXPECT_EQ(node1.neighbours.size(), 1); + EXPECT_EQ(node1.neighbours[0], "node2"); + + node2.DecodeNeighbours(search_key_, storage_.get()); + EXPECT_EQ(node2.neighbours.size(), 2); + std::unordered_set expected_neighbours = {"node1", "node3"}; + std::unordered_set actual_neighbours(node2.neighbours.begin(), node2.neighbours.end()); + EXPECT_EQ(actual_neighbours, expected_neighbours); + + node3.DecodeNeighbours(search_key_, storage_.get()); + EXPECT_EQ(node3.neighbours.size(), 1); + EXPECT_EQ(node3.neighbours[0], "node2"); +} + +TEST_F(NodeTest, ModifyNeighbours) { + uint16_t layer = 1; + redis::Node node1("node1", layer); + redis::Node node2("node2", layer); + redis::Node node3("node3", layer); + redis::Node node4("node4", layer); + + redis::HnswNodeFieldMetadata metadata1(0, {1, 2, 3}); + redis::HnswNodeFieldMetadata metadata2(0, {4, 5, 6}); + redis::HnswNodeFieldMetadata metadata3(0, {7, 8, 9}); + redis::HnswNodeFieldMetadata metadata4(0, {10, 11, 12}); + + // Add Nodes + auto batch1 = storage_->GetWriteBatchBase(); + node1.PutMetadata(&metadata1, search_key_, storage_.get(), batch1); + node2.PutMetadata(&metadata2, search_key_, storage_.get(), batch1); + node3.PutMetadata(&metadata3, search_key_, storage_.get(), batch1); + node4.PutMetadata(&metadata4, search_key_, storage_.get(), batch1); + auto s = storage_->Write(storage_->DefaultWriteOptions(), batch1->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + // Add Edges + auto batch2 = storage_->GetWriteBatchBase(); + auto s1 = node1.AddNeighbour("node2", search_key_, storage_.get(), batch2); + ASSERT_TRUE(s1.IsOK()); + auto s2 = node2.AddNeighbour("node1", search_key_, storage_.get(), batch2); + ASSERT_TRUE(s2.IsOK()); + auto s3 = node2.AddNeighbour("node3", search_key_, storage_.get(), batch2); + ASSERT_TRUE(s3.IsOK()); + auto s4 = node3.AddNeighbour("node2", search_key_, storage_.get(), batch2); + ASSERT_TRUE(s4.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch2->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + node1.DecodeNeighbours(search_key_, storage_.get()); + EXPECT_EQ(node1.neighbours.size(), 1); + EXPECT_EQ(node1.neighbours[0], "node2"); + + node2.DecodeNeighbours(search_key_, storage_.get()); + EXPECT_EQ(node2.neighbours.size(), 2); + std::unordered_set expected_neighbours = {"node1", "node3"}; + std::unordered_set actual_neighbours(node2.neighbours.begin(), node2.neighbours.end()); + EXPECT_EQ(actual_neighbours, expected_neighbours); + + node3.DecodeNeighbours(search_key_, storage_.get()); + EXPECT_EQ(node3.neighbours.size(), 1); + EXPECT_EQ(node3.neighbours[0], "node2"); + + // Remove Edges + auto batch3 = storage_->GetWriteBatchBase(); + auto s5 = node2.RemoveNeighbour("node3", search_key_, storage_.get(), batch3); + ASSERT_TRUE(s5.IsOK()); + + s = storage_->Write(storage_->DefaultWriteOptions(), batch3->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + node2.DecodeNeighbours(search_key_, storage_.get()); + EXPECT_EQ(node2.neighbours.size(), 1); + EXPECT_EQ(node2.neighbours[0], "node1"); + + // Update neighbours with fully new edge + auto batch4 = storage_->GetWriteBatchBase(); + std::vector new_neighbours = {"node3"}; + std::unordered_set deleted_neighbours; + auto s6 = node1.UpdateNeighbours(new_neighbours, search_key_, storage_.get(), batch4, deleted_neighbours); + ASSERT_TRUE(s5.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch4->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + node1.DecodeNeighbours(search_key_, storage_.get()); + EXPECT_EQ(node1.neighbours.size(), 1); + EXPECT_EQ(node1.neighbours[0], "node3"); + + EXPECT_EQ(deleted_neighbours.size(), 1); + EXPECT_TRUE(deleted_neighbours.count("node2")); + + // Update neighbours with existing neighbours included + auto batch5 = storage_->GetWriteBatchBase(); + new_neighbours = {"node3", "node4"}; + auto s7 = node1.UpdateNeighbours(new_neighbours, search_key_, storage_.get(), batch5, deleted_neighbours); + ASSERT_TRUE(s6.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch5->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + node1.DecodeNeighbours(search_key_, storage_.get()); + EXPECT_EQ(node1.neighbours.size(), 2); + expected_neighbours = {new_neighbours.begin(), new_neighbours.end()}; + actual_neighbours = {node1.neighbours.begin(), node1.neighbours.end()}; + EXPECT_EQ(actual_neighbours, expected_neighbours); + EXPECT_EQ(deleted_neighbours.size(), 0); +}