From 49b745d0ec95fc9a3261882a64546178e20a9644 Mon Sep 17 00:00:00 2001 From: Beihao Zhou Date: Tue, 25 Jun 2024 07:24:22 +0000 Subject: [PATCH] Handle returned Status & Integrate kqir::Value --- src/search/hnsw_indexer.h | 318 +++++++++++++++++------------------ src/search/indexer.cc | 31 ++-- src/search/indexer.h | 4 +- src/search/search_encoding.h | 27 ++- 4 files changed, 193 insertions(+), 187 deletions(-) diff --git a/src/search/hnsw_indexer.h b/src/search/hnsw_indexer.h index a91c5a135ec..cd5bf8e1399 100644 --- a/src/search/hnsw_indexer.h +++ b/src/search/hnsw_indexer.h @@ -20,66 +20,92 @@ #pragma once +#include + #include #include #include +#include #include #include #include #include -#include #include "db_util.h" #include "parse_util.h" #include "search/indexer.h" #include "search/search_encoding.h" +#include "search/value.h" #include "storage/redis_metadata.h" #include "storage/storage.h" - namespace redis { -struct VectorItem { - std::string key; - // TODO: use template based on VectorType - std::vector vector; - HnswVectorFieldMetadata* metadata; +class HnswIndex; - VectorItem(std::string_view key, std::string_view vector_str, HnswVectorFieldMetadata* metadata) : key(key), metadata(metadata) { - Decode(vector_str); - } +struct Node { + using NodeKey = std::string; + + NodeKey key; + uint16_t level; + std::vector neighbours; - // TODO: move it to util - void Decode(std::string_view vector_str) { - std::string trimmed = std::string(vector_str); - trimmed.erase(0, 1); // remove the first '[' - trimmed.erase(trimmed.size() - 1, 1); // remove the last ']' + Node(const NodeKey& key, uint16_t level) : key(key), level(level) {} + + StatusOr DecodeMetadata(const SearchKey& search_key, engine::Storage* storage) { + auto node_index_key = search_key.ConstructHnswNode(level, key); + rocksdb::PinnableSlice value; + auto s = storage->Get(rocksdb::ReadOptions(), storage->GetCFHandle(ColumnFamilyID::Search), node_index_key, &value); + if (!s.ok()) return {Status::NotOK, s.ToString()}; - std::istringstream iss(trimmed); - std::string num; + HnswNodeFieldMetadata metadata; + s = metadata.Decode(&value); + if (!s.ok()) return {Status::NotOK, s.ToString()}; + return metadata; + } - vector.clear(); + void PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage, + ObserverOrUniquePtr& batch) { + std::string updated_metadata; + node_meta->Encode(&updated_metadata); + batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), search_key.ConstructHnswNode(level, key), + updated_metadata); + } - while (std::getline(iss, num, ',')) { - try { - double value = std::stod(num); - vector.push_back(value); - } catch (const std::invalid_argument& ia) { - throw std::runtime_error("Invalid number in vector string: " + num); + void DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage) { + auto edge_prefix = search_key.ConstructHnswEdgeWithSingleEnd(level, key); + util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search); + for (iter->Seek(edge_prefix); iter->Valid(); iter->Next()) { + if (!iter->key().starts_with(edge_prefix)) { + break; } + auto neighbour_key = iter->key().ToString().substr(edge_prefix.size()); + neighbours.push_back(std::move(neighbour_key)); } } + + friend class HnswIndex; }; -auto ComputeDistance(const VectorItem& left, const VectorItem& right) { - if (left.metadata->distance_metric != right.metadata->distance_metric) - // throw error - ; +struct VectorItem { + using NodeKey = Node::NodeKey; + + NodeKey key; + kqir::NumericArray vector; + const HnswVectorFieldMetadata* metadata; + + VectorItem(const NodeKey& key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata) + : key(key), vector(std::move(vector)), metadata(metadata) {} + VectorItem(const NodeKey& key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata) + : key(key), vector(std::move(vector)), metadata(metadata) {} + + bool operator<(const VectorItem& other) const { return key < other.key; } +}; + +StatusOr ComputeDistance(const VectorItem& left, const VectorItem& right) { + if (left.metadata->distance_metric != right.metadata->distance_metric || left.metadata->dim != right.metadata->dim) + return {Status::InvalidArgument, "Vectors must be of the same metric and dimension to compute distance."}; - if (left.metadata->dim != right.metadata->dim) - // throw error - ; - auto metric = left.metadata->distance_metric; auto dim = left.metadata->dim; @@ -112,124 +138,82 @@ auto ComputeDistance(const VectorItem& left, const VectorItem& right) { return 1.0 - similarity; } default: - // throw error - return 0.0; + __builtin_unreachable(); } } -bool operator<(const VectorItem& lhs, const VectorItem& rhs) { - if (lhs.key != rhs.key) { - return lhs.key < rhs.key; - } - if (!lhs.vector.empty() && !rhs.vector.empty()) { - return lhs.vector[0] < rhs.vector[0]; - } - return false; -} - -struct Node { - using NodeKey = std::string; - - NodeKey key; - uint16_t level; - std::vector neighbours; - - Node(const NodeKey& key, uint16_t level) : key(key), level(level) {} - - HnswNodeFieldMetadata DecodeNodeMetadata(const SearchKey& search_key, engine::Storage *storage) { - auto node_index_key = search_key.ConstructHnswNode(level, key); - std::string value; - rocksdb::Status s = storage->Get(rocksdb::ReadOptions(), node_index_key, &value); - HnswNodeFieldMetadata metadata; - Slice input(value); - s = metadata.Decode(&input); - return metadata; - } - - void DecodeNeighbours(const SearchKey& search_key, engine::Storage *storage) { - auto edge_prefix = search_key.ConstructHnswEdgeWithSingleEnd(level, key); - util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search); - for (iter->Seek(edge_prefix); iter->Valid(); iter->Next()) { - if (!iter->key().starts_with(edge_prefix)) { - break; - } - auto neighbour_key = iter->key().ToString().substr(edge_prefix.size()); - neighbours.push_back(neighbour_key); - } - } -}; - class HnswIndex { public: using NodeKey = Node::NodeKey; SearchKey search_key_; HnswVectorFieldMetadata* metadata_; + engine::Storage* storage_ = nullptr; + std::mt19937 generator_; double m_level_normalization_factor_; - engine::Storage *storage = nullptr; - HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, engine::Storage *storage) - : search_key_(search_key), metadata_(vector), storage(storage) { + HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, engine::Storage* storage) + : search_key_(search_key), metadata_(vector), storage_(storage) { m_level_normalization_factor_ = 1.0 / std::log(metadata_->m); std::random_device rand_dev; generator_ = std::mt19937(rand_dev()); } - int RandomizeLayer() { + uint16_t RandomizeLayer() { std::uniform_real_distribution level_dist(0.0, 1.0); - return static_cast(std::floor(-std::log(level_dist(generator_)) * m_level_normalization_factor_)); + return static_cast(std::floor(-std::log(level_dist(generator_)) * m_level_normalization_factor_)); } - NodeKey DefaultEntryPoint(uint16_t level) { + StatusOr DefaultEntryPoint(uint16_t level) { auto prefix = search_key_.ConstructHnswLevelNodePrefix(level); - util::UniqueIterator it(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search); + util::UniqueIterator it(storage_, storage_->DefaultScanOptions(), ColumnFamilyID::Search); it->Seek(prefix); + Slice node_key; + Slice node_key_dst; if (it->Valid() && it->key().starts_with(prefix)) { - Slice node_key_dst; - auto node_key = Slice(it->key().ToString().substr(prefix.size())); + node_key = Slice(it->key().ToString().substr(prefix.size())); if (!GetSizedString(&node_key, &node_key_dst)) { - // error handling - return ""; + return {Status::NotFound, fmt::format("No node found in layer {}", level)}; } - return node_key_dst.ToString(); } - return ""; + return node_key_dst.ToString(); } - void Connect(uint16_t layer, NodeKey node_key1, NodeKey node_key2, - ObserverOrUniquePtr &batch, rocksdb::ColumnFamilyHandle* cf_handle) { + Status Connect(uint16_t layer, const NodeKey& node_key1, const NodeKey& node_key2, + ObserverOrUniquePtr& batch) { + auto cf_handle = storage_->GetCFHandle(ColumnFamilyID::Search); auto edge_index_key1 = search_key_.ConstructHnswEdge(layer, node_key1, node_key2); batch->Put(cf_handle, edge_index_key1, Slice()); auto edge_index_key2 = search_key_.ConstructHnswEdge(layer, node_key2, node_key1); batch->Put(cf_handle, edge_index_key2, Slice()); - Node node1 = Node(node_key1, layer); - HnswNodeFieldMetadata node1_metadata = node1.DecodeNodeMetadata(search_key_, storage); + auto node1 = Node(node_key1, layer); + HnswNodeFieldMetadata node1_metadata = GET_OR_RET(node1.DecodeMetadata(search_key_, storage_)); node1_metadata.num_neighbours += 1; - std::string node1_updated_metadata; - node1_metadata.Encode(&node1_updated_metadata); - batch->Put(cf_handle, node_key1, node1_updated_metadata); + node1.PutMetadata(&node1_metadata, search_key_, storage_, batch); - Node node2 = Node(node_key2, layer); - HnswNodeFieldMetadata node2_metadata = node2.DecodeNodeMetadata(search_key_, storage); + auto node2 = Node(node_key2, layer); + HnswNodeFieldMetadata node2_metadata = GET_OR_RET(node2.DecodeMetadata(search_key_, storage_)); node2_metadata.num_neighbours += 1; - std::string node2_updated_metadata; - node2_metadata.Encode(&node2_updated_metadata); - batch->Put(cf_handle, node_key2, node2_updated_metadata); + node2.PutMetadata(&node1_metadata, search_key_, storage_, batch); + + return Status::OK(); } - void ResetEdges(const VectorItem& vec, const std::vector& neighbour_vertors, uint16_t layer, - ObserverOrUniquePtr &batch, rocksdb::ColumnFamilyHandle* cf_handle) { + // Assume the new_neighbour_vertors is a subset of the original neighbours + Status PruneEdges(const VectorItem& vec, const std::vector& new_neighbour_vertors, uint16_t layer, + ObserverOrUniquePtr& batch) { + auto cf_handle = storage_->GetCFHandle(ColumnFamilyID::Search); std::unordered_set neighbours; - for (const auto& neighbour_vector : neighbour_vertors) { + for (const auto& neighbour_vector : new_neighbour_vertors) { neighbours.insert(neighbour_vector.key); } auto edge_prefix = search_key_.ConstructHnswEdgeWithSingleEnd(layer, vec.key); - util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search); + util::UniqueIterator iter(storage_, storage_->DefaultScanOptions(), ColumnFamilyID::Search); for (iter->Seek(edge_prefix); iter->Valid(); iter->Next()) { if (!iter->key().starts_with(edge_prefix)) { break; @@ -242,18 +226,20 @@ class HnswIndex { } Node node = Node(vec.key, layer); - HnswNodeFieldMetadata node_metadata = node.DecodeNodeMetadata(search_key_, storage); + HnswNodeFieldMetadata node_metadata = GET_OR_RET(node.DecodeMetadata(search_key_, storage_)); node_metadata.num_neighbours = neighbours.size(); - std::string node_updated_metadata; - node_metadata.Encode(&node_updated_metadata); - batch->Put(cf_handle, vec.key, node_updated_metadata); + node.PutMetadata(&node_metadata, search_key_, storage_, batch); + + return Status::OK(); } - std::vector SelectNeighbors(const VectorItem& vec, const std::vector& vertors, uint16_t layer) { + StatusOr> SelectNeighbors(const VectorItem& vec, const std::vector& vertors, + uint16_t layer) { std::vector> distances; distances.reserve(vertors.size()); for (const auto& candidate : vertors) { - distances.push_back( { ComputeDistance(vec, candidate), candidate } ); + auto dist = GET_OR_RET(ComputeDistance(vec, candidate)); + distances.push_back({dist, candidate}); } std::sort(distances.begin(), distances.end()); @@ -267,8 +253,8 @@ class HnswIndex { return selected_vs; } - std::vector SearchLayer(uint16_t level, const VectorItem& base_vector, uint32_t ef_runtime, - const std::vector& entry_points) { + StatusOr> SearchLayer(uint16_t level, const VectorItem& target_vector, uint32_t ef_runtime, + const std::vector& entry_points) { std::vector candidates; std::unordered_set visited; std::priority_queue, std::vector>, std::greater<>> @@ -277,34 +263,39 @@ class HnswIndex { for (const auto& entry_point_key : entry_points) { Node entry_node = Node(entry_point_key, level); - HnswNodeFieldMetadata node_metadata = entry_node.DecodeNodeMetadata(search_key_, storage); - auto entry_point_vector = VectorItem(entry_point_key, node_metadata.vector, metadata_); - auto dist = ComputeDistance(base_vector, entry_point_vector); - explore_heap.push({dist, entry_point_vector}); - result_heap.push({dist, entry_point_vector}); + auto entry_node_metadata = GET_OR_RET(entry_node.DecodeMetadata(search_key_, storage_)); + + auto entry_point_vector = VectorItem(entry_point_key, std::move(entry_node_metadata.vector), metadata_); + auto dist = GET_OR_RET(ComputeDistance(target_vector, entry_point_vector)); + + explore_heap.push(std::make_pair(dist, entry_point_vector)); + result_heap.push(std::make_pair(dist, std::move(entry_point_vector))); visited.insert(entry_point_key); } while (!explore_heap.empty()) { - auto [dist, current_vector] = explore_heap.top(); + auto& [dist, current_vector] = explore_heap.top(); explore_heap.pop(); if (dist > result_heap.top().first) { break; } - auto node = Node(current_vector.key, level); - node.DecodeNeighbours(search_key_, storage); - for (const auto& neighbour_key : node.neighbours) { + auto current_node = Node(current_vector.key, level); + current_node.DecodeNeighbours(search_key_, storage_); + + for (const auto& neighbour_key : current_node.neighbours) { if (visited.find(neighbour_key) != visited.end()) { continue; } visited.insert(neighbour_key); - Node neighbour_node = Node(neighbour_key, level); - HnswNodeFieldMetadata neighbour_metadata = neighbour_node.DecodeNodeMetadata(search_key_, storage); - auto neighbour_node_vector = VectorItem(neighbour_key, neighbour_metadata.vector, metadata_); - auto dist = ComputeDistance(current_vector, neighbour_node_vector); - explore_heap.push({dist, neighbour_node_vector}); - result_heap.push({dist, neighbour_node_vector}); + + auto neighbour_node = Node(neighbour_key, level); + auto neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key_, storage_)); + auto neighbour_node_vector = VectorItem(neighbour_key, std::move(neighbour_node_metadata.vector), metadata_); + + auto dist = GET_OR_RET(ComputeDistance(target_vector, neighbour_node_vector)); + explore_heap.push(std::make_pair(dist, neighbour_node_vector)); + result_heap.push(std::make_pair(dist, neighbour_node_vector)); while (result_heap.size() > ef_runtime) { result_heap.pop(); } @@ -318,70 +309,71 @@ class HnswIndex { return candidates; } - void InsertVectorEntry(std::string_view key, std::string_view vector_str, ObserverOrUniquePtr &batch) { - auto cf_handle = storage->GetCFHandle(ColumnFamilyID::Search); - auto vector_item = VectorItem(key, vector_str, metadata_); - int target_level = RandomizeLayer(); - std::vector nearest_elements; + Status InsertVectorEntry(std::string_view key, kqir::NumericArray vector, + ObserverOrUniquePtr& batch) { + auto cf_handle = storage_->GetCFHandle(ColumnFamilyID::Search); + auto inserted_vector_item = VectorItem(std::string(key), vector, metadata_); + auto target_level = RandomizeLayer(); + std::vector nearest_vec_items; if (metadata_->num_levels != 0) { auto level = metadata_->num_levels - 1; - std::vector entry_points{DefaultEntryPoint(level)}; + auto default_entry_node = GET_OR_RET(DefaultEntryPoint(level)); + std::vector entry_points{default_entry_node}; for (; level > target_level; level--) { - nearest_elements = SearchLayer(level, vector_item, metadata_->ef_runtime, entry_points); - entry_points = {nearest_elements[0].key}; + nearest_vec_items = GET_OR_RET(SearchLayer(level, inserted_vector_item, metadata_->ef_runtime, entry_points)); + entry_points = {nearest_vec_items[0].key}; } for (; level >= 0; level--) { - nearest_elements = SearchLayer(level, vector_item, metadata_->ef_construction, entry_points); - auto connect_vec_items = SelectNeighbors(vector_item, nearest_elements, level); + nearest_vec_items = + GET_OR_RET(SearchLayer(level, inserted_vector_item, metadata_->ef_construction, entry_points)); + auto connect_vec_items = GET_OR_RET(SelectNeighbors(inserted_vector_item, nearest_vec_items, level)); + for (const auto& connected_vec_item : connect_vec_items) { - Connect(level, vector_item.key, connected_vec_item.key, batch, cf_handle); + GET_OR_RET(Connect(level, inserted_vector_item.key, connected_vec_item.key, batch)); } for (const auto& connected_vec_item : connect_vec_items) { auto connected_node = Node(connected_vec_item.key, level); - auto connected_node_metadata = connected_node.DecodeNodeMetadata(search_key_, storage); + auto connected_node_metadata = GET_OR_RET(connected_node.DecodeMetadata(search_key_, storage_)); + uint16_t connected_node_num_neighbours = connected_node_metadata.num_neighbours; auto m_max = level == 0 ? 2 * metadata_->m : metadata_->m; + if (connected_node_num_neighbours <= m_max) continue; - if (connected_node_num_neighbours <= m_max) { - continue; - } - - connected_node.DecodeNeighbours(search_key_, storage); + connected_node.DecodeNeighbours(search_key_, storage_); std::vector connected_node_neighbour_vec_items; for (const auto& connected_node_neighbour_key : connected_node.neighbours) { Node connected_node_neighbour = Node(connected_node_neighbour_key, level); - auto connected_node_neighbour_metadata = connected_node_neighbour.DecodeNodeMetadata(search_key_, storage); - auto neighbour_vector = VectorItem(connected_node_neighbour_key, connected_node_neighbour_metadata.vector, metadata_); + auto connected_node_neighbour_metadata = + GET_OR_RET(connected_node_neighbour.DecodeMetadata(search_key_, storage_)); + auto neighbour_vector = VectorItem(connected_node_neighbour_key, + std::move(connected_node_neighbour_metadata.vector), metadata_); connected_node_neighbour_vec_items.push_back(neighbour_vector); } - auto new_neighbors = SelectNeighbors(connected_vec_item, connected_node_neighbour_vec_items, level); - ResetEdges(connected_vec_item, new_neighbors, level, batch, cf_handle); + auto new_neighbors = + GET_OR_RET(SelectNeighbors(connected_vec_item, connected_node_neighbour_vec_items, level)); + GET_OR_RET(PruneEdges(connected_vec_item, new_neighbors, level, batch)); } entry_points.clear(); - for (const auto& new_entry_point : nearest_elements) { - entry_points.push_back(new_entry_point.key); + for (const auto& new_entry_point : nearest_vec_items) { + entry_points.push_back(std::move(new_entry_point.key)); } } } else { - auto node_index_key = search_key_.ConstructHnswNode(0, key); - HnswNodeFieldMetadata node_metadata(0, vector_str); - std::string encoded_metadata; - node_metadata.Encode(&encoded_metadata); - batch->Put(cf_handle, node_index_key, encoded_metadata); + auto node = Node(std::string(key), 0); + HnswNodeFieldMetadata node_metadata(0, vector); + node.PutMetadata(&node_metadata, search_key_, storage_, batch); metadata_->num_levels = 1; } - while (metadata_->num_levels - 1 < target_level) { - auto node_index_key = search_key_.ConstructHnswNode(metadata_->num_levels, key); - HnswNodeFieldMetadata node_metadata(0, vector_str); - std::string encoded_metadata; - node_metadata.Encode(&encoded_metadata); - batch->Put(cf_handle, node_index_key, encoded_metadata); + while (target_level > metadata_->num_levels - 1) { + auto node = Node(std::string(key), metadata_->num_levels); + HnswNodeFieldMetadata node_metadata(0, vector); + node.PutMetadata(&node_metadata, search_key_, storage_, batch); metadata_->num_levels++; } @@ -389,6 +381,8 @@ class HnswIndex { metadata_->Encode(&encoded_index_metadata); auto index_meta_key = search_key_.ConstructFieldMeta(); batch->Put(cf_handle, index_meta_key, encoded_index_metadata); + + return Status::OK(); } }; diff --git a/src/search/indexer.cc b/src/search/indexer.cc index 4fb2416ef3f..d43d05b6695 100644 --- a/src/search/indexer.cc +++ b/src/search/indexer.cc @@ -25,14 +25,13 @@ #include "db_util.h" #include "parse_util.h" +#include "search/hnsw_indexer.h" #include "search/search_encoding.h" #include "search/value.h" #include "storage/redis_metadata.h" #include "storage/storage.h" #include "string_util.h" #include "types/redis_hash.h" -#include "search/hnsw_indexer.h" - namespace redis { @@ -59,10 +58,6 @@ StatusOr FieldValueRetriever::Create(IndexOnDataType type, } } -// placeholders, remove them after vector indexing is implemented -static bool IsVectorType(const redis::IndexFieldMetadata *) { return false; } -static size_t GetVectorDim(const redis::IndexFieldMetadata *) { return 1; } - StatusOr FieldValueRetriever::ParseFromJson(const jsoncons::json &val, const redis::IndexFieldMetadata *type) { if (auto numeric [[maybe_unused]] = dynamic_cast(type)) { @@ -84,8 +79,8 @@ StatusOr FieldValueRetriever::ParseFromJson(const jsoncons::json &v } else { return {Status::NotOK, "json value should be string or array of strings for tag fields"}; } - } else if (IsVectorType(type)) { - size_t dim = GetVectorDim(type); + } else if (auto vector = dynamic_cast(type)) { + const auto dim = vector->dim; if (!val.is_array()) return {Status::NotOK, "json value should be array of numbers for vector fields"}; if (dim != val.size()) return {Status::NotOK, "the size of the json array is not equal to the dim of the vector"}; std::vector nums; @@ -109,8 +104,8 @@ StatusOr FieldValueRetriever::ParseFromHash(const std::string &valu const char delim[] = {tag->separator, '\0'}; auto vec = util::Split(value, delim); return kqir::MakeValue(vec); - } else if (IsVectorType(type)) { - const size_t dim = GetVectorDim(type); + } else if (auto vector = dynamic_cast(type)) { + const auto dim = vector->dim; if (value.size() != dim * sizeof(double)) { return {Status::NotOK, "field value is too short or too long to be parsed as a vector"}; } @@ -248,7 +243,7 @@ Status IndexUpdater::UpdateTagIndex(std::string_view key, const kqir::Value &ori Status IndexUpdater::UpdateNumericIndex(std::string_view key, const kqir::Value &original, const kqir::Value ¤t, const SearchKey &search_key, const NumericFieldMetadata *num) const { CHECK(original.IsNull() || original.Is()); - CHECK(original.IsNull() || original.Is()); + CHECK(current.IsNull() || current.Is()); auto *storage = indexer->storage; auto batch = storage->GetWriteBatchBase(); @@ -271,18 +266,22 @@ Status IndexUpdater::UpdateNumericIndex(std::string_view key, const kqir::Value return Status::OK(); } -Status IndexUpdater::UpdateHnswVectorIndex(std::string_view key, std::string_view original, std::string_view current, - const SearchKey &search_key, HnswVectorFieldMetadata *vector) const { +Status IndexUpdater::UpdateHnswVectorIndex(std::string_view key, const kqir::Value &original, + const kqir::Value ¤t, const SearchKey &search_key, + HnswVectorFieldMetadata *vector) const { + CHECK(original.IsNull() || original.Is()); + CHECK(current.IsNull() || current.Is()); + auto *storage = indexer->storage; auto batch = storage->GetWriteBatchBase(); - if (!original.empty()) { + if (!original.IsNull()) { // TODO: delete } - if (!current.empty()) { + if (!current.IsNull()) { auto hnsw = HnswIndex(search_key, vector, indexer->storage); - hnsw.InsertVectorEntry(key, current, batch); + GET_OR_RET(hnsw.InsertVectorEntry(key, current.Get(), batch)); } auto s = storage->Write(storage->DefaultWriteOptions(), batch->GetWriteBatch()); diff --git a/src/search/indexer.h b/src/search/indexer.h index ce0bdfef0dd..e5e0aa4fb50 100644 --- a/src/search/indexer.h +++ b/src/search/indexer.h @@ -89,8 +89,8 @@ struct IndexUpdater { const SearchKey &search_key, const TagFieldMetadata *tag) const; Status UpdateNumericIndex(std::string_view key, const kqir::Value &original, const kqir::Value ¤t, const SearchKey &search_key, const NumericFieldMetadata *num) const; - Status UpdateHnswVectorIndex(std::string_view key, std::string_view original, std::string_view current, - const SearchKey &search_key, HnswVectorFieldMetadata *vector) const; + Status UpdateHnswVectorIndex(std::string_view key, const kqir::Value &original, const kqir::Value ¤t, + const SearchKey &search_key, HnswVectorFieldMetadata *vector) const; }; struct GlobalIndexer { diff --git a/src/search/search_encoding.h b/src/search/search_encoding.h index ea10c2477c7..147a2d07e12 100644 --- a/src/search/search_encoding.h +++ b/src/search/search_encoding.h @@ -33,6 +33,7 @@ enum class IndexOnDataType : uint8_t { }; inline constexpr auto kErrorInsufficientLength = "insufficient length while decoding metadata"; +inline constexpr auto kErrorIncorrectLength = "length is too short or too long to be parsed as a vector"; class IndexMetadata { public: @@ -410,24 +411,36 @@ struct HnswVectorFieldMetadata : IndexFieldMetadata { struct HnswNodeFieldMetadata { uint16_t num_neighbours; - std::string vector; + std::vector vector; HnswNodeFieldMetadata() {} - HnswNodeFieldMetadata(uint16_t num_neighbours, std::string_view vector) : num_neighbours(num_neighbours), vector(vector) {} + HnswNodeFieldMetadata(uint16_t num_neighbours, std::vector vector) + : num_neighbours(num_neighbours), vector(vector) {} void Encode(std::string *dst) const { PutFixed16(dst, uint16_t(num_neighbours)); - PutSizedString(dst, vector); + PutFixed16(dst, uint16_t(vector.size())); + for (size_t i = 0; i < vector.size(); ++i) { + PutDouble(dst, vector[i]); + } } rocksdb::Status Decode(Slice *input) { - if (input->size() < 2 + 4) { + if (input->size() < 2 + 2) { return rocksdb::Status::Corruption(kErrorInsufficientLength); } GetFixed16(input, (uint16_t *)(&num_neighbours)); - Slice value; - GetSizedString(input, &value); - vector = value.ToString(); + + uint16_t dim; + GetFixed16(input, (uint16_t *)(&dim)); + + if (input->size() != dim * sizeof(double)) { + return rocksdb::Status::Corruption(kErrorIncorrectLength); + } + + for (size_t i = 0; i < dim; ++i) { + GetDouble(input, &vector[i]); + } return rocksdb::Status::OK(); } };