diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 90735a14..77204700 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -312,17 +312,19 @@ class HierarchicalNSW : public AlgorithmInterface { template std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr, size_t k = 0) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; + const bool use_filter = isIdAllowed && k > 0; std::priority_queue, std::vector>, CompareByFirst> top_candidates; std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), (labeltype)-1, getExternalLabel(ep_id), dist_func_param_); + labeltype ep_label = getExternalLabel(ep_id); + if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(ep_label))) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), (labeltype)-1, ep_label, dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); candidate_set.emplace(-dist, ep_id); @@ -336,8 +338,19 @@ class HierarchicalNSW : public AlgorithmInterface { while (!candidate_set.empty()) { std::pair current_node_pair = candidate_set.top(); - if ((-current_node_pair.first) > lowerBound && - (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { + bool should_break = false; + if (use_filter) { + // Continue until we have k filtered candidates + if (top_candidates.size() >= k && (-current_node_pair.first) > lowerBound) { + should_break = true; + } + } else { + if ((-current_node_pair.first) > lowerBound && + (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { + should_break = true; + } + } + if (should_break) { break; } candidate_set.pop(); @@ -370,9 +383,12 @@ class HierarchicalNSW : public AlgorithmInterface { visited_array[candidate_id] = visited_array_tag; char *currObj1 = (getDataByInternalId(candidate_id)); - dist_t dist = fstdistfunc_(data_point, currObj1, (labeltype)-1, getExternalLabel(candidate_id), dist_func_param_); + labeltype cand_label = getExternalLabel(candidate_id); + dist_t dist = fstdistfunc_(data_point, currObj1, (labeltype)-1, cand_label, dist_func_param_); - if (top_candidates.size() < ef || lowerBound > dist) { + bool should_explore = (use_filter && top_candidates.size() < k) || + (top_candidates.size() < ef || lowerBound > dist); + if (should_explore) { candidate_set.emplace(-dist, candidate_id); #ifdef USE_SSE _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + @@ -380,8 +396,10 @@ class HierarchicalNSW : public AlgorithmInterface { _MM_HINT_T0); //////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) + bool is_filtered = (!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(cand_label)); + if (is_filtered) { top_candidates.emplace(dist, candidate_id); + } if (top_candidates.size() > ef) top_candidates.pop(); @@ -1367,13 +1385,14 @@ class HierarchicalNSW : public AlgorithmInterface { size_t searchEf = std::max(ef_, k); if ( ef ) searchEf = std::max(searchEf, *ef); + std::priority_queue, std::vector>, CompareByFirst> top_candidates; if (num_deleted_) { top_candidates = searchBaseLayerST( - currObj, query_data, searchEf, isIdAllowed); + currObj, query_data, searchEf, isIdAllowed, isIdAllowed ? k : 0); } else { top_candidates = searchBaseLayerST( - currObj, query_data, searchEf, isIdAllowed); + currObj, query_data, searchEf, isIdAllowed, isIdAllowed ? k : 0); } while (top_candidates.size() > k) {