diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index eb8dfe20..a04b2ed4 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -2,16 +2,21 @@ #include "visited_list_pool.h" #include "hnswlib.h" -#include -#include -#include + #include -#include +#include + +#include +#include #include #include +#include +#include +#include namespace hnswlib { typedef unsigned int tableint; +constexpr tableint kInvalidInternalId = std::numeric_limits::max(); typedef unsigned int linklistsizeint; template @@ -195,6 +200,17 @@ class HierarchicalNSW : public AlgorithmInterface { } + tableint getInternalIdByLabel(labeltype label) const { + std::lock_guard lock_table(label_lookup_lock); + auto label_lookup_result = label_lookup_.find(label); + if (label_lookup_result == label_lookup_.end() || + isMarkedDeleted(label_lookup_result->second)) { + return kInvalidInternalId; + } + return label_lookup_result->second; + } + + inline void setExternalLabel(tableint internal_id, labeltype label) const { memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); } @@ -870,13 +886,10 @@ class HierarchicalNSW : public AlgorithmInterface { // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); - std::unique_lock lock_table(label_lookup_lock); - auto search = label_lookup_.find(label); - if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + tableint internalId = getInternalIdByLabel(label); + if (internalId == kInvalidInternalId) { return Status("Label not found"); } - tableint internalId = search->second; - lock_table.unlock(); char* data_ptrv = getDataByInternalId(internalId); size_t dim = *((size_t *) dist_func_param_); @@ -1190,7 +1203,7 @@ class HierarchicalNSW : public AlgorithmInterface { } - // This internal function adds a point at a specific level. If level is + // This internal function adds a point at a specific level. StatusOr addPointWithLevel(const void *data_point, labeltype label, int level) { tableint cur_c = 0; {