@@ -550,6 +550,10 @@ std::vector<DocId> FlatVectorIndex::GetAllDocsWithNonNullValues() const {
550550 return result;
551551}
552552
553+ ShardHnswVectorIndex::ShardHnswVectorIndex (const SchemaField::VectorParams& params)
554+ : BaseVectorIndex{params.dim , params.sim } {
555+ }
556+
553557struct HnswlibAdapter {
554558 // Default setting of hnswlib/hnswalg
555559 constexpr static size_t kDefaultEfRuntime = 10 ;
@@ -560,34 +564,45 @@ struct HnswlibAdapter {
560564 100 /* seed*/ } {
561565 }
562566
563- void Add (const float * data, DocId id) {
564- if (world_.cur_element_count + 1 >= world_.max_elements_ )
565- world_.resizeIndex (world_.cur_element_count * 2 );
566- world_.addPoint (data, id);
567+ void Add (const float * data, GlobalDocId id) {
568+ while (true ) {
569+ try {
570+ absl::ReaderMutexLock lock (&resize_mutex_);
571+ world_.addPoint (data, id);
572+ return ;
573+ } catch (const std::exception& e) {
574+ std::string error_msg = e.what ();
575+ if (absl::StrContains (error_msg, " The number of elements exceeds the specified limit" )) {
576+ ResizeIfFull ();
577+ continue ;
578+ }
579+ throw e;
580+ }
581+ }
567582 }
568583
569- void Remove (DocId id) {
584+ void Remove (GlobalDocId id) {
570585 try {
571586 world_.markDelete (id);
572587 } catch (const std::exception& e) {
573588 }
574589 }
575590
576- vector<pair<float , DocId >> Knn (float * target, size_t k, std::optional<size_t > ef) {
591+ vector<pair<float , GlobalDocId >> Knn (float * target, size_t k, std::optional<size_t > ef) {
577592 world_.setEf (ef.value_or (kDefaultEfRuntime ));
578593 return QueueToVec (world_.searchKnn (target, k));
579594 }
580595
581- vector<pair<float , DocId >> Knn (float * target, size_t k, std::optional<size_t > ef,
582- const vector<DocId >& allowed) {
596+ vector<pair<float , GlobalDocId >> Knn (float * target, size_t k, std::optional<size_t > ef,
597+ const vector<GlobalDocId >& allowed) {
583598 struct BinsearchFilter : hnswlib::BaseFilterFunctor {
584599 virtual bool operator ()(hnswlib::labeltype id) {
585600 return binary_search (allowed->begin (), allowed->end (), id);
586601 }
587602
588- BinsearchFilter (const vector<DocId >* allowed) : allowed{allowed} {
603+ BinsearchFilter (const vector<GlobalDocId >* allowed) : allowed{allowed} {
589604 }
590- const vector<DocId >* allowed;
605+ const vector<GlobalDocId >* allowed;
591606 };
592607
593608 world_.setEf (ef.value_or (kDefaultEfRuntime ));
@@ -609,8 +624,32 @@ struct HnswlibAdapter {
609624 return visit ([](auto & space) -> hnswlib::SpaceInterface<float >* { return &space; }, space_);
610625 }
611626
612- template <typename Q> static vector<pair<float , DocId>> QueueToVec (Q queue) {
613- vector<pair<float , DocId>> out (queue.size ());
627+ // Function requires that we hold mutex while resizing index. resizeIndex is not thread safe with
628+ // insertion (https://github.com/nmslib/hnswlib/issues/267)
629+ void ResizeIfFull () {
630+ {
631+ absl::ReaderMutexLock lock (&resize_mutex_);
632+ if (world_.getCurrentElementCount () < world_.getMaxElements () ||
633+ (world_.allow_replace_deleted_ && world_.getDeletedCount () > 0 )) {
634+ return ;
635+ }
636+ }
637+ try {
638+ absl::WriterMutexLock lock (&resize_mutex_);
639+ if (world_.getCurrentElementCount () == world_.getMaxElements () &&
640+ (!world_.allow_replace_deleted_ || world_.getDeletedCount () == 0 )) {
641+ auto max_elements = world_.getMaxElements ();
642+ world_.resizeIndex (max_elements * 2 );
643+ LOG (INFO) << " Resizing HNSW Index, current size: " << max_elements
644+ << " , expand by: " << max_elements * 2 ;
645+ }
646+ } catch (const std::exception& e) {
647+ throw e;
648+ }
649+ }
650+
651+ template <typename Q> static vector<pair<float , GlobalDocId>> QueueToVec (Q queue) {
652+ vector<pair<float , GlobalDocId>> out (queue.size ());
614653 size_t idx = out.size ();
615654 while (!queue.empty ()) {
616655 out[--idx] = queue.top ();
@@ -621,34 +660,50 @@ struct HnswlibAdapter {
621660
622661 SpaceUnion space_;
623662 hnswlib::HierarchicalNSW<float > world_;
663+ absl::Mutex resize_mutex_;
624664};
625665
626666HnswVectorIndex::HnswVectorIndex (const SchemaField::VectorParams& params, PMR_NS::memory_resource*)
627- : BaseVectorIndex {params.dim , params.sim }, adapter_{make_unique<HnswlibAdapter>(params)} {
667+ : dim_ {params.dim }, sim_{ params.sim }, adapter_{make_unique<HnswlibAdapter>(params)} {
628668 DCHECK (params.use_hnsw );
629669 // TODO: Patch hnsw to use MR
630670}
631671
632672HnswVectorIndex::~HnswVectorIndex () {
633673}
634674
635- void HnswVectorIndex::AddVector (DocId id, const VectorPtr& vector) {
636- if (vector) {
637- adapter_->Add (vector.get (), id);
675+ bool HnswVectorIndex::Add (GlobalDocId id, const DocumentAccessor& doc, std::string_view field) {
676+ auto vector = doc.GetVector (field);
677+
678+ if (!vector) {
679+ return false ;
680+ }
681+
682+ auto & [ptr, size] = vector.value ();
683+
684+ if (ptr && size != dim_) {
685+ return false ;
686+ }
687+
688+ if (ptr) {
689+ adapter_->Add (ptr.get (), id);
638690 }
691+
692+ return true ;
639693}
640694
641- std::vector<std::pair<float , DocId >> HnswVectorIndex::Knn (float * target, size_t k,
642- std::optional<size_t > ef) const {
695+ std::vector<std::pair<float , GlobalDocId >> HnswVectorIndex::Knn (float * target, size_t k,
696+ std::optional<size_t > ef) const {
643697 return adapter_->Knn (target, k, ef);
644698}
645- std::vector<std::pair<float , DocId>> HnswVectorIndex::Knn (float * target, size_t k,
646- std::optional<size_t > ef,
647- const std::vector<DocId>& allowed) const {
699+
700+ std::vector<std::pair<float , GlobalDocId>> HnswVectorIndex::Knn (
701+ float * target, size_t k, std::optional<size_t > ef,
702+ const std::vector<GlobalDocId>& allowed) const {
648703 return adapter_->Knn (target, k, ef, allowed);
649704}
650705
651- void HnswVectorIndex::Remove (DocId id, const DocumentAccessor& doc, string_view field) {
706+ void HnswVectorIndex::Remove (GlobalDocId id, const DocumentAccessor& doc, string_view field) {
652707 adapter_->Remove (id);
653708}
654709
0 commit comments