@@ -492,40 +492,36 @@ bool BaseVectorIndex<T>::Add(T id, const DocumentAccessor& doc, std::string_view
492492 return true ;
493493}
494494
495- template <typename T>
496- FlatVectorIndex<T>::FlatVectorIndex(const SchemaField::VectorParams& params,
497- PMR_NS::memory_resource* mr)
498- : BaseVectorIndex<T>{params.dim , params.sim }, entries_{mr} {
495+ FlatVectorIndex::FlatVectorIndex (const SchemaField::VectorParams& params,
496+ PMR_NS::memory_resource* mr)
497+ : BaseVectorIndex<DocId>{params.dim , params.sim }, entries_{mr} {
499498 DCHECK (!params.use_hnsw );
500499 entries_.reserve (params.capacity * params.dim );
501500}
502501
503- template <typename T>
504- void FlatVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T>::VectorPtr& vector) {
505- DCHECK_LE (id * BaseVectorIndex<T>::dim_, entries_.size ());
506- if (id * BaseVectorIndex<T>::dim_ == entries_.size ())
507- entries_.resize ((id + 1 ) * BaseVectorIndex<T>::dim_);
502+ void FlatVectorIndex::AddVector (DocId id, const typename BaseVectorIndex::VectorPtr& vector) {
503+ DCHECK_LE (id * dim_, entries_.size ());
504+ if (id * dim_ == entries_.size ())
505+ entries_.resize ((id + 1 ) * dim_);
508506
509507 // TODO: Let get vector write to buf itself
510508 if (vector) {
511- memcpy (&entries_[id * BaseVectorIndex<T>::dim_], vector.get (),
512- BaseVectorIndex<T>::dim_ * sizeof (float ));
509+ memcpy (&entries_[id * dim_], vector.get (), dim_ * sizeof (float ));
513510 }
514511}
515512
516- template <typename T>
517- void FlatVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
513+ void FlatVectorIndex::Remove (DocId id, const DocumentAccessor& doc, string_view field) {
518514 // noop
519515}
520516
521- template < typename T> const float * FlatVectorIndex<T> ::Get(T doc) const {
517+ const float * FlatVectorIndex::Get (DocId doc) const {
522518 return &entries_[doc * dim_];
523519}
524520
525- template < typename T> std::vector<T > FlatVectorIndex<T> ::GetAllDocsWithNonNullValues() const {
521+ std::vector<DocId > FlatVectorIndex::GetAllDocsWithNonNullValues () const {
526522 std::vector<DocId> result;
527523
528- size_t num_vectors = entries_.size () / BaseVectorIndex<T>:: dim_;
524+ size_t num_vectors = entries_.size () / dim_;
529525 result.reserve (num_vectors);
530526
531527 for (DocId id = 0 ; id < num_vectors; ++id) {
@@ -535,7 +531,7 @@ template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullVa
535531 bool is_zero_vector = true ;
536532
537533 // TODO: Consider don't use check for zero vector
538- for (size_t i = 0 ; i < BaseVectorIndex<T>:: dim_; ++i) {
534+ for (size_t i = 0 ; i < dim_; ++i) {
539535 if (vec[i] != 0 .0f ) { // TODO: Consider using a threshold for float comparison
540536 is_zero_vector = false ;
541537 break ;
@@ -552,9 +548,7 @@ template <typename T> std::vector<T> FlatVectorIndex<T>::GetAllDocsWithNonNullVa
552548 return result;
553549}
554550
555- template struct FlatVectorIndex <DocId>;
556-
557- struct HnswlibAdapter {
551+ template <typename T> struct HnswlibAdapter {
558552 // Default setting of hnswlib/hnswalg
559553 constexpr static size_t kDefaultEfRuntime = 10 ;
560554
@@ -564,34 +558,45 @@ struct HnswlibAdapter {
564558 100 /* seed*/ } {
565559 }
566560
567- void Add (const float * data, DocId id) {
568- if (world_.cur_element_count + 1 >= world_.max_elements_ )
569- world_.resizeIndex (world_.cur_element_count * 2 );
570- world_.addPoint (data, id);
561+ void Add (const float * data, T id) {
562+ while (true ) {
563+ try {
564+ absl::ReaderMutexLock lock (&resize_mutex_);
565+ world_.addPoint (data, id);
566+ return ;
567+ } catch (const std::exception& e) {
568+ std::string error_msg = e.what ();
569+ if (absl::StrContains (error_msg, " The number of elements exceeds the specified limit" )) {
570+ ResizeIfFull ();
571+ continue ;
572+ }
573+ throw e;
574+ }
575+ }
571576 }
572577
573- void Remove (DocId id) {
578+ void Remove (T id) {
574579 try {
575580 world_.markDelete (id);
576581 } catch (const std::exception& e) {
577582 }
578583 }
579584
580- vector<pair<float , DocId >> Knn (float * target, size_t k, std::optional<size_t > ef) {
585+ vector<pair<float , T >> Knn (float * target, size_t k, std::optional<size_t > ef) {
581586 world_.setEf (ef.value_or (kDefaultEfRuntime ));
582587 return QueueToVec (world_.searchKnn (target, k));
583588 }
584589
585- vector<pair<float , DocId >> Knn (float * target, size_t k, std::optional<size_t > ef,
586- const vector<DocId >& allowed) {
590+ vector<pair<float , T >> Knn (float * target, size_t k, std::optional<size_t > ef,
591+ const vector<T >& allowed) {
587592 struct BinsearchFilter : hnswlib::BaseFilterFunctor {
588593 virtual bool operator ()(hnswlib::labeltype id) {
589594 return binary_search (allowed->begin (), allowed->end (), id);
590595 }
591596
592- BinsearchFilter (const vector<DocId >* allowed) : allowed{allowed} {
597+ BinsearchFilter (const vector<T >* allowed) : allowed{allowed} {
593598 }
594- const vector<DocId >* allowed;
599+ const vector<T >* allowed;
595600 };
596601
597602 world_.setEf (ef.value_or (kDefaultEfRuntime ));
@@ -613,8 +618,30 @@ struct HnswlibAdapter {
613618 return visit ([](auto & space) -> hnswlib::SpaceInterface<float >* { return &space; }, space_);
614619 }
615620
616- template <typename Q> static vector<pair<float , DocId>> QueueToVec (Q queue) {
617- vector<pair<float , DocId>> out (queue.size ());
621+ void ResizeIfFull () {
622+ {
623+ absl::ReaderMutexLock lock (&resize_mutex_);
624+ if (world_.getCurrentElementCount () < world_.getMaxElements () ||
625+ (world_.allow_replace_deleted_ && world_.getDeletedCount () > 0 )) {
626+ return ;
627+ }
628+ }
629+ try {
630+ absl::WriterMutexLock lock (&resize_mutex_);
631+ if (world_.getCurrentElementCount () == world_.getMaxElements () &&
632+ (!world_.allow_replace_deleted_ || world_.getDeletedCount () == 0 )) {
633+ auto max_elements = world_.getMaxElements ();
634+ world_.resizeIndex (max_elements * 2 );
635+ LOG (INFO) << " Resizing HNSW Index, current size: " << max_elements
636+ << " , expand by: " << max_elements * 2 ;
637+ }
638+ } catch (const std::exception& e) {
639+ throw e;
640+ }
641+ }
642+
643+ template <typename Q> static vector<pair<float , GlobalDocId>> QueueToVec (Q queue) {
644+ vector<pair<float , GlobalDocId>> out (queue.size ());
618645 size_t idx = out.size ();
619646 while (!queue.empty ()) {
620647 out[--idx] = queue.top ();
@@ -625,45 +652,46 @@ struct HnswlibAdapter {
625652
626653 SpaceUnion space_;
627654 hnswlib::HierarchicalNSW<float > world_;
655+ absl::Mutex resize_mutex_;
628656};
629657
630- template <typename T>
631- HnswVectorIndex<T>::HnswVectorIndex(const SchemaField::VectorParams& params,
632- PMR_NS::memory_resource*)
633- : BaseVectorIndex<T>{params.dim , params.sim }, adapter_{make_unique<HnswlibAdapter>(params)} {
658+ HnswVectorIndexShardPlaceholder::HnswVectorIndexShardPlaceholder (
659+ const SchemaField::VectorParams& params)
660+ : BaseVectorIndex<DocId>{params.dim , params.sim } {
661+ }
662+
663+ HnswVectorIndex::HnswVectorIndex (const SchemaField::VectorParams& params, PMR_NS::memory_resource*)
664+ : BaseVectorIndex<GlobalDocId>{params.dim , params.sim },
665+ adapter_{make_unique<HnswlibAdapter<GlobalDocId>>(params)} {
634666 DCHECK (params.use_hnsw );
635667 // TODO: Patch hnsw to use MR
636668}
637- template <typename T> HnswVectorIndex<T>::~HnswVectorIndex () {
669+
670+ HnswVectorIndex::~HnswVectorIndex () {
638671}
639672
640- template < typename T>
641- void HnswVectorIndex<T>::AddVector(T id, const typename BaseVectorIndex<T >::VectorPtr& vector) {
673+ void HnswVectorIndex::AddVector (GlobalDocId id,
674+ const typename BaseVectorIndex<GlobalDocId >::VectorPtr& vector) {
642675 if (vector) {
643676 adapter_->Add (vector.get (), id);
644677 }
645678}
646679
647- template <typename T>
648- std::vector<std::pair<float , T>> HnswVectorIndex<T>::Knn(float * target, size_t k,
649- std::optional<size_t > ef) const {
680+ std::vector<std::pair<float , GlobalDocId>> HnswVectorIndex::Knn (float * target, size_t k,
681+ std::optional<size_t > ef) const {
650682 return adapter_->Knn (target, k, ef);
651683}
652684
653- template <typename T>
654- std::vector<std::pair<float , T>> HnswVectorIndex<T>::Knn(float * target, size_t k,
655- std::optional<size_t > ef,
656- const std::vector<T>& allowed) const {
685+ std::vector<std::pair<float , GlobalDocId>> HnswVectorIndex::Knn (
686+ float * target, size_t k, std::optional<size_t > ef,
687+ const std::vector<GlobalDocId>& allowed) const {
657688 return adapter_->Knn (target, k, ef, allowed);
658689}
659690
660- template <typename T>
661- void HnswVectorIndex<T>::Remove(T id, const DocumentAccessor& doc, string_view field) {
691+ void HnswVectorIndex::Remove (GlobalDocId id, const DocumentAccessor& doc, string_view field) {
662692 adapter_->Remove (id);
663693}
664694
665- template struct HnswVectorIndex <DocId>;
666-
667695GeoIndex::GeoIndex (PMR_NS::memory_resource* mr) : rtree_(make_unique<rtree>()) {
668696}
669697
0 commit comments