Skip to content

Commit

Permalink
Fix & Modularize & Add majority unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Beihao-Zhou committed Jul 1, 2024
1 parent 4716336 commit a98376f
Show file tree
Hide file tree
Showing 6 changed files with 586 additions and 55 deletions.
150 changes: 102 additions & 48 deletions src/search/hnsw_indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,92 @@ void Node::PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search
}

void Node::DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage) {
neighbours.clear();
auto edge_prefix = search_key.ConstructHnswEdgeWithSingleEnd(level, key);
util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search);
for (iter->Seek(edge_prefix); iter->Valid(); iter->Next()) {
if (!iter->key().starts_with(edge_prefix)) {
break;
}
auto neighbour_key = iter->key().ToString().substr(edge_prefix.size());
neighbours.push_back(std::move(neighbour_key));
auto neighbour_edge = iter->key();
neighbour_edge.remove_prefix(edge_prefix.size());
Slice neighbour;
GetSizedString(&neighbour_edge, &neighbour);
neighbours.push_back(neighbour.ToString());
}
}

Status Node::AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour_key);
batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key, Slice());

HnswNodeFieldMetadata node_metadata = GET_OR_RET(DecodeMetadata(search_key, storage));
node_metadata.num_neighbours++;
PutMetadata(&node_metadata, search_key, storage, batch);
return Status::OK();
}

Status Node::RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour_key);
auto s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key);
if (!s.ok()) {
return {Status::NotOK, fmt::format("failed to delete edge, {}", s.ToString())};
}

HnswNodeFieldMetadata node_metadata = GET_OR_RET(DecodeMetadata(search_key, storage));
node_metadata.num_neighbours--;
PutMetadata(&node_metadata, search_key, storage, batch);
return Status::OK();
}

Status Node::UpdateNeighbours(std::vector<NodeKey>& neighbours, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
std::unordered_set<NodeKey>& deleted_neighbours) {
deleted_neighbours.clear();
auto cf_handle = storage->GetCFHandle(ColumnFamilyID::Search);
auto edge_prefix = search_key.ConstructHnswEdgeWithSingleEnd(level, key);
std::unordered_set<NodeKey> to_be_added{neighbours.begin(), neighbours.end()};

util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search);
for (iter->Seek(edge_prefix); iter->Valid(); iter->Next()) {
if (!iter->key().starts_with(edge_prefix)) {
break;
}
auto neighbour_edge = iter->key();
neighbour_edge.remove_prefix(edge_prefix.size());
Slice neighbour;
GetSizedString(&neighbour_edge, &neighbour);
auto neighbour_key = neighbour.ToString();

if (to_be_added.count(neighbour_key) == 0) {
batch->Delete(cf_handle, iter->key());
deleted_neighbours.insert(neighbour_key);
} else {
to_be_added.erase(neighbour_key);
}
}

for (const auto& neighbour : to_be_added) {
auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour);
batch->Put(cf_handle, edge_index_key, Slice());
}

HnswNodeFieldMetadata node_metadata = GET_OR_RET(DecodeMetadata(search_key, storage));
node_metadata.num_neighbours = static_cast<uint16_t>(neighbours.size());
PutMetadata(&node_metadata, search_key, storage, batch);
return Status::OK();
}

VectorItem::VectorItem(const NodeKey& key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata)
: key(key), vector(std::move(vector)), metadata(metadata) {}
VectorItem::VectorItem(const NodeKey& key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata)
: key(key), vector(std::move(vector)), metadata(metadata) {}

bool VectorItem::operator<(const VectorItem& other) const { return key < other.key; }

StatusOr<double> ComputeDistance(const VectorItem& left, const VectorItem& right) {
StatusOr<double> ComputeSimilarity(const VectorItem& left, const VectorItem& right) {
if (left.metadata->distance_metric != right.metadata->distance_metric || left.metadata->dim != right.metadata->dim)
return {Status::InvalidArgument, "Vectors must be of the same metric and dimension to compute distance."};

Expand All @@ -99,14 +166,14 @@ StatusOr<double> ComputeDistance(const VectorItem& left, const VectorItem& right
}
case DistanceMetric::COSINE: {
double dist = 0.0;
double norma = 0.0;
double normb = 0.0;
double norm_left = 0.0;
double norm_right = 0.0;
for (auto i = 0; i < dim; i++) {
dist += left.vector[i] * right.vector[i];
norma += left.vector[i] * right.vector[i];
normb += left.vector[i] * right.vector[i];
norm_left += left.vector[i] * left.vector[i];
norm_right += right.vector[i] * right.vector[i];
}
auto similarity = dist / std::sqrt(norma * normb);
auto similarity = dist / std::sqrt(norm_left * norm_right);
return 1.0 - similarity;
}
default:
Expand Down Expand Up @@ -136,61 +203,48 @@ StatusOr<HnswIndex::NodeKey> HnswIndex::DefaultEntryPoint(uint16_t level) {
if (it->Valid() && it->key().starts_with(prefix)) {
node_key = Slice(it->key().ToString().substr(prefix.size()));
if (!GetSizedString(&node_key, &node_key_dst)) {
return {Status::NotFound, fmt::format("No node found in layer {}", level)};
return {Status::NotOK, fmt::format("fail to decode the default node key layer {}", level)};
}
return node_key_dst.ToString();
}
return node_key_dst.ToString();
return {Status::NotFound, fmt::format("No node found in layer {}", level)};
}

Status HnswIndex::Connect(uint16_t layer, const NodeKey& node_key1, const NodeKey& node_key2,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto cf_handle = storage_->GetCFHandle(ColumnFamilyID::Search);
auto edge_index_key1 = search_key_.ConstructHnswEdge(layer, node_key1, node_key2);
batch->Put(cf_handle, edge_index_key1, Slice());

auto edge_index_key2 = search_key_.ConstructHnswEdge(layer, node_key2, node_key1);
batch->Put(cf_handle, edge_index_key2, Slice());

auto node1 = Node(node_key1, layer);
HnswNodeFieldMetadata node1_metadata = GET_OR_RET(node1.DecodeMetadata(search_key_, storage_));
node1_metadata.num_neighbours += 1;
node1.PutMetadata(&node1_metadata, search_key_, storage_, batch);
GET_OR_RET(node1.AddNeighbour(node_key2, search_key_, storage_, batch));

auto node2 = Node(node_key2, layer);
HnswNodeFieldMetadata node2_metadata = GET_OR_RET(node2.DecodeMetadata(search_key_, storage_));
node2_metadata.num_neighbours += 1;
node2.PutMetadata(&node1_metadata, search_key_, storage_, batch);
GET_OR_RET(node2.AddNeighbour(node_key1, search_key_, storage_, batch));

return Status::OK();
}

// Assume the new_neighbour_vertors is a subset of the original neighbours
Status HnswIndex::PruneEdges(const VectorItem& vec, const std::vector<VectorItem>& new_neighbour_vertors,
Status HnswIndex::PruneEdges(const VectorItem& vec, const std::vector<VectorItem>& new_neighbour_vectors,
uint16_t layer, ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto cf_handle = storage_->GetCFHandle(ColumnFamilyID::Search);
std::unordered_set<NodeKey> neighbours;
for (const auto& neighbour_vector : new_neighbour_vertors) {
neighbours.insert(neighbour_vector.key);
}

auto edge_prefix = search_key_.ConstructHnswEdgeWithSingleEnd(layer, vec.key);
util::UniqueIterator iter(storage_, storage_->DefaultScanOptions(), ColumnFamilyID::Search);
for (iter->Seek(edge_prefix); iter->Valid(); iter->Next()) {
if (!iter->key().starts_with(edge_prefix)) {
break;
}
auto neighbour_key = iter->key().ToString().substr(edge_prefix.size());

if (neighbours.count(neighbour_key) == 0) {
batch->Delete(cf_handle, iter->key());
auto node = Node(vec.key, layer);
node.DecodeNeighbours(search_key_, storage_);
std::unordered_set original_neighbours{node.neighbours.begin(), node.neighbours.end()};

uint16_t neighbours_sz = static_cast<uint16_t>(new_neighbour_vectors.size());
std::vector<NodeKey> neighbours(neighbours_sz);
for (auto i = 0; i < neighbours_sz; i++) {
auto neighbour_key = new_neighbour_vectors[i].key;
if (original_neighbours.count(neighbour_key) == 0) {
return {Status::InvalidArgument,
fmt::format("Node \"{}\" is not a neighbour of \"{}\" and can't be pruned", neighbour_key, vec.key)};
}
neighbours[i] = new_neighbour_vectors[i].key;
}

Node node = Node(vec.key, layer);
HnswNodeFieldMetadata node_metadata = GET_OR_RET(node.DecodeMetadata(search_key_, storage_));
node_metadata.num_neighbours = neighbours.size();
node.PutMetadata(&node_metadata, search_key_, storage_, batch);
std::unordered_set<NodeKey> deleted_neighbours;
GET_OR_RET(node.UpdateNeighbours(neighbours, search_key_, storage_, batch, deleted_neighbours));

for (const auto& key : deleted_neighbours) {
auto neighbour_node = Node(key, layer);
GET_OR_RET(neighbour_node.RemoveNeighbour(vec.key, search_key_, storage_, batch));
}
return Status::OK();
}

Expand All @@ -199,7 +253,7 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SelectNeighbors(const VectorItem& v
std::vector<std::pair<double, VectorItem>> distances;
distances.reserve(vertors.size());
for (const auto& candidate : vertors) {
auto dist = GET_OR_RET(ComputeDistance(vec, candidate));
auto dist = GET_OR_RET(ComputeSimilarity(vec, candidate));
distances.push_back({dist, candidate});
}

Expand Down Expand Up @@ -228,7 +282,7 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V
auto entry_node_metadata = GET_OR_RET(entry_node.DecodeMetadata(search_key_, storage_));

auto entry_point_vector = VectorItem(entry_point_key, std::move(entry_node_metadata.vector), metadata_);
auto dist = GET_OR_RET(ComputeDistance(target_vector, entry_point_vector));
auto dist = GET_OR_RET(ComputeSimilarity(target_vector, entry_point_vector));

explore_heap.push(std::make_pair(dist, entry_point_vector));
result_heap.push(std::make_pair(dist, std::move(entry_point_vector)));
Expand All @@ -255,7 +309,7 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V
auto neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key_, storage_));
auto neighbour_node_vector = VectorItem(neighbour_key, std::move(neighbour_node_metadata.vector), metadata_);

auto dist = GET_OR_RET(ComputeDistance(target_vector, neighbour_node_vector));
auto dist = GET_OR_RET(ComputeSimilarity(target_vector, neighbour_node_vector));
explore_heap.push(std::make_pair(dist, neighbour_node_vector));
result_heap.push(std::make_pair(dist, neighbour_node_vector));
while (result_heap.size() > ef_runtime) {
Expand Down
9 changes: 8 additions & 1 deletion src/search/hnsw_indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ struct Node {
void PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
void DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage);
Status AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
Status RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
Status UpdateNeighbours(std::vector<NodeKey>& neighbours, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
std::unordered_set<NodeKey>& deleted_neighbours);

friend class HnswIndex;
};
Expand All @@ -62,7 +69,7 @@ struct VectorItem {
bool operator<(const VectorItem& other) const;
};

StatusOr<double> ComputeDistance(const VectorItem& left, const VectorItem& right);
StatusOr<double> ComputeSimilarity(const VectorItem& left, const VectorItem& right);

class HnswIndex {
public:
Expand Down
2 changes: 1 addition & 1 deletion src/search/indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ Status IndexUpdater::UpdateHnswVectorIndex(std::string_view key, const kqir::Val
auto batch = storage->GetWriteBatchBase();

if (!original.IsNull()) {
// TODO: delete
// TODO(Beihao): implement vector deletion
}

if (!current.IsNull()) {
Expand Down
10 changes: 5 additions & 5 deletions src/search/search_encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ enum class IndexFieldType : uint8_t {
};

enum class VectorType : uint8_t {
FLOAT32 = 0,
FLOAT64 = 1,
};

Expand Down Expand Up @@ -418,9 +417,9 @@ struct HnswNodeFieldMetadata {
: num_neighbours(num_neighbours), vector(std::move(vector)) {}

void Encode(std::string *dst) const {
PutFixed16(dst, uint16_t(num_neighbours));
PutFixed16(dst, uint16_t(vector.size()));
for (auto element : vector) {
PutFixed16(dst, num_neighbours);
PutFixed16(dst, static_cast<uint16_t>(vector.size()));
for (double element : vector) {
PutDouble(dst, element);
}
}
Expand All @@ -437,8 +436,9 @@ struct HnswNodeFieldMetadata {
if (input->size() != dim * sizeof(double)) {
return rocksdb::Status::Corruption(kErrorIncorrectLength);
}
vector.resize(dim);

for (size_t i = 0; i < dim; ++i) {
for (auto i = 0; i < dim; ++i) {
GetDouble(input, &vector[i]);
}
return rocksdb::Status::OK();
Expand Down
Loading

0 comments on commit a98376f

Please sign in to comment.