Skip to content

Commit d5e9d5e

Browse files
committed
Global HNSW vector index
1 parent 6480e41 commit d5e9d5e

File tree

14 files changed

+735
-141
lines changed

14 files changed

+735
-141
lines changed

src/core/search/ast_expr.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) {
7373
this->filter = make_unique<AstNode>(std::move(filter));
7474
}
7575

76+
bool AstKnnNode::PreFilter() const {
77+
return filter == nullptr;
78+
}
79+
7680
} // namespace dfly::search
7781

7882
namespace std {

src/core/search/ast_expr.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ struct AstKnnNode {
114114
OwnedFtVector vec;
115115
std::string score_alias;
116116
std::optional<float> ef_runtime;
117+
118+
bool PreFilter() const;
117119
};
118120

119121
using NodeVariants =

src/core/search/base.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,24 @@
1616
#include "absl/container/flat_hash_set.h"
1717
#include "base/pmr/memory_resource.h"
1818
#include "core/string_map.h"
19+
#include "server/tx_base.h"
1920

2021
namespace dfly::search {
2122

2223
using DocId = uint32_t;
24+
using GlobalDocId = uint64_t;
25+
26+
inline GlobalDocId CreateGlobalDocId(ShardId shard_id, DocId local_doc_id) {
27+
return ((uint64_t)shard_id << 32) | local_doc_id;
28+
}
29+
30+
inline ShardId GlobalDocIdShardId(GlobalDocId id) {
31+
return (id >> 32);
32+
}
33+
34+
inline search::DocId GlobalDocIdLocalId(GlobalDocId id) {
35+
return (id)&0xFFFF;
36+
}
2337

2438
enum class VectorSimilarity { L2, IP, COSINE };
2539

src/core/search/indices.cc

Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -492,40 +492,36 @@ bool BaseVectorIndex<T>::Add(T id, const DocumentAccessor& doc, std::string_view
492492
return true;
493493
}
494494

495-
template <typename T>
496-
FlatVectorIndex<T>::FlatVectorIndex(const SchemaField::VectorParams& params,
497-
PMR_NS::memory_resource* mr)
498-
: BaseVectorIndex<T>{params.dim, params.sim}, entries_{mr} {
495+
FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params,
496+
PMR_NS::memory_resource* mr)
497+
: BaseVectorIndex<DocId>{params.dim, params.sim}, entries_{mr} {
499498
DCHECK(!params.use_hnsw);
500499
entries_.reserve(params.capacity * params.dim);
501500
}
502501

503-
template <typename T>
504-
void FlatVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
505-
DCHECK_LE(id * BaseVectorIndex<T>::dim_, entries_.size());
506-
if (id * BaseVectorIndex<T>::dim_ == entries_.size())
507-
entries_.resize((id + 1) * BaseVectorIndex<T>::dim_);
502+
void FlatVectorIndex::AddVector(DocId id, const typename BaseVectorIndex::VectorPtr& vector) {
503+
DCHECK_LE(id * dim_, entries_.size());
504+
if (id * dim_ == entries_.size())
505+
entries_.resize((id + 1) * dim_);
508506

509507
// TODO: Let get vector write to buf itself
510508
if (vector) {
511-
memcpy(&entries_[id * BaseVectorIndex<T>::dim_], vector.get(),
512-
BaseVectorIndex<T>::dim_ * sizeof(float));
509+
memcpy(&entries_[id * dim_], vector.get(), dim_ * sizeof(float));
513510
}
514511
}
515512

516-
template <typename T>
517-
void FlatVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
513+
void FlatVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
518514
// noop
519515
}
520516

521-
template <typename T> const float* FlatVectorIndex<T>::Get(T doc) const {
517+
const float* FlatVectorIndex::Get(DocId doc) const {
522518
return &entries_[doc * dim_];
523519
}
524520

525-
template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullValues() const {
521+
std::vector<DocId> FlatVectorIndex::GetAllDocsWithNonNullValues() const {
526522
std::vector<DocId> result;
527523

528-
size_t num_vectors = entries_.size() / BaseVectorIndex<T>::dim_;
524+
size_t num_vectors = entries_.size() / dim_;
529525
result.reserve(num_vectors);
530526

531527
for (DocId id = 0; id < num_vectors; ++id) {
@@ -535,7 +531,7 @@ template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullVa
535531
bool is_zero_vector = true;
536532

537533
// TODO: Consider don't use check for zero vector
538-
for (size_t i = 0; i < BaseVectorIndex<T>::dim_; ++i) {
534+
for (size_t i = 0; i < dim_; ++i) {
539535
if (vec[i] != 0.0f) { // TODO: Consider using a threshold for float comparison
540536
is_zero_vector = false;
541537
break;
@@ -552,9 +548,7 @@ template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullVa
552548
return result;
553549
}
554550

555-
template struct FlatVectorIndex<DocId>;
556-
557-
struct HnswlibAdapter {
551+
template <typename T> struct HnswlibAdapter {
558552
// Default setting of hnswlib/hnswalg
559553
constexpr static size_t kDefaultEfRuntime = 10;
560554

@@ -564,34 +558,45 @@ struct HnswlibAdapter {
564558
100 /* seed*/} {
565559
}
566560

567-
void Add(const float* data, DocId id) {
568-
if (world_.cur_element_count + 1 >= world_.max_elements_)
569-
world_.resizeIndex(world_.cur_element_count * 2);
570-
world_.addPoint(data, id);
561+
void Add(const float* data, T id) {
562+
while (true) {
563+
try {
564+
absl::ReaderMutexLock lock(&resize_mutex_);
565+
world_.addPoint(data, id);
566+
return;
567+
} catch (const std::exception& e) {
568+
std::string error_msg = e.what();
569+
if (absl::StrContains(error_msg, "The number of elements exceeds the specified limit")) {
570+
ResizeIfFull();
571+
continue;
572+
}
573+
throw e;
574+
}
575+
}
571576
}
572577

573-
void Remove(DocId id) {
578+
void Remove(T id) {
574579
try {
575580
world_.markDelete(id);
576581
} catch (const std::exception& e) {
577582
}
578583
}
579584

580-
vector<pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef) {
585+
vector<pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef) {
581586
world_.setEf(ef.value_or(kDefaultEfRuntime));
582587
return QueueToVec(world_.searchKnn(target, k));
583588
}
584589

585-
vector<pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
586-
const vector<DocId>& allowed) {
590+
vector<pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef,
591+
const vector<T>& allowed) {
587592
struct BinsearchFilter : hnswlib::BaseFilterFunctor {
588593
virtual bool operator()(hnswlib::labeltype id) {
589594
return binary_search(allowed->begin(), allowed->end(), id);
590595
}
591596

592-
BinsearchFilter(const vector<DocId>* allowed) : allowed{allowed} {
597+
BinsearchFilter(const vector<T>* allowed) : allowed{allowed} {
593598
}
594-
const vector<DocId>* allowed;
599+
const vector<T>* allowed;
595600
};
596601

597602
world_.setEf(ef.value_or(kDefaultEfRuntime));
@@ -613,8 +618,30 @@ struct HnswlibAdapter {
613618
return visit([](auto& space) -> hnswlib::SpaceInterface<float>* { return &space; }, space_);
614619
}
615620

616-
template <typename Q> static vector<pair<float, DocId>> QueueToVec(Q queue) {
617-
vector<pair<float, DocId>> out(queue.size());
621+
void ResizeIfFull() {
622+
{
623+
absl::ReaderMutexLock lock(&resize_mutex_);
624+
if (world_.getCurrentElementCount() < world_.getMaxElements() ||
625+
(world_.allow_replace_deleted_ && world_.getDeletedCount() > 0)) {
626+
return;
627+
}
628+
}
629+
try {
630+
absl::WriterMutexLock lock(&resize_mutex_);
631+
if (world_.getCurrentElementCount() == world_.getMaxElements() &&
632+
(!world_.allow_replace_deleted_ || world_.getDeletedCount() == 0)) {
633+
auto max_elements = world_.getMaxElements();
634+
world_.resizeIndex(max_elements * 2);
635+
LOG(INFO) << "Resizing HNSW Index, current size: " << max_elements
636+
<< ", expand by: " << max_elements * 2;
637+
}
638+
} catch (const std::exception& e) {
639+
throw e;
640+
}
641+
}
642+
643+
template <typename Q> static vector<pair<float, GlobalDocId>> QueueToVec(Q queue) {
644+
vector<pair<float, GlobalDocId>> out(queue.size());
618645
size_t idx = out.size();
619646
while (!queue.empty()) {
620647
out[--idx] = queue.top();
@@ -625,45 +652,46 @@ struct HnswlibAdapter {
625652

626653
SpaceUnion space_;
627654
hnswlib::HierarchicalNSW<float> world_;
655+
absl::Mutex resize_mutex_;
628656
};
629657

630-
template <typename T>
631-
HnswVectorIndex<T>::HnswVectorIndex(const SchemaField::VectorParams& params,
632-
PMR_NS::memory_resource*)
633-
: BaseVectorIndex<T>{params.dim, params.sim}, adapter_{make_unique<HnswlibAdapter>(params)} {
658+
HnswVectorIndexShardPlaceholder::HnswVectorIndexShardPlaceholder(
659+
const SchemaField::VectorParams& params)
660+
: BaseVectorIndex<DocId>{params.dim, params.sim} {
661+
}
662+
663+
HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource*)
664+
: BaseVectorIndex<GlobalDocId>{params.dim, params.sim},
665+
adapter_{make_unique<HnswlibAdapter<GlobalDocId>>(params)} {
634666
DCHECK(params.use_hnsw);
635667
// TODO: Patch hnsw to use MR
636668
}
637-
template <typename T> HnswVectorIndex<T>::~HnswVectorIndex() {
669+
670+
HnswVectorIndex::~HnswVectorIndex() {
638671
}
639672

640-
template <typename T>
641-
void HnswVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
673+
void HnswVectorIndex::AddVector(GlobalDocId id,
674+
const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) {
642675
if (vector) {
643676
adapter_->Add(vector.get(), id);
644677
}
645678
}
646679

647-
template <typename T>
648-
std::vector<std::pair<float, T>> HnswVectorIndex<T>::Knn(float* target, size_t k,
649-
std::optional<size_t> ef) const {
680+
std::vector<std::pair<float, GlobalDocId>> HnswVectorIndex::Knn(float* target, size_t k,
681+
std::optional<size_t> ef) const {
650682
return adapter_->Knn(target, k, ef);
651683
}
652684

653-
template <typename T>
654-
std::vector<std::pair<float, T>> HnswVectorIndex<T>::Knn(float* target, size_t k,
655-
std::optional<size_t> ef,
656-
const std::vector<T>& allowed) const {
685+
std::vector<std::pair<float, GlobalDocId>> HnswVectorIndex::Knn(
686+
float* target, size_t k, std::optional<size_t> ef,
687+
const std::vector<GlobalDocId>& allowed) const {
657688
return adapter_->Knn(target, k, ef, allowed);
658689
}
659690

660-
template <typename T>
661-
void HnswVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
691+
void HnswVectorIndex::Remove(GlobalDocId id, const DocumentAccessor& doc, string_view field) {
662692
adapter_->Remove(id);
663693
}
664694

665-
template struct HnswVectorIndex<DocId>;
666-
667695
GeoIndex::GeoIndex(PMR_NS::memory_resource* mr) : rtree_(make_unique<rtree>()) {
668696
}
669697

src/core/search/indices.h

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -175,53 +175,70 @@ template <typename T> struct BaseVectorIndex : public BaseIndex<T> {
175175
};
176176

177177
// Index for vector fields.
178-
// Only supports lookup by id.
179-
template <typename T> struct FlatVectorIndex : public BaseVectorIndex<T> {
178+
// Only supports lookup by id.WW
179+
struct FlatVectorIndex : public BaseVectorIndex<DocId> {
180180
FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
181181

182-
void Remove(T id, const DocumentAccessor& doc, std::string_view field) override;
182+
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
183183

184-
const float* Get(T doc) const;
184+
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k) const;
185185

186-
// Return all documents that have vectors in this index
187-
std::vector<T> GetAllDocsWithNonNullValues() const override;
186+
const float* Get(DocId doc) const;
187+
188+
std::vector<DocId> GetAllDocsWithNonNullValues() const override;
188189

189190
protected:
190-
using BaseVectorIndex<T>::dim_;
191-
void AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) override;
191+
using BaseVectorIndex<DocId>::dim_;
192+
void AddVector(DocId id, const typename BaseVectorIndex<DocId>::VectorPtr& vector) override;
192193

193194
private:
194195
PMR_NS::vector<float> entries_;
195196
};
196197

197-
extern template struct FlatVectorIndex<DocId>;
198+
struct HnswVectorIndexShardPlaceholder : public BaseVectorIndex<DocId> {
199+
explicit HnswVectorIndexShardPlaceholder(const SchemaField::VectorParams& params);
198200

199-
struct HnswlibAdapter;
201+
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override {
202+
// noop
203+
}
200204

201-
template <typename T> struct HnswVectorIndex : public BaseVectorIndex<T> {
205+
// Return all documents that have vectors in this index
206+
std::vector<DocId> GetAllDocsWithNonNullValues() const override {
207+
return {};
208+
}
209+
210+
protected:
211+
using BaseVectorIndex<DocId>::dim_;
212+
void AddVector(DocId id, const typename BaseVectorIndex<DocId>::VectorPtr& vector) override {
213+
// noop
214+
}
215+
};
216+
217+
template <typename T> struct HnswlibAdapter;
218+
struct HnswVectorIndex : public BaseVectorIndex<GlobalDocId> {
202219
HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
203220
~HnswVectorIndex();
204221

205-
void Remove(T id, const DocumentAccessor& doc, std::string_view field) override;
222+
void Remove(GlobalDocId id, const DocumentAccessor& doc, std::string_view field) override;
206223

207-
std::vector<std::pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef) const;
208-
std::vector<std::pair<float, T>> Knn(float* target, size_t k, std::optional<size_t> ef,
209-
const std::vector<T>& allowed) const;
224+
std::vector<std::pair<float, GlobalDocId>> Knn(float* target, size_t k,
225+
std::optional<size_t> ef) const;
226+
std::vector<std::pair<float, GlobalDocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
227+
const std::vector<GlobalDocId>& allowed) const;
210228

211229
// TODO: Implement if needed
212-
std::vector<T> GetAllDocsWithNonNullValues() const override {
213-
return std::vector<T>{};
230+
std::vector<GlobalDocId> GetAllDocsWithNonNullValues() const override {
231+
return std::vector<GlobalDocId>{};
214232
}
215233

216234
protected:
217-
void AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) override;
235+
void AddVector(GlobalDocId id,
236+
const typename BaseVectorIndex<GlobalDocId>::VectorPtr& vector) override;
218237

219238
private:
220-
std::unique_ptr<HnswlibAdapter> adapter_;
239+
std::unique_ptr<HnswlibAdapter<GlobalDocId>> adapter_;
221240
};
222241

223-
extern template struct HnswVectorIndex<DocId>;
224-
225242
struct GeoIndex : public BaseIndex<DocId> {
226243
using point =
227244
boost::geometry::model::point<double, 2,

0 commit comments

Comments
 (0)