From 7300b1cb1ab96e5adb18f4d795706a2db43ec3dd Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Thu, 9 Jan 2025 17:58:12 -0600 Subject: [PATCH] Use ScoreTracker to avoid wasteful searching for very large k (#384) * clarify * use scoreTracker to short circuit new edge evaluation once we hit a local maximum --- .../jbellis/jvector/graph/GraphSearcher.java | 20 ++++++++----------- .../jbellis/jvector/graph/ScoreTracker.java | 8 +++++--- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java index 848103b8f..f64514584 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java @@ -264,11 +264,8 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr rerankedResults.setMaxSize(topK); int numVisited = initialVisited; - // A bound that holds the minimum similarity to the query vector that a candidate vector must - // have to be considered -- will be set to the lowest score in the results queue once the queue is full. - var minAcceptedSimilarity = Float.NEGATIVE_INFINITY; // track scores to predict when we are done with threshold queries - var scoreTracker = threshold > 0 ? new ScoreTracker.TwoPhaseTracker(threshold) : ScoreTracker.NO_OP; + var scoreTracker = threshold > 0 ? new ScoreTracker.TwoPhaseTracker(threshold) : new ScoreTracker.TwoPhaseTracker(1.0); VectorFloat similarities = null; // add evicted results from the last call back to the candidates @@ -283,11 +280,11 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr while (candidates.size() > 0) { // we're done when we have K results and the best candidate is worse than the worst result so far float topCandidateScore = candidates.topScore(); - if (topCandidateScore < minAcceptedSimilarity) { + if (approximateResults.size() >= rerankK && topCandidateScore < approximateResults.topScore()) { break; } // when querying by threshold, also stop when we are probabilistically unlikely to find more qualifying results - if (scoreTracker.shouldStop()) { + if (threshold > 0 && scoreTracker.shouldStop()) { break; } @@ -295,11 +292,6 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr int topCandidateNode = candidates.pop(); if (acceptOrds.get(topCandidateNode) && topCandidateScore >= threshold) { addTopCandidate(topCandidateNode, topCandidateScore, rerankK); - - // update minAcceptedSimilarity if we've found K results - if (approximateResults.size() >= rerankK) { - minAcceptedSimilarity = approximateResults.topScore(); - } } // if this candidate came from evictedResults, we don't need to evaluate its neighbors again @@ -307,13 +299,17 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr continue; } + // skip edge loading if we've found a local maximum and we have enough results + if (scoreTracker.shouldStop() && candidates.size() >= rerankK - approximateResults.size()) { + continue; + } + // score the neighbors of the top candidate and add them to the queue var scoreFunction = scoreProvider.scoreFunction(); var useEdgeLoading = scoreFunction.supportsEdgeLoadingSimilarity(); if (useEdgeLoading) { similarities = scoreFunction.edgeLoadingSimilarityTo(topCandidateNode); } - var it = view.getNeighborsIterator(topCandidateNode); for (int i = 0; i < it.size(); i++) { var friendOrd = it.nextInt(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java index 3a3a3d549..9aff97721 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java @@ -55,7 +55,7 @@ class TwoPhaseTracker implements ScoreTracker { private int recentEntryIndex; // Heap of the best scores seen so far - AbstractLongHeap bestScores; + BoundedLongHeap bestScores; // observation count private int observationCount; @@ -87,8 +87,10 @@ public boolean shouldStop() { return false; } - // we're in phase 2 if the 99th percentile of the recent scores is worse than the best score - // (paper suggests median, but experimentally that is too prone to false positives. + // We're in phase 2 if the 99th percentile of the recent scores evaluated is lower + // than the worst of the best scores seen. + // + // (paper suggests using the median of recent scores, but experimentally that is too prone to false positives. // 90th does seem to be enough, but 99th doesn't result in much extra work, so we'll be conservative) double windowMedian = StatUtils.percentile(recentScores, 99); double worstBest = sortableIntToFloat((int) bestScores.top());