From 39b297bdfdcd97cf58a232d9f6416f8ea1186fd0 Mon Sep 17 00:00:00 2001 From: Sergey Nikolaev Date: Mon, 12 Jan 2026 18:04:40 +0700 Subject: [PATCH] feat: add on-the-fly filtering support to HNSW search algorithm This commit implements on-the-fly filtering for KNN search, allowing the search algorithm to continue exploring until k filtered candidates are found, rather than finding k total candidates and then filtering them. Key changes: - Added `k` parameter to `searchBaseLayerST()` method to specify the target number of filtered results when filtering is enabled - Added `use_filter` flag to determine when filtering mode is active - Modified termination condition to continue searching until `k` filtered candidates are found when `use_filter` is true, instead of stopping after `ef` total candidates - Updated exploration logic (`should_explore`) to continue exploring neighbors when filtering is enabled and fewer than `k` filtered candidates have been found, even if `ef` candidates have already been explored - Added filtering check (`is_filtered`) before adding candidates to `top_candidates` priority queue, ensuring only filtered candidates are stored - The filter callback (`BaseFilterFunctor* isIdAllowed`) is invoked for each candidate node via `(*isIdAllowed)(label)` to determine if it passes the filter - Updated `searchKnn()` to pass the `k` parameter to `searchBaseLayerST()` when a filter is provided The changes ensure that when filtering is enabled: 1. The search continues until `k` filtered candidates are found (if they exist) 2. Only candidates that pass the filter are added to `top_candidates` 3. The exploration continues even if `ef` candidates have been explored, as long as fewer than `k` filtered candidates have been found This enables more accurate KNN search results when combined with attribute filters, as the algorithm actively searches for filtered candidates rather than relying on post-filtering which may return fewer than `k` results. --- hnswlib/hnswalg.h | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 90735a14..77204700 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -312,17 +312,19 @@ class HierarchicalNSW : public AlgorithmInterface { template std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr, size_t k = 0) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; + const bool use_filter = isIdAllowed && k > 0; std::priority_queue, std::vector>, CompareByFirst> top_candidates; std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), (labeltype)-1, getExternalLabel(ep_id), dist_func_param_); + labeltype ep_label = getExternalLabel(ep_id); + if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(ep_label))) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), (labeltype)-1, ep_label, dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); candidate_set.emplace(-dist, ep_id); @@ -336,8 +338,19 @@ class HierarchicalNSW : public AlgorithmInterface { while (!candidate_set.empty()) { std::pair current_node_pair = candidate_set.top(); - if ((-current_node_pair.first) > lowerBound && - (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { + bool should_break = false; + if (use_filter) { + // Continue until we have k filtered candidates + if (top_candidates.size() >= k && (-current_node_pair.first) > lowerBound) { + should_break = true; + } + } else { + if ((-current_node_pair.first) > lowerBound && + (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { + should_break = true; + } + } + if (should_break) { break; } candidate_set.pop(); @@ -370,9 +383,12 @@ class HierarchicalNSW : public AlgorithmInterface { visited_array[candidate_id] = visited_array_tag; char *currObj1 = (getDataByInternalId(candidate_id)); - dist_t dist = fstdistfunc_(data_point, currObj1, (labeltype)-1, getExternalLabel(candidate_id), dist_func_param_); + labeltype cand_label = getExternalLabel(candidate_id); + dist_t dist = fstdistfunc_(data_point, currObj1, (labeltype)-1, cand_label, dist_func_param_); - if (top_candidates.size() < ef || lowerBound > dist) { + bool should_explore = (use_filter && top_candidates.size() < k) || + (top_candidates.size() < ef || lowerBound > dist); + if (should_explore) { candidate_set.emplace(-dist, candidate_id); #ifdef USE_SSE _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + @@ -380,8 +396,10 @@ class HierarchicalNSW : public AlgorithmInterface { _MM_HINT_T0); //////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) + bool is_filtered = (!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(cand_label)); + if (is_filtered) { top_candidates.emplace(dist, candidate_id); + } if (top_candidates.size() > ef) top_candidates.pop(); @@ -1367,13 +1385,14 @@ class HierarchicalNSW : public AlgorithmInterface { size_t searchEf = std::max(ef_, k); if ( ef ) searchEf = std::max(searchEf, *ef); + std::priority_queue, std::vector>, CompareByFirst> top_candidates; if (num_deleted_) { top_candidates = searchBaseLayerST( - currObj, query_data, searchEf, isIdAllowed); + currObj, query_data, searchEf, isIdAllowed, isIdAllowed ? k : 0); } else { top_candidates = searchBaseLayerST( - currObj, query_data, searchEf, isIdAllowed); + currObj, query_data, searchEf, isIdAllowed, isIdAllowed ? k : 0); } while (top_candidates.size() > k) {