Skip to content

Commit

Permalink
cache reranked scores (#341)
Browse files Browse the repository at this point in the history
* cache reranked scores to avoid redoing expensive work when resuming
* add SearchResult.getRerankedCount
* Merge Reranker into ExactScoreFunction
* remove multiscore methods
  • Loading branch information
jbellis authored Jul 1, 2024
1 parent b5d8247 commit edd396d
Show file tree
Hide file tree
Showing 21 changed files with 138 additions and 398 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ default int getIdUpperBound() {
* except for OnHeapGraphIndex.ConcurrentGraphIndexView.)
*/
interface ScoringView extends View {
ScoreFunction.Reranker rerankerFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf);
ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf);
ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.Experimental;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.BoundedLongHeap;
import io.github.jbellis.jvector.util.GrowableLongHeap;
import io.github.jbellis.jvector.util.SparseBits;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import org.agrona.collections.Int2ObjectHashMap;
import org.agrona.collections.IntHashSet;

import java.io.Closeable;
Expand All @@ -56,6 +58,7 @@ public class GraphSearcher implements Closeable {
// Search parameters that we save here for use by resume()
private Bits acceptOrds;
private SearchScoreProvider scoreProvider;
private CachingReranker cachingReranker;

/**
* Creates a new graph searcher from the given GraphIndex
Expand All @@ -73,6 +76,16 @@ private GraphSearcher(GraphIndex.View view) {
this.visited = new IntHashSet();
}

private void initializeScoreProvider(SearchScoreProvider scoreProvider) {
this.scoreProvider = scoreProvider;
if (scoreProvider.reranker() == null) {
cachingReranker = null;
return;
}

cachingReranker = new CachingReranker(scoreProvider);
}

public GraphIndex.View getView() {
return view;
}
Expand Down Expand Up @@ -120,10 +133,11 @@ public GraphSearcher build() {
* If threshold > 0 then the search will stop when it is probabilistically unlikely
* to find more nodes above the threshold, even if `topK` results have not yet been found.
* @param rerankFloor (Experimental!) Candidates whose approximate similarity is at least this value
* will not be reranked with the exact score (which requires loading the raw vector)
* will be reranked with the exact score (which requires loading a high-res vector from disk)
* and included in the final results. (Potentially leaving fewer than topK entries
* in the results.) Other candidates will be discarded. This is intended for use
* when combining results from multiple indexes.
* in the results.) Other candidates will be discarded, but will be potentially
* resurfaced if `resume` is called. This is intended for use when combining results
* from multiple indexes.
* @param acceptOrds a Bits instance indicating which nodes are acceptable results.
* If {@link Bits#ALL}, all nodes are acceptable.
* It is caller's responsibility to ensure that there are enough acceptable nodes
Expand Down Expand Up @@ -198,7 +212,7 @@ SearchResult searchInternal(SearchScoreProvider scoreProvider,
}

// save search parameters for potential later resume
this.scoreProvider = scoreProvider;
initializeScoreProvider(scoreProvider);
this.acceptOrds = Bits.intersectionOf(rawAcceptOrds, view.liveNodes());

// reset the scratch data structures
Expand All @@ -208,7 +222,7 @@ SearchResult searchInternal(SearchScoreProvider scoreProvider,

// no entry point -> empty results
if (ep < 0) {
return new SearchResult(new SearchResult.NodeScore[0], 0, Float.POSITIVE_INFINITY);
return new SearchResult(new SearchResult.NodeScore[0], 0, 0, Float.POSITIVE_INFINITY);
}

// kick off the actual search at the entry point
Expand Down Expand Up @@ -337,18 +351,22 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr
assert approximateResults.size() <= rerankK;
NodeQueue popFromQueue;
float worstApproximateInTopK;
if (scoreProvider.reranker() == null) {
int reranked;
if (cachingReranker == null) {
// save the worst candidates in evictedResults for potential resume()
while (approximateResults.size() > topK) {
var nScore = approximateResults.topScore();
var n = approximateResults.pop();
evictedResults.add(n, nScore);
}

reranked = 0;
worstApproximateInTopK = Float.POSITIVE_INFINITY;
popFromQueue = approximateResults;
} else {
worstApproximateInTopK = approximateResults.rerank(topK, scoreProvider.reranker(), rerankFloor, rerankedResults, evictedResults);
int oldReranked = cachingReranker.getRerankCalls();
worstApproximateInTopK = approximateResults.rerank(topK, cachingReranker, rerankFloor, rerankedResults, evictedResults);
reranked = cachingReranker.getRerankCalls() - oldReranked;
approximateResults.clear();
popFromQueue = rerankedResults;
}
Expand All @@ -363,7 +381,7 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr
// that should be everything
assert popFromQueue.size() == 0;

return new SearchResult(nodes, numVisited, worstApproximateInTopK);
return new SearchResult(nodes, numVisited, reranked, worstApproximateInTopK);
} catch (Throwable t) {
// clear scratch structures if terminated via throwable, as they may not have been drained
approximateResults.clear();
Expand All @@ -390,4 +408,33 @@ public SearchResult resume(int additionalK, int rerankK) {
public void close() throws IOException {
view.close();
}

private static class CachingReranker implements ScoreFunction.ExactScoreFunction {
// this cache never gets cleared out (until a new search reinitializes it),
// but we expect resume() to be called at most a few times so it's fine
private final Int2ObjectHashMap<Float> cachedScores;
private final SearchScoreProvider scoreProvider;
private int rerankCalls;

public CachingReranker(SearchScoreProvider scoreProvider) {
this.scoreProvider = scoreProvider;
cachedScores = new Int2ObjectHashMap<>();
rerankCalls = 0;
}

@Override
public float similarityTo(int node2) {
if (cachedScores.containsKey(node2)) {
return cachedScores.get(node2);
}
rerankCalls++;
float score = scoreProvider.reranker().similarityTo(node2);
cachedScores.put(node2, Float.valueOf(score));
return score;
}

public int getRerankCalls() {
return rerankCalls;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public int[] nodesCopy() {
* <p>
* Only the best result or results whose approximate score is at least `rerankFloor` will be reranked.
*/
public float rerank(int topK, ScoreFunction.Reranker reranker, float rerankFloor, NodeQueue reranked, NodesUnsorted unused) {
public float rerank(int topK, ScoreFunction.ExactScoreFunction reranker, float rerankFloor, NodeQueue reranked, NodesUnsorted unused) {
// Rescore the nodes whose approximate score meets the floor. Nodes that do not will be marked as -1
int[] ids = new int[size()];
float[] exactScores = new float[size()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@

package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;

import java.util.function.Supplier;
Expand Down Expand Up @@ -105,4 +107,19 @@ default Supplier<RandomAccessVectorValues> threadLocalSupplier() {
var tl = ExplicitThreadLocal.withInitial(this::copy);
return tl::get;
}

/**
* Convenience method to create an ExactScoreFunction for reranking. The resulting function is NOT thread-safe.
*/
default ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf) {
return new ScoreFunction.ExactScoreFunction() {
private final VectorFloat<?> scratch = vts.createFloatVector(dimension());

@Override
public float similarityTo(int node2) {
getVectorInto(node2, scratch, 0);
return vsf.compare(queryVector, scratch);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
public final class SearchResult {
private final NodeScore[] nodes;
private final int visitedCount;
private final int rerankedCount;
private final float worstApproximateScoreInTopK;

public SearchResult(NodeScore[] nodes, int visitedCount, float worstApproximateScoreInTopK) {
public SearchResult(NodeScore[] nodes, int visitedCount, int rerankedCount, float worstApproximateScoreInTopK) {
this.nodes = nodes;
this.visitedCount = visitedCount;
this.rerankedCount = rerankedCount;
this.worstApproximateScoreInTopK = worstApproximateScoreInTopK;
}

Expand All @@ -44,6 +46,13 @@ public int getVisitedCount() {
return visitedCount;
}

/**
* @return the number of nodes that were reranked during the search
*/
public int getRerankedCount() {
return rerankedCount;
}

/**
* @return the worst approximate score of the top K nodes in the search result. Useful
* for passing to rerankFloor during search across multiple indexes. Will be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public void close() throws IOException {
}

@Override
public ScoreFunction.Reranker rerankerFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf) {
public ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf) {
return view.rerankerFor(queryVector, vsf);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,17 +260,12 @@ public void close() throws IOException {
reader.close();
}

public ScoreFunction.Reranker rerankerFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf, Set<FeatureId> permissibleFeatures) {
if (permissibleFeatures.contains(FeatureId.INLINE_VECTORS) && features.containsKey(FeatureId.INLINE_VECTORS)) {
return ScoreFunction.Reranker.from(queryVector, vsf, this);
} else {
throw new UnsupportedOperationException("No reranker available for this graph");
}
}

@Override
public ScoreFunction.Reranker rerankerFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf) {
return rerankerFor(queryVector, vsf, FeatureId.ALL);
public ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf) {
if (!features.containsKey(FeatureId.INLINE_VECTORS)) {
throw new UnsupportedOperationException("No inline vectors in this graph");
}
return RandomAccessVectorValues.super.rerankerFor(queryVector, vsf);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,41 +64,6 @@ default boolean isExact() {
}
}

/**
* An ExactScoreFunction with an optimized batch `similarityTo` method for reranking
* results after an approximate-scored search.
*/
interface Reranker extends ExactScoreFunction {
/**
* @return a vector of size `nodes.length` containing the corresponding score for
* each of the nodes in `nodes`. Used when reranking search results.
*/
VectorFloat<?> similarityTo(int[] nodes);

static Reranker from(VectorFloat<?> queryVector, VectorSimilarityFunction vsf, RandomAccessVectorValues vp) {
return new Reranker() {
@Override
public VectorFloat<?> similarityTo(int[] nodes) {
var results = vts.createFloatVector(nodes.length);
var nodeCount = nodes.length;
var dimension = queryVector.length();
var packedVectors = vts.createFloatVector(nodeCount * dimension);
for (int i1 = 0; i1 < nodeCount; i1++) {
var node = nodes[i1];
vp.getVectorInto(node, packedVectors, i1 * dimension);
}
vsf.compareMulti(queryVector, packedVectors, results);
return results;
}

@Override
public float similarityTo(int node2) {
return vsf.compare(queryVector, vp.getVector(node2));
}
};
}
}

interface ApproximateScoreFunction extends ScoreFunction {
default boolean isExact() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
/** Encapsulates comparing node distances to a specific vector for GraphSearcher. */
public final class SearchScoreProvider {
private final ScoreFunction scoreFunction;
private final ScoreFunction.Reranker reranker;
private final ScoreFunction.ExactScoreFunction reranker;

/**
* @param scoreFunction the primary, fast scoring function
Expand All @@ -40,10 +40,10 @@ public SearchScoreProvider(ScoreFunction scoreFunction) {
* Generally, reranker will be null iff scoreFunction is an ExactScoreFunction. However,
* it is allowed, and sometimes useful, to only perform approximate scoring without reranking.
* <p>
* Most often it will be convenient to get the Reranker either using `Reranker.from`
* Most often it will be convenient to get the reranker either using `ExactScoreFunction.from`
* or `ScoringView.rerankerFor`.
*/
public SearchScoreProvider(ScoreFunction scoreFunction, ScoreFunction.Reranker reranker) {
public SearchScoreProvider(ScoreFunction scoreFunction, ScoreFunction.ExactScoreFunction reranker) {
assert scoreFunction != null;
this.scoreFunction = scoreFunction;
this.reranker = reranker;
Expand All @@ -53,7 +53,7 @@ public ScoreFunction scoreFunction() {
return scoreFunction;
}

public ScoreFunction.Reranker reranker() {
public ScoreFunction.ExactScoreFunction reranker() {
return reranker;
}

Expand All @@ -69,7 +69,13 @@ public ScoreFunction.ExactScoreFunction exactScoreFunction() {
* e.g. during construction.
*/
public static SearchScoreProvider exact(VectorFloat<?> v, VectorSimilarityFunction vsf, RandomAccessVectorValues ravv) {
var sf = ScoreFunction.Reranker.from(v, vsf, ravv);
// don't use ESF.reranker, we need thread safety here
var sf = new ScoreFunction.ExactScoreFunction() {
@Override
public float similarityTo(int node2) {
return vsf.compare(v, ravv.getVector(node2));
}
};
return new SearchScoreProvider(sf);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ public enum VectorSimilarityFunction {
public float compare(VectorFloat<?> v1, VectorFloat<?> v2) {
return 1 / (1 + VectorUtil.squareL2Distance(v1, v2));
}

@Override
public void compareMulti(VectorFloat<?> v1, VectorFloat<?> packedVectors, VectorFloat<?> results) {
VectorUtil.euclideanMultiScore(v1, packedVectors, results);
}
},

/**
Expand All @@ -58,11 +53,6 @@ public void compareMulti(VectorFloat<?> v1, VectorFloat<?> packedVectors, Vector
public float compare(VectorFloat<?> v1, VectorFloat<?> v2) {
return (1 + VectorUtil.dotProduct(v1, v2)) / 2;
}

@Override
public void compareMulti(VectorFloat<?> v1, VectorFloat<?> packedVectors, VectorFloat<?> results) {
VectorUtil.dotProductMultiScore(v1, packedVectors, results);
}
},

/**
Expand All @@ -76,11 +66,6 @@ public void compareMulti(VectorFloat<?> v1, VectorFloat<?> packedVectors, Vector
public float compare(VectorFloat<?> v1, VectorFloat<?> v2) {
return (1 + VectorUtil.cosine(v1, v2)) / 2;
}

@Override
public void compareMulti(VectorFloat<?> v1, VectorFloat<?> packedVectors, VectorFloat<?> results) {
VectorUtil.cosineMultiScore(v1, packedVectors, results);
}
};

/**
Expand All @@ -92,13 +77,4 @@ public void compareMulti(VectorFloat<?> v1, VectorFloat<?> packedVectors, Vector
* @return the value of the similarity function applied to the two vectors
*/
public abstract float compare(VectorFloat<?> v1, VectorFloat<?> v2);

/**
* Calculates similarity scores between a query vector and multiple vectors with a specified function. Higher
* similarity scores correspond to closer vectors.
*
* @param v1 a vector
* @param packedVectors N vectors packed into a single vector, of N * v1.length() dimension
*/
public abstract void compareMulti(VectorFloat<?> v1, VectorFloat<?> packedVectors, VectorFloat<?> results);
}
Loading

0 comments on commit edd396d

Please sign in to comment.