From edd396d2ffd32135ae8ff97b24db7b3a08dc94fd Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Mon, 1 Jul 2024 09:45:14 -0500 Subject: [PATCH] cache reranked scores (#341) * cache reranked scores to avoid redoing expensive work when resuming * add SearchResult.getRerankedCount * Merge Reranker into ExactScoreFunction * remove multiscore methods --- .../jbellis/jvector/graph/GraphIndex.java | 2 +- .../jbellis/jvector/graph/GraphSearcher.java | 63 +++++++++-- .../jbellis/jvector/graph/NodeQueue.java | 2 +- .../graph/RandomAccessVectorValues.java | 17 +++ .../jbellis/jvector/graph/SearchResult.java | 11 +- .../jvector/graph/disk/CachingGraphIndex.java | 2 +- .../jvector/graph/disk/OnDiskGraphIndex.java | 15 +-- .../graph/similarity/ScoreFunction.java | 35 ------ .../graph/similarity/SearchScoreProvider.java | 16 ++- .../vector/VectorSimilarityFunction.java | 24 ----- .../jbellis/jvector/vector/VectorUtil.java | 33 ------ .../jvector/vector/VectorUtilSupport.java | 17 --- .../jbellis/jvector/example/IPCService.java | 2 +- .../jbellis/jvector/example/SiftSmall.java | 6 +- jvector-native/src/main/c/jvector_simd.c | 84 --------------- jvector-native/src/main/c/jvector_simd.h | 2 - .../vector/NativeVectorUtilSupport.java | 10 -- .../jvector/vector/cnative/NativeSimdOps.java | 100 ------------------ .../jvector/graph/Test2DThreshold.java | 2 +- .../jvector/graph/TestVectorGraph.java | 34 +++++- .../graph/TestVectorSimilarityFunction.java | 59 ----------- 21 files changed, 138 insertions(+), 398 deletions(-) delete mode 100644 jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorSimilarityFunction.java diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndex.java index 7774a8ce7..3462163f1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndex.java @@ -130,7 +130,7 @@ default int getIdUpperBound() { * except for OnHeapGraphIndex.ConcurrentGraphIndexView.) */ interface ScoringView extends View { - ScoreFunction.Reranker rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf); + ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf); ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat queryVector, VectorSimilarityFunction vsf); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java index 3bb1fce08..3435fd5c2 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java @@ -25,6 +25,7 @@ package io.github.jbellis.jvector.graph; import io.github.jbellis.jvector.annotations.Experimental; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.BoundedLongHeap; @@ -32,6 +33,7 @@ import io.github.jbellis.jvector.util.SparseBits; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; +import org.agrona.collections.Int2ObjectHashMap; import org.agrona.collections.IntHashSet; import java.io.Closeable; @@ -56,6 +58,7 @@ public class GraphSearcher implements Closeable { // Search parameters that we save here for use by resume() private Bits acceptOrds; private SearchScoreProvider scoreProvider; + private CachingReranker cachingReranker; /** * Creates a new graph searcher from the given GraphIndex @@ -73,6 +76,16 @@ private GraphSearcher(GraphIndex.View view) { this.visited = new IntHashSet(); } + private void initializeScoreProvider(SearchScoreProvider scoreProvider) { + this.scoreProvider = scoreProvider; + if (scoreProvider.reranker() == null) { + cachingReranker = null; + return; + } + + cachingReranker = new CachingReranker(scoreProvider); + } + public GraphIndex.View getView() { return view; } @@ -120,10 +133,11 @@ public GraphSearcher build() { * If threshold > 0 then the search will stop when it is probabilistically unlikely * to find more nodes above the threshold, even if `topK` results have not yet been found. * @param rerankFloor (Experimental!) Candidates whose approximate similarity is at least this value - * will not be reranked with the exact score (which requires loading the raw vector) + * will be reranked with the exact score (which requires loading a high-res vector from disk) * and included in the final results. (Potentially leaving fewer than topK entries - * in the results.) Other candidates will be discarded. This is intended for use - * when combining results from multiple indexes. + * in the results.) Other candidates will be discarded, but will be potentially + * resurfaced if `resume` is called. This is intended for use when combining results + * from multiple indexes. * @param acceptOrds a Bits instance indicating which nodes are acceptable results. * If {@link Bits#ALL}, all nodes are acceptable. * It is caller's responsibility to ensure that there are enough acceptable nodes @@ -198,7 +212,7 @@ SearchResult searchInternal(SearchScoreProvider scoreProvider, } // save search parameters for potential later resume - this.scoreProvider = scoreProvider; + initializeScoreProvider(scoreProvider); this.acceptOrds = Bits.intersectionOf(rawAcceptOrds, view.liveNodes()); // reset the scratch data structures @@ -208,7 +222,7 @@ SearchResult searchInternal(SearchScoreProvider scoreProvider, // no entry point -> empty results if (ep < 0) { - return new SearchResult(new SearchResult.NodeScore[0], 0, Float.POSITIVE_INFINITY); + return new SearchResult(new SearchResult.NodeScore[0], 0, 0, Float.POSITIVE_INFINITY); } // kick off the actual search at the entry point @@ -337,7 +351,8 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr assert approximateResults.size() <= rerankK; NodeQueue popFromQueue; float worstApproximateInTopK; - if (scoreProvider.reranker() == null) { + int reranked; + if (cachingReranker == null) { // save the worst candidates in evictedResults for potential resume() while (approximateResults.size() > topK) { var nScore = approximateResults.topScore(); @@ -345,10 +360,13 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr evictedResults.add(n, nScore); } + reranked = 0; worstApproximateInTopK = Float.POSITIVE_INFINITY; popFromQueue = approximateResults; } else { - worstApproximateInTopK = approximateResults.rerank(topK, scoreProvider.reranker(), rerankFloor, rerankedResults, evictedResults); + int oldReranked = cachingReranker.getRerankCalls(); + worstApproximateInTopK = approximateResults.rerank(topK, cachingReranker, rerankFloor, rerankedResults, evictedResults); + reranked = cachingReranker.getRerankCalls() - oldReranked; approximateResults.clear(); popFromQueue = rerankedResults; } @@ -363,7 +381,7 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr // that should be everything assert popFromQueue.size() == 0; - return new SearchResult(nodes, numVisited, worstApproximateInTopK); + return new SearchResult(nodes, numVisited, reranked, worstApproximateInTopK); } catch (Throwable t) { // clear scratch structures if terminated via throwable, as they may not have been drained approximateResults.clear(); @@ -390,4 +408,33 @@ public SearchResult resume(int additionalK, int rerankK) { public void close() throws IOException { view.close(); } + + private static class CachingReranker implements ScoreFunction.ExactScoreFunction { + // this cache never gets cleared out (until a new search reinitializes it), + // but we expect resume() to be called at most a few times so it's fine + private final Int2ObjectHashMap cachedScores; + private final SearchScoreProvider scoreProvider; + private int rerankCalls; + + public CachingReranker(SearchScoreProvider scoreProvider) { + this.scoreProvider = scoreProvider; + cachedScores = new Int2ObjectHashMap<>(); + rerankCalls = 0; + } + + @Override + public float similarityTo(int node2) { + if (cachedScores.containsKey(node2)) { + return cachedScores.get(node2); + } + rerankCalls++; + float score = scoreProvider.reranker().similarityTo(node2); + cachedScores.put(node2, Float.valueOf(score)); + return score; + } + + public int getRerankCalls() { + return rerankCalls; + } + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java index 2eed24473..0bf3cd602 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java @@ -144,7 +144,7 @@ public int[] nodesCopy() { *

* Only the best result or results whose approximate score is at least `rerankFloor` will be reranked. */ - public float rerank(int topK, ScoreFunction.Reranker reranker, float rerankFloor, NodeQueue reranked, NodesUnsorted unused) { + public float rerank(int topK, ScoreFunction.ExactScoreFunction reranker, float rerankFloor, NodeQueue reranked, NodesUnsorted unused) { // Rescore the nodes whose approximate score meets the floor. Nodes that do not will be marked as -1 int[] ids = new int[size()]; float[] exactScores = new float[size()]; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java index ea7fb7069..eb8f6df24 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java @@ -24,7 +24,9 @@ package io.github.jbellis.jvector.graph; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.util.ExplicitThreadLocal; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; import java.util.function.Supplier; @@ -105,4 +107,19 @@ default Supplier threadLocalSupplier() { var tl = ExplicitThreadLocal.withInitial(this::copy); return tl::get; } + + /** + * Convenience method to create an ExactScoreFunction for reranking. The resulting function is NOT thread-safe. + */ + default ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf) { + return new ScoreFunction.ExactScoreFunction() { + private final VectorFloat scratch = vts.createFloatVector(dimension()); + + @Override + public float similarityTo(int node2) { + getVectorInto(node2, scratch, 0); + return vsf.compare(queryVector, scratch); + } + }; + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/SearchResult.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/SearchResult.java index 711ca4ea9..9505463fd 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/SearchResult.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/SearchResult.java @@ -22,11 +22,13 @@ public final class SearchResult { private final NodeScore[] nodes; private final int visitedCount; + private final int rerankedCount; private final float worstApproximateScoreInTopK; - public SearchResult(NodeScore[] nodes, int visitedCount, float worstApproximateScoreInTopK) { + public SearchResult(NodeScore[] nodes, int visitedCount, int rerankedCount, float worstApproximateScoreInTopK) { this.nodes = nodes; this.visitedCount = visitedCount; + this.rerankedCount = rerankedCount; this.worstApproximateScoreInTopK = worstApproximateScoreInTopK; } @@ -44,6 +46,13 @@ public int getVisitedCount() { return visitedCount; } + /** + * @return the number of nodes that were reranked during the search + */ + public int getRerankedCount() { + return rerankedCount; + } + /** * @return the worst approximate score of the top K nodes in the search result. Useful * for passing to rerankFloor during search across multiple indexes. Will be diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CachingGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CachingGraphIndex.java index 4184669e4..59d551213 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CachingGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/CachingGraphIndex.java @@ -118,7 +118,7 @@ public void close() throws IOException { } @Override - public ScoreFunction.Reranker rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf) { + public ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf) { return view.rerankerFor(queryVector, vsf); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index a791d69eb..219813102 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -260,17 +260,12 @@ public void close() throws IOException { reader.close(); } - public ScoreFunction.Reranker rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf, Set permissibleFeatures) { - if (permissibleFeatures.contains(FeatureId.INLINE_VECTORS) && features.containsKey(FeatureId.INLINE_VECTORS)) { - return ScoreFunction.Reranker.from(queryVector, vsf, this); - } else { - throw new UnsupportedOperationException("No reranker available for this graph"); - } - } - @Override - public ScoreFunction.Reranker rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf) { - return rerankerFor(queryVector, vsf, FeatureId.ALL); + public ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf) { + if (!features.containsKey(FeatureId.INLINE_VECTORS)) { + throw new UnsupportedOperationException("No inline vectors in this graph"); + } + return RandomAccessVectorValues.super.rerankerFor(queryVector, vsf); } @Override diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/ScoreFunction.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/ScoreFunction.java index fe5ba2337..2dad2d689 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/ScoreFunction.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/ScoreFunction.java @@ -64,41 +64,6 @@ default boolean isExact() { } } - /** - * An ExactScoreFunction with an optimized batch `similarityTo` method for reranking - * results after an approximate-scored search. - */ - interface Reranker extends ExactScoreFunction { - /** - * @return a vector of size `nodes.length` containing the corresponding score for - * each of the nodes in `nodes`. Used when reranking search results. - */ - VectorFloat similarityTo(int[] nodes); - - static Reranker from(VectorFloat queryVector, VectorSimilarityFunction vsf, RandomAccessVectorValues vp) { - return new Reranker() { - @Override - public VectorFloat similarityTo(int[] nodes) { - var results = vts.createFloatVector(nodes.length); - var nodeCount = nodes.length; - var dimension = queryVector.length(); - var packedVectors = vts.createFloatVector(nodeCount * dimension); - for (int i1 = 0; i1 < nodeCount; i1++) { - var node = nodes[i1]; - vp.getVectorInto(node, packedVectors, i1 * dimension); - } - vsf.compareMulti(queryVector, packedVectors, results); - return results; - } - - @Override - public float similarityTo(int node2) { - return vsf.compare(queryVector, vp.getVector(node2)); - } - }; - } - } - interface ApproximateScoreFunction extends ScoreFunction { default boolean isExact() { return false; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/SearchScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/SearchScoreProvider.java index 4a30bdfbd..4d6d22d5c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/SearchScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/SearchScoreProvider.java @@ -23,7 +23,7 @@ /** Encapsulates comparing node distances to a specific vector for GraphSearcher. */ public final class SearchScoreProvider { private final ScoreFunction scoreFunction; - private final ScoreFunction.Reranker reranker; + private final ScoreFunction.ExactScoreFunction reranker; /** * @param scoreFunction the primary, fast scoring function @@ -40,10 +40,10 @@ public SearchScoreProvider(ScoreFunction scoreFunction) { * Generally, reranker will be null iff scoreFunction is an ExactScoreFunction. However, * it is allowed, and sometimes useful, to only perform approximate scoring without reranking. *

- * Most often it will be convenient to get the Reranker either using `Reranker.from` + * Most often it will be convenient to get the reranker either using `ExactScoreFunction.from` * or `ScoringView.rerankerFor`. */ - public SearchScoreProvider(ScoreFunction scoreFunction, ScoreFunction.Reranker reranker) { + public SearchScoreProvider(ScoreFunction scoreFunction, ScoreFunction.ExactScoreFunction reranker) { assert scoreFunction != null; this.scoreFunction = scoreFunction; this.reranker = reranker; @@ -53,7 +53,7 @@ public ScoreFunction scoreFunction() { return scoreFunction; } - public ScoreFunction.Reranker reranker() { + public ScoreFunction.ExactScoreFunction reranker() { return reranker; } @@ -69,7 +69,13 @@ public ScoreFunction.ExactScoreFunction exactScoreFunction() { * e.g. during construction. */ public static SearchScoreProvider exact(VectorFloat v, VectorSimilarityFunction vsf, RandomAccessVectorValues ravv) { - var sf = ScoreFunction.Reranker.from(v, vsf, ravv); + // don't use ESF.reranker, we need thread safety here + var sf = new ScoreFunction.ExactScoreFunction() { + @Override + public float similarityTo(int node2) { + return vsf.compare(v, ravv.getVector(node2)); + } + }; return new SearchScoreProvider(sf); } } \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorSimilarityFunction.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorSimilarityFunction.java index 3993ee059..bdef5aefb 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorSimilarityFunction.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorSimilarityFunction.java @@ -39,11 +39,6 @@ public enum VectorSimilarityFunction { public float compare(VectorFloat v1, VectorFloat v2) { return 1 / (1 + VectorUtil.squareL2Distance(v1, v2)); } - - @Override - public void compareMulti(VectorFloat v1, VectorFloat packedVectors, VectorFloat results) { - VectorUtil.euclideanMultiScore(v1, packedVectors, results); - } }, /** @@ -58,11 +53,6 @@ public void compareMulti(VectorFloat v1, VectorFloat packedVectors, Vector public float compare(VectorFloat v1, VectorFloat v2) { return (1 + VectorUtil.dotProduct(v1, v2)) / 2; } - - @Override - public void compareMulti(VectorFloat v1, VectorFloat packedVectors, VectorFloat results) { - VectorUtil.dotProductMultiScore(v1, packedVectors, results); - } }, /** @@ -76,11 +66,6 @@ public void compareMulti(VectorFloat v1, VectorFloat packedVectors, Vector public float compare(VectorFloat v1, VectorFloat v2) { return (1 + VectorUtil.cosine(v1, v2)) / 2; } - - @Override - public void compareMulti(VectorFloat v1, VectorFloat packedVectors, VectorFloat results) { - VectorUtil.cosineMultiScore(v1, packedVectors, results); - } }; /** @@ -92,13 +77,4 @@ public void compareMulti(VectorFloat v1, VectorFloat packedVectors, Vector * @return the value of the similarity function applied to the two vectors */ public abstract float compare(VectorFloat v1, VectorFloat v2); - - /** - * Calculates similarity scores between a query vector and multiple vectors with a specified function. Higher - * similarity scores correspond to closer vectors. - * - * @param v1 a vector - * @param packedVectors N vectors packed into a single vector, of N * v1.length() dimension - */ - public abstract void compareMulti(VectorFloat v1, VectorFloat packedVectors, VectorFloat results); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index 97e2888c2..a6d87807e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -177,39 +177,6 @@ public static void quantizePartials(float delta, VectorFloat partials, Vector impl.quantizePartials(delta, partials, partialBase, quantizedPartials); } - /** - * Calculates the dot product similarity scores between v1 and multiple vectors packed into v2. - * Note that unlike the dotProduct, this method puts similarity scores into results, taking this responsibility from VectorSimilarityFunction. - * @param v1 the query vector - * @param v2 multiple vectors to compare against - * @param results the output vector to store the similarity scores. This should be pre-allocated to the same size as the number of vectors in v2. - */ - public static void dotProductMultiScore(VectorFloat v1, VectorFloat v2, VectorFloat results) { - impl.dotProductMultiScore(v1, v2, results); - } - - /** - * Calculates the Euclidean similarity scores between v1 and multiple vectors packed into v2. - * Note that unlike the squareDistance, this method puts similarity scores into results, taking this responsibility from VectorSimilarityFunction. - * @param v1 the query vector - * @param v2 multiple vectors to compare against - * @param results the output vector to store the similarity scores. This should be pre-allocated to the same size as the number of vectors in v2. - */ - public static void euclideanMultiScore(VectorFloat v1, VectorFloat v2, VectorFloat results) { - impl.squareL2DistanceMultiScore(v1, v2, results); - } - - /** - * Calculates the cosine similarity scores between v1 and multiple vectors packed into v2. - * Note that unlike the cosine, this method puts similarity scores into results, taking this responsibility from VectorSimilarityFunction. - * @param v1 the query vector - * @param v2 multiple vectors to compare against - * @param results the output vector to store the similarity scores. This should be pre-allocated to the same size as the number of vectors in v2. - */ - public static void cosineMultiScore(VectorFloat v1, VectorFloat v2, VectorFloat results) { - impl.cosineMultiScore(v1, v2, results); - } - /** * Calculates the maximum value in the vector. * @param v vector diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index caba2b97e..46cb4f18a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -199,21 +199,4 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int float max(VectorFloat v); float min(VectorFloat v); - default void dotProductMultiScore(VectorFloat v1, VectorFloat v2, VectorFloat results) { - for (int i = 0; i < results.length(); i++) { - results.set(i, (1 + dotProduct(v1, 0, v2, i * v1.length(), v1.length())) / 2); - } - } - - default void squareL2DistanceMultiScore(VectorFloat v1, VectorFloat v2, VectorFloat results) { - for (int i = 0; i < results.length(); i++) { - results.set(i, 1 / (1 + squareDistance(v1, 0, v2, i * v1.length(), v1.length()))); - } - } - - default void cosineMultiScore(VectorFloat v1, VectorFloat v2, VectorFloat results) { - for (int i = 0; i < results.length(); i++) { - results.set(i, (1 + cosine(v1, 0, v2, i * v1.length(), v1.length())) / 2); - } - } } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java index 424d2edca..3944b9633 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java @@ -271,7 +271,7 @@ String search(String input, SessionContext ctx) { try (var view = ctx.index.getView()) { var rr = view instanceof GraphIndex.ScoringView ? ((GraphIndex.ScoringView) view).rerankerFor(queryVector, ctx.similarityFunction) - : ScoreFunction.Reranker.from(queryVector, ctx.similarityFunction, ctx.ravv); + : ctx.ravv.rerankerFor(queryVector, ctx.similarityFunction); var ssp = new SearchScoreProvider(sf, rr); r = new GraphSearcher(ctx.index).search(ssp, searchEf, Bits.ALL); } catch (Exception e) { diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java index 0fec3420c..55e74beef 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java @@ -35,7 +35,7 @@ import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.ScoreFunction.ApproximateScoreFunction; -import io.github.jbellis.jvector.graph.similarity.ScoreFunction.Reranker; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction.ExactScoreFunction; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; import io.github.jbellis.jvector.pq.PQVectors; import io.github.jbellis.jvector.pq.ProductQuantization; @@ -197,7 +197,7 @@ public static void siftDiskAnn(List> baseVectors, List, SearchScoreProvider> sspFactory = q -> { ApproximateScoreFunction asf = pqv.precomputedScoreFunctionFor(q, VectorSimilarityFunction.EUCLIDEAN); - Reranker reranker = index.getView().rerankerFor(q, VectorSimilarityFunction.EUCLIDEAN); + ExactScoreFunction reranker = index.getView().rerankerFor(q, VectorSimilarityFunction.EUCLIDEAN); return new SearchScoreProvider(asf, reranker); }; // measure our recall against the (exactly computed) ground truth @@ -258,7 +258,7 @@ public static void siftDiskAnnLTM(List> baseVectors, List, SearchScoreProvider> sspFactory = q -> { ApproximateScoreFunction asf = pqvSearch.precomputedScoreFunctionFor(q, VectorSimilarityFunction.EUCLIDEAN); - Reranker reranker = index.getView().rerankerFor(q, VectorSimilarityFunction.EUCLIDEAN); + ExactScoreFunction reranker = index.getView().rerankerFor(q, VectorSimilarityFunction.EUCLIDEAN); return new SearchScoreProvider(asf, reranker); }; testRecall(index, queryVectors, groundTruth, sspFactory); diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index 24a024afc..886186fab 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -333,90 +333,6 @@ void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codeboo } } -void dot_product_multi_f32_512(const float* v1, const float* packedv2, int v1Length, int resultsLength, float* results) { - int ao = 0; - int simd_length = v1Length - (v1Length % 16); - - - if (v1Length >= 16) { - __m512 sums[resultsLength]; // Array of sums for each subvector in c - for (int k = 0; k < resultsLength; ++k) { - sums[k] = _mm512_setzero_ps(); - } - - for (; ao < simd_length; ao += 16) { - __m512 va = _mm512_loadu_ps(v1 + ao); - - for (int k = 0; k < resultsLength; ++k) { - // Load float32 from the k-th subvector of c - __m512 vc = _mm512_loadu_ps(packedv2 + ao + (k * v1Length)); - // Multiply and accumulate for the k-th subvector - sums[k] = _mm512_fmadd_ps(va, vc, sums[k]); - } - } - - // Horizontal sum of the vectors to get K dot products - for (int k = 0; k < resultsLength; ++k) { - results[k] = _mm512_reduce_add_ps(sums[k]); - } - } - - // Scalar computation for remaining elements - for (; ao < v1Length; ao++) { - for (int k = 0; k < resultsLength; ++k) { - results[k] += v1[ao] * packedv2[ao + (k * v1Length)]; - } - } - - // convert to scores - for (int k = 0; k < resultsLength; ++k) { - results[k] = (1.0f + results[k] ) / 2; - } -} - -void square_distance_multi_f32_512(const float* v1, const float* packedv2, int v1Length, int resultsLength, float* results) { - int ao = 0; - int simd_length = v1Length - (v1Length % 16); - - - if (v1Length >= 16) { - __m512 sums[resultsLength]; // Array of sums for each subvector in c - for (int k = 0; k < resultsLength; ++k) { - sums[k] = _mm512_setzero_ps(); - } - - for (; ao < simd_length; ao += 16) { - __m512 va = _mm512_loadu_ps(v1 + ao); - - for (int k = 0; k < resultsLength; ++k) { - // Load float32 from the k-th subvector of c - __m512 vc = _mm512_loadu_ps(packedv2 + ao + (k * v1Length)); - // Multiply and accumulate for the k-th subvector - __m512 diff = _mm512_sub_ps(va, vc); - sums[k] = _mm512_fmadd_ps(diff, diff, sums[k]); - } - } - - // Horizontal sum of the vectors to get K dot products - for (int k = 0; k < resultsLength; ++k) { - results[k] = _mm512_reduce_add_ps(sums[k]); - } - } - - // Scalar computation for remaining elements - for (; ao < v1Length; ao++) { - for (int k = 0; k < resultsLength; ++k) { - float diff = v1[ao] - packedv2[ao + (k * v1Length)]; - results[k] += diff * diff; - } - } - - // convert to scores - for (int k = 0; k < resultsLength; ++k) { - results[k] = 1.0f / (1 + results[k]); - } -} - /* Bulk shuffles for Fused ADC * These shuffles take an array of transposed PQ neighbors (in shuffles) and an of quantized partial distances to shuffle. * Partial distance quantization depends on the best distance and delta used to quantize. diff --git a/jvector-native/src/main/c/jvector_simd.h b/jvector-native/src/main/c/jvector_simd.h index 493aa184c..a5410ef5f 100644 --- a/jvector-native/src/main/c/jvector_simd.h +++ b/jvector-native/src/main/c/jvector_simd.h @@ -33,6 +33,4 @@ void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances); void calculate_partial_sums_best_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances); -void dot_product_multi_f32_512(const float* v1, const float* packedv2, int v1Length, int resultsLength, float* results); -void square_distance_multi_f32_512(const float* v1, const float* packedv2, int v1Length, int resultsLength, float* results); #endif diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index 1dd729e01..3fe12e719 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -133,16 +133,6 @@ public void calculatePartialSums(VectorFloat codebook, int codebookBase, int } } - @Override - public void dotProductMultiScore(VectorFloat v1, VectorFloat v2, VectorFloat results) { - NativeSimdOps.dot_product_multi_f32_512(((MemorySegmentVectorFloat)v1).get(), ((MemorySegmentVectorFloat)v2).get(), v1.length(), results.length(), ((MemorySegmentVectorFloat)results).get()); - } - - @Override - public void squareL2DistanceMultiScore(VectorFloat v1, VectorFloat v2, VectorFloat results) { - NativeSimdOps.square_distance_multi_f32_512(((MemorySegmentVectorFloat)v1).get(), ((MemorySegmentVectorFloat)v2).get(), v1.length(), results.length(), ((MemorySegmentVectorFloat)results).get()); - } - @Override public void quantizePartials(float delta, VectorFloat partials, VectorFloat partialBases, ByteSequence quantizedPartials) { VectorSimdOps.quantizePartials(delta, (MemorySegmentVectorFloat) partials, (MemorySegmentVectorFloat) partialBases, (MemorySegmentByteSequence) quantizedPartials); diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java index 52d82f380..e148b1be9 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java @@ -661,105 +661,5 @@ public static void calculate_partial_sums_best_euclidean_f32_512(MemorySegment c throw new AssertionError("should not reach here", ex$); } } - - private static class dot_product_multi_f32_512 { - public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( - NativeSimdOps.C_POINTER, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER - ); - - public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle( - NativeSimdOps.findOrThrow("dot_product_multi_f32_512"), - DESC, Linker.Option.critical(true)); - } - - /** - * Function descriptor for: - * {@snippet lang=c : - * void dot_product_multi_f32_512(const float *v1, const float *packedv2, int v1Length, int resultsLength, float *results) - * } - */ - public static FunctionDescriptor dot_product_multi_f32_512$descriptor() { - return dot_product_multi_f32_512.DESC; - } - - /** - * Downcall method handle for: - * {@snippet lang=c : - * void dot_product_multi_f32_512(const float *v1, const float *packedv2, int v1Length, int resultsLength, float *results) - * } - */ - public static MethodHandle dot_product_multi_f32_512$handle() { - return dot_product_multi_f32_512.HANDLE; - } - /** - * {@snippet lang=c : - * void dot_product_multi_f32_512(const float *v1, const float *packedv2, int v1Length, int resultsLength, float *results) - * } - */ - public static void dot_product_multi_f32_512(MemorySegment v1, MemorySegment packedv2, int v1Length, int resultsLength, MemorySegment results) { - var mh$ = dot_product_multi_f32_512.HANDLE; - try { - if (TRACE_DOWNCALLS) { - traceDowncall("dot_product_multi_f32_512", v1, packedv2, v1Length, resultsLength, results); - } - mh$.invokeExact(v1, packedv2, v1Length, resultsLength, results); - } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); - } - } - - private static class square_distance_multi_f32_512 { - public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( - NativeSimdOps.C_POINTER, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER - ); - - public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle( - NativeSimdOps.findOrThrow("square_distance_multi_f32_512"), - DESC, Linker.Option.critical(true)); - } - - /** - * Function descriptor for: - * {@snippet lang=c : - * void square_distance_multi_f32_512(const float *v1, const float *packedv2, int v1Length, int resultsLength, float *results) - * } - */ - public static FunctionDescriptor square_distance_multi_f32_512$descriptor() { - return square_distance_multi_f32_512.DESC; - } - - /** - * Downcall method handle for: - * {@snippet lang=c : - * void square_distance_multi_f32_512(const float *v1, const float *packedv2, int v1Length, int resultsLength, float *results) - * } - */ - public static MethodHandle square_distance_multi_f32_512$handle() { - return square_distance_multi_f32_512.HANDLE; - } - /** - * {@snippet lang=c : - * void square_distance_multi_f32_512(const float *v1, const float *packedv2, int v1Length, int resultsLength, float *results) - * } - */ - public static void square_distance_multi_f32_512(MemorySegment v1, MemorySegment packedv2, int v1Length, int resultsLength, MemorySegment results) { - var mh$ = square_distance_multi_f32_512.HANDLE; - try { - if (TRACE_DOWNCALLS) { - traceDowncall("square_distance_multi_f32_512", v1, packedv2, v1Length, resultsLength, results); - } - mh$.invokeExact(v1, packedv2, v1Length, resultsLength, results); - } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); - } - } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/Test2DThreshold.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/Test2DThreshold.java index f9fa7abbe..807ef2937 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/Test2DThreshold.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/Test2DThreshold.java @@ -60,7 +60,7 @@ public void testThreshold(int graphSize, int maxDegree) throws IOException { for (int i = 0; i < 10; i++) { TestParams tp = createTestParams(vectors); - var sf = ScoreFunction.Reranker.from(tp.q, VectorSimilarityFunction.EUCLIDEAN, ravv); + var sf = ravv.rerankerFor(tp.q, VectorSimilarityFunction.EUCLIDEAN); var result = searcher.search(new SearchScoreProvider(sf), vectors.length, tp.th, Bits.ALL); assert result.getVisitedCount() < vectors.length : "visited all vectors for threshold " + tp.th; diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java index 4b361044f..bc72dcf51 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java @@ -30,6 +30,8 @@ import io.github.jbellis.jvector.TestUtil; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.pq.PQVectors; +import io.github.jbellis.jvector.pq.ProductQuantization; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.BoundedLongHeap; import io.github.jbellis.jvector.util.FixedBitSet; @@ -136,7 +138,7 @@ public void testResume() { var query = randomVector(dim); var searcher = new GraphSearcher(graph); - var ssp = new SearchScoreProvider(ScoreFunction.Reranker.from(query, similarityFunction, vectors)); + var ssp = new SearchScoreProvider(vectors.rerankerFor(query, similarityFunction)); var initial = searcher.search(ssp, initialTopK, acceptOrds); assertEquals(initialTopK, initial.getNodes().length); @@ -155,6 +157,34 @@ public void testResume() { } } + @Test + // resuming a search should not need to rerank the nodes that were already evaluated + public void testRerankCaching() { + int size = 1000; + int dim = 2; + var vectors = vectorValues(size, dim); + var builder = new GraphIndexBuilder(vectors, similarityFunction, 20, 30, 1.0f, 1.4f); + var graph = builder.build(vectors); + + var pq = ProductQuantization.compute(vectors, 2, 256, false); + var encoded = pq.encodeAll(vectors); + var pqv = new PQVectors(pq, encoded); + + int topK = 10; + int rerankK = 30; + var query = randomVector(dim); + var searcher = new GraphSearcher(graph); + + var ssp = new SearchScoreProvider(pqv.scoreFunctionFor(query, similarityFunction), + vectors.rerankerFor(query, similarityFunction)); + var initial = searcher.search(ssp, topK, rerankK, 0.0f, 0.0f, Bits.ALL); + assertEquals(topK, initial.getNodes().length); + assertEquals(rerankK, initial.getRerankedCount()); + + var resumed = searcher.resume(topK, rerankK); + assert resumed.getRerankedCount() < rerankK; + } + // If an exception is thrown during search, the next search should still function @Test public void testExceptionalTermination() { @@ -199,7 +229,7 @@ public int size() { }; var searcher = new GraphSearcher(graph); - var ssp = new SearchScoreProvider(ScoreFunction.Reranker.from(getTargetVector(), similarityFunction, wrappedVectors)); + var ssp = new SearchScoreProvider(wrappedVectors.rerankerFor(getTargetVector(), similarityFunction)); assertThrows(RuntimeException.class, () -> { searcher.search(ssp, 10, Bits.ALL); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorSimilarityFunction.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorSimilarityFunction.java deleted file mode 100644 index 9ce0765a3..000000000 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorSimilarityFunction.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.github.jbellis.jvector.graph; - -import com.carrotsearch.randomizedtesting.RandomizedTest; -import io.github.jbellis.jvector.TestUtil; -import io.github.jbellis.jvector.vector.VectorSimilarityFunction; -import io.github.jbellis.jvector.vector.VectorizationProvider; -import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import org.junit.Assert; -import org.junit.Test; - -import java.util.Collections; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -public class TestVectorSimilarityFunction extends RandomizedTest { - private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); - @Test - public void testCompareMulti() { - var random = getRandom(); - var dimension = random.nextInt(1535) + 1; - var vectors = TestUtil.createRandomVectors(100, dimension); - var q = TestUtil.randomVector(random, dimension); - var length = random.nextInt(100) + 1; - var indexes = IntStream.range(0, 100).boxed().collect(Collectors.toList()); - Collections.shuffle(indexes, random); - var ids = indexes.subList(0, length).toArray(new Integer[0]); - var results = vectorTypeSupport.createFloatVector(length); - var packedVectors = vectorTypeSupport.createFloatVector(length * dimension); - - for (int i = 0; i < length; i++) { - var v = vectors.get(ids[i]); - packedVectors.copyFrom(v, 0, i * dimension, dimension); - } - - for (VectorSimilarityFunction vsf : VectorSimilarityFunction.values()) { - results.zero(); - vsf.compareMulti(q, packedVectors, results); - for (int i = 0; i < length; i++) { - Assert.assertEquals(vsf.compare(q, vectors.get(ids[i])), results.get(i), 0.01f); - } - } - } -}