diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java index b71591f33..6cabad59e 100644 --- a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java @@ -247,11 +247,8 @@ private double calculateRecall(Set predicted, int[] groundTruth, int k) int actualK = Math.min(k, Math.min(predicted.size(), groundTruth.length)); for (int i = 0; i < actualK; i++) { - for (int j = 0; j < actualK; j++) { - if (predicted.contains(groundTruth[j])) { - hits++; - break; - } + if (predicted.contains(groundTruth[i])) { + hits++; } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index b99b71fe4..dfad371d0 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -451,7 +451,6 @@ public ImmutableGraphIndex build(RandomAccessVectorValues ravv) { cleanup(); return graph; } - /** * Validates that the current entry node has been completely added. */ @@ -979,7 +978,6 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { * @param newVectors a super set RAVV containing the new vectors to be added to the graph as well as the old ones that are already in the graph * @param buildScoreProvider the provider responsible for calculating build scores. * @param startingNodeOffset the offset in the newVectors RAVV where the new vectors start - * @param graphToRavvOrdMap a mapping from the old graph's node ids to the newVectors RAVV node ids * @param beamWidth the width of the beam used during the graph building process. * @param overflowRatio the ratio of extra neighbors to allow temporarily when inserting a node. * @param alpha the weight factor for balancing score computations. @@ -990,15 +988,47 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { */ @Experimental public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in, - RandomAccessVectorValues newVectors, + RemappedRandomAccessVectorValues newVectors, BuildScoreProvider buildScoreProvider, int startingNodeOffset, - int[] graphToRavvOrdMap, int beamWidth, float overflowRatio, float alpha, boolean addHierarchy) throws IOException { + return buildAndMergeNewNodes(in, newVectors, buildScoreProvider, startingNodeOffset, beamWidth, overflowRatio, alpha, addHierarchy, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool()); + } + + /** + * Convenience method to build a new graph from an existing one, with the addition of new nodes. + * This is useful when we want to merge a new set of vectors into an existing graph that is already on disk. + * + * @param in a reader from which to read the on-heap graph. + * @param newVectors a super set RAVV containing the new vectors to be added to the graph as well as the old ones that are already in the graph + * @param buildScoreProvider the provider responsible for calculating build scores. + * @param startingNodeOffset the offset in the newVectors RAVV where the new vectors start + * @param beamWidth the width of the beam used during the graph building process. + * @param overflowRatio the ratio of extra neighbors to allow temporarily when inserting a node. + * @param alpha the weight factor for balancing score computations. + * @param addHierarchy whether to add hierarchical structures while building the graph. + * @param simdExecutor the ForkJoinPool executor used for SIMD tasks during graph building. + * @param parallelExecutor the ForkJoinPool executor used for general parallelization during graph building. + * + * @return the in-memory representation of the graph index. + * @throws IOException if an I/O error occurs during the graph loading or conversion process. + */ + @Experimental + public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in, + RemappedRandomAccessVectorValues newVectors, + BuildScoreProvider buildScoreProvider, + int startingNodeOffset, + int beamWidth, + float overflowRatio, + float alpha, + boolean addHierarchy, + ForkJoinPool simdExecutor, + ForkJoinPool parallelExecutor) throws IOException { + var diversityProvider = new VamanaDiversityProvider(buildScoreProvider, alpha); try (MutableGraphIndex graph = OnHeapGraphIndex.load(in, newVectors.dimension(), overflowRatio, diversityProvider);) { @@ -1012,15 +1042,15 @@ public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in, alpha, addHierarchy, true, - PhysicalCoreExecutor.pool(), - ForkJoinPool.commonPool() + simdExecutor, + parallelExecutor ); var vv = newVectors.threadLocalSupplier(); // parallel graph construction from the merge documents Ids - PhysicalCoreExecutor.pool().submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> { - builder.addGraphNode(ord, vv.get().getVector(graphToRavvOrdMap[ord])); + simdExecutor.submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> { + builder.addGraphNode(ord, vv.get().getVector(ord)); })).join(); builder.cleanup(); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java index 65d14f91b..13706a482 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java @@ -45,22 +45,25 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.apache.commons.lang3.ArrayUtils.shuffle; import static org.junit.Assert.assertEquals; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class OnHeapGraphIndexTest extends RandomizedTest { private final static Logger log = org.apache.logging.log4j.LogManager.getLogger(OnHeapGraphIndexTest.class); private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); - private static final int NUM_BASE_VECTORS = 100; - private static final int NUM_NEW_VECTORS = 100; + private static final int NUM_BASE_VECTORS = 1000; + private static final int NUM_NEW_VECTORS = 1000; private static final int NUM_ALL_VECTORS = NUM_BASE_VECTORS + NUM_NEW_VECTORS; private static final int DIMENSION = 16; private static final int M = 8; - private static final int BEAM_WIDTH = 100; + private static final int BEAM_WIDTH = 200; private static final float ALPHA = 1.2f; private static final float NEIGHBOR_OVERFLOW = 1.2f; private static final boolean ADD_HIERARCHY = false; private static final int TOP_K = 10; + private static final int NUM_QUERY_VECTORS = 100; + private static VectorSimilarityFunction SIMILARITY_FUNCTION = VectorSimilarityFunction.EUCLIDEAN; private Path testDirectory; @@ -70,8 +73,9 @@ public class OnHeapGraphIndexTest extends RandomizedTest { private RandomAccessVectorValues baseVectorsRavv; private RandomAccessVectorValues newVectorsRavv; private RandomAccessVectorValues allVectorsRavv; - private VectorFloat queryVector; - private int[] groundTruthAllVectors; + private ArrayList> queryVectors; + private ArrayList groundTruthBaseVectors; + private ArrayList groundTruthAllVectors; private BuildScoreProvider baseBuildScoreProvider; private BuildScoreProvider allBuildScoreProvider; private ImmutableGraphIndex baseGraphIndex; @@ -99,12 +103,20 @@ public void setup() throws IOException { newVectorsRavv = new ListRandomAccessVectorValues(newVectors, DIMENSION); allVectorsRavv = new ListRandomAccessVectorValues(allVectors, DIMENSION); - queryVector = createRandomVector(DIMENSION); - groundTruthAllVectors = getGroundTruth(allVectorsRavv, queryVector, TOP_K, VectorSimilarityFunction.EUCLIDEAN); + // Create multiple query vectors for more stable recall measurements + queryVectors = new ArrayList<>(NUM_QUERY_VECTORS); + groundTruthBaseVectors = new ArrayList<>(NUM_QUERY_VECTORS); + groundTruthAllVectors = new ArrayList<>(NUM_QUERY_VECTORS); + for (int i = 0; i < NUM_QUERY_VECTORS; i++) { + VectorFloat queryVector = createRandomVector(DIMENSION); + queryVectors.add(queryVector); + groundTruthBaseVectors.add(getGroundTruth(baseVectorsRavv, queryVector, TOP_K, SIMILARITY_FUNCTION)); + groundTruthAllVectors.add(getGroundTruth(allVectorsRavv, queryVector, TOP_K, SIMILARITY_FUNCTION)); + } // score provider using the raw, in-memory vectors - baseBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(baseVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); - allBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(allVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); + baseBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(baseVectorsRavv, SIMILARITY_FUNCTION); + allBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(allVectorsRavv, SIMILARITY_FUNCTION); var baseGraphIndexBuilder = new GraphIndexBuilder(baseBuildScoreProvider, baseVectorsRavv.dimension(), M, // graph degree @@ -129,14 +141,39 @@ public void tearDown() { TestUtil.deleteQuietly(testDirectory); } + /** + * Test that we can build a graph with a non-identity mapping from graph node id to ravv ordinal + * and that the recall is the same as the identity mapping (meaning the graphs are equivalent) + * @throws IOException exception + */ + @Test + public void testGraphConstructionWithNonIdentityOrdinalMapping() throws IOException { + // create reversed mapping from graph node id to ravv ordinal + int[] graphToRavvOrdMap = IntStream.range(0, baseVectorsRavv.size()).map(i -> baseVectorsRavv.size() - 1 - i).toArray(); + final RemappedRandomAccessVectorValues remappedBaseVectorsRavv = new RemappedRandomAccessVectorValues(baseVectorsRavv, graphToRavvOrdMap); + var bsp = BuildScoreProvider.randomAccessScoreProvider(remappedBaseVectorsRavv, SIMILARITY_FUNCTION); + try (var baseGraphIndexBuilder = new GraphIndexBuilder(bsp, + baseVectorsRavv.dimension(), + M, // graph degree + BEAM_WIDTH, // construction search depth + NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor + ALPHA, // relax neighbor diversity requirement by this factor + ADD_HIERARCHY); // add the hierarchy) { + var baseGraphIndexFromShuffledVectors = baseGraphIndexBuilder.build(remappedBaseVectorsRavv)) { + float recallFromBaseGraphIndexFromShuffledVectors = calculateAverageRecall(baseGraphIndexFromShuffledVectors, bsp, queryVectors, groundTruthBaseVectors, TOP_K, graphToRavvOrdMap); + float recallFromBaseGraphIndex = calculateAverageRecall(baseGraphIndex, baseBuildScoreProvider, queryVectors, groundTruthBaseVectors, TOP_K, null); + Assert.assertEquals(recallFromBaseGraphIndex, recallFromBaseGraphIndexFromShuffledVectors, 0.11f); + } + } /** * Create an {@link OnHeapGraphIndex} persist it as a {@link OnDiskGraphIndex} and reconstruct back to a mutable {@link OnHeapGraphIndex} + * Using identity mapping from graph node id to ravv ordinal * Make sure that both graphs are equivalent * @throws IOException */ @Test - public void testReconstructionOfOnHeapGraphIndex() throws IOException { + public void testReconstructionOfOnHeapGraphIndex_withIdentityOrdinalMapping() throws IOException { var graphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); var heapGraphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName() + "_onHeap"); @@ -166,12 +203,51 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException { } } + /** + * Create an {@link OnHeapGraphIndex} persist it as a {@link OnDiskGraphIndex} and reconstruct back to a mutable {@link OnHeapGraphIndex} + * Using random mapping from graph node id to ravv ordinal + * Make sure that both graphs are equivalent + * @throws IOException + */ + @Test + public void testReconstructionOfOnHeapGraphIndex_withNonIdentityOrdinalMapping() throws IOException { + var graphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); + var heapGraphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName() + "_onHeap"); + + // create reversed mapping from graph node id to ravv ordinal + int[] graphToRavvOrdMap = IntStream.range(0, baseVectorsRavv.size()).map(i -> baseVectorsRavv.size() - 1 - i).toArray(); + final RemappedRandomAccessVectorValues remmappedRavv = new RemappedRandomAccessVectorValues(baseVectorsRavv, graphToRavvOrdMap); + var bsp = BuildScoreProvider.randomAccessScoreProvider(remmappedRavv, SIMILARITY_FUNCTION); + try (var baseGraphIndexBuilder = new GraphIndexBuilder(bsp, + baseVectorsRavv.dimension(), + M, // graph degree + BEAM_WIDTH, // construction search depth + NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor + ALPHA, // relax neighbor diversity requirement by this factor + ADD_HIERARCHY); // add the hierarchy) { + var baseGraphIndex = baseGraphIndexBuilder.build(remmappedRavv)) { + log.info("Writing graph to {}", graphOutputPath); + TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, graphOutputPath); + + log.info("Writing on-heap graph to {}", heapGraphOutputPath); + try (SimpleWriter writer = new SimpleWriter(heapGraphOutputPath.toAbsolutePath())) { + ((OnHeapGraphIndex) baseGraphIndex).save(writer); + } + + log.info("Reading on-heap graph from {}", heapGraphOutputPath); + try (var readerSupplier = new SimpleMappedReader.Supplier(heapGraphOutputPath.toAbsolutePath())) { + MutableGraphIndex reconstructedOnHeapGraphIndex = OnHeapGraphIndex.load(readerSupplier.get(), baseVectorsRavv.dimension(), NEIGHBOR_OVERFLOW, new VamanaDiversityProvider(bsp, ALPHA)); + TestUtil.assertGraphEquals(baseGraphIndex, reconstructedOnHeapGraphIndex); + } + } + } + /** * Create {@link OnDiskGraphIndex} then append to it via {@link GraphIndexBuilder#buildAndMergeNewNodes} * Verify that the resulting OnHeapGraphIndex is equivalent to the graph that would have been alternatively generated by bulk index into a new {@link OnDiskGraphIndex} */ @Test - public void testIncrementalInsertionFromOnDiskIndex() throws IOException { + public void testIncrementalInsertionFromOnDiskIndex_withIdentityOrdinalMapping() throws IOException { var outputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); var heapGraphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName() + "_onHeap"); @@ -187,12 +263,65 @@ public void testIncrementalInsertionFromOnDiskIndex() throws IOException { try (var readerSupplier = new SimpleMappedReader.Supplier(heapGraphOutputPath.toAbsolutePath())) { // We will create a trivial 1:1 mapping between the new graph and the ravv final int[] graphToRavvOrdMap = IntStream.range(0, allVectorsRavv.size()).toArray(); - ImmutableGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(readerSupplier.get(), allVectorsRavv, allBuildScoreProvider, NUM_BASE_VECTORS, graphToRavvOrdMap, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY); + final RemappedRandomAccessVectorValues remappedAllVectorsRavv = new RemappedRandomAccessVectorValues(allVectorsRavv, graphToRavvOrdMap); + ImmutableGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(readerSupplier.get(), remappedAllVectorsRavv, allBuildScoreProvider, NUM_BASE_VECTORS, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY); + + // Verify that the recall is similar across multiple queries + // Note: Incremental insertion can have slightly different recall than bulk indexing due to the order of insertions + float recallFromReconstructedAllNodeOnHeapGraphIndex = calculateAverageRecall(reconstructedAllNodeOnHeapGraphIndex, allBuildScoreProvider, queryVectors, groundTruthAllVectors, TOP_K, null); + float recallFromAllGraphIndex = calculateAverageRecall(allGraphIndex, allBuildScoreProvider, queryVectors, groundTruthAllVectors, TOP_K, null); + Assert.assertEquals(String.format("Recall mismatch, recallFromReconstructedAllNodeOnHeapGraphIndex: %f != recallFromAllGraphIndex: %f", recallFromReconstructedAllNodeOnHeapGraphIndex, recallFromAllGraphIndex), recallFromReconstructedAllNodeOnHeapGraphIndex, recallFromAllGraphIndex, 0.05f); + } + } + + /** + * Create {@link OnDiskGraphIndex} then append to it via {@link GraphIndexBuilder#buildAndMergeNewNodes} + * Using non-identity (reversed) mapping from graph node id to ravv ordinal + * Verify that the resulting OnHeapGraphIndex has similar recall to the graph that would have been alternatively generated by bulk index into a new {@link OnDiskGraphIndex} + */ + @Test + public void testIncrementalInsertionFromOnDiskIndex_withNonIdentityOrdinalMapping() throws IOException { + var outputPath = testDirectory.resolve("testIncrementalInsertionFromOnDiskIndex_nonIdentity_" + baseGraphIndex.getClass().getSimpleName()); + var heapGraphOutputPath = testDirectory.resolve("testIncrementalInsertionFromOnDiskIndex_nonIdentity_" + baseGraphIndex.getClass().getSimpleName() + "_onHeap"); - // Verify that the recall is similar - float recallFromReconstructedAllNodeOnHeapGraphIndex = calculateRecall(reconstructedAllNodeOnHeapGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K); - float recallFromAllGraphIndex = calculateRecall(allGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K); - Assert.assertEquals(recallFromReconstructedAllNodeOnHeapGraphIndex, recallFromAllGraphIndex, 0.01f); + // Create reversed mapping from graph node id to ravv ordinal for base vectors + int[] baseGraphToRavvOrdMap = IntStream.range(0, baseVectorsRavv.size()).map(i -> baseVectorsRavv.size() - 1 - i).toArray(); + final RemappedRandomAccessVectorValues remappedBaseVectorsRavv = new RemappedRandomAccessVectorValues(baseVectorsRavv, baseGraphToRavvOrdMap); + var baseBsp = BuildScoreProvider.randomAccessScoreProvider(remappedBaseVectorsRavv, SIMILARITY_FUNCTION); + + // Build base graph with non-identity mapping + try (var baseGraphIndexBuilder = new GraphIndexBuilder(baseBsp, + baseVectorsRavv.dimension(), + M, + BEAM_WIDTH, + NEIGHBOR_OVERFLOW, + ALPHA, + ADD_HIERARCHY); + var baseGraphIndexWithMapping = baseGraphIndexBuilder.build(remappedBaseVectorsRavv)) { + + log.info("Writing graph to {}", outputPath); + TestUtil.writeGraph(baseGraphIndexWithMapping, baseVectorsRavv, outputPath); + + log.info("Writing on-heap graph to {}", heapGraphOutputPath); + try (SimpleWriter writer = new SimpleWriter(heapGraphOutputPath.toAbsolutePath())) { + ((OnHeapGraphIndex) baseGraphIndexWithMapping).save(writer); + } + + log.info("Reading on-heap graph from {}", heapGraphOutputPath); + try (var readerSupplier = new SimpleMappedReader.Supplier(heapGraphOutputPath.toAbsolutePath())) { + // Create reversed mapping for all vectors (base + new) + final int[] allGraphToRavvOrdMap = IntStream.range(0, allVectorsRavv.size()).map(i -> allVectorsRavv.size() - 1 - i).toArray(); + final RemappedRandomAccessVectorValues remappedAllVectorsRavv = new RemappedRandomAccessVectorValues(allVectorsRavv, allGraphToRavvOrdMap); + var allBsp = BuildScoreProvider.randomAccessScoreProvider(remappedAllVectorsRavv, SIMILARITY_FUNCTION); + + ImmutableGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(readerSupplier.get(), remappedAllVectorsRavv, allBsp, NUM_BASE_VECTORS, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY); + + // Verify that the recall is similar across multiple queries + // Note: Non-identity mapping can have slightly lower recall due to the complexity of merging with remapped ordinals + float recallFromReconstructedAllNodeOnHeapGraphIndex = calculateAverageRecall(reconstructedAllNodeOnHeapGraphIndex, allBsp, queryVectors, groundTruthAllVectors, TOP_K, allGraphToRavvOrdMap); + float recallFromAllGraphIndex = calculateAverageRecall(allGraphIndex, allBuildScoreProvider, queryVectors, groundTruthAllVectors, TOP_K, null); + Assert.assertEquals(String.format("Recall mismatch, recallFromReconstructedAllNodeOnHeapGraphIndex: %f != recallFromAllGraphIndex: %f", recallFromReconstructedAllNodeOnHeapGraphIndex, recallFromAllGraphIndex), recallFromReconstructedAllNodeOnHeapGraphIndex, recallFromAllGraphIndex, 0.20f); + } } } @@ -229,11 +358,48 @@ private static int[] getGroundTruth(RandomAccessVectorValues ravv, VectorFloat nodeScore.node).toArray(); } + /** + * Calculate average recall across multiple query vectors for more stable measurements + * @param graphIndex the graph index to search + * @param buildScoreProvider the score provider + * @param queryVectors the list of query vectors + * @param groundTruths the list of ground truth results for each query + * @param k the number of results to consider + * @param graphToRavvOrdMap optional mapping from graph node IDs to RAVV ordinals + * @return the average recall across all queries + */ + private static float calculateAverageRecall(ImmutableGraphIndex graphIndex, BuildScoreProvider buildScoreProvider, + ArrayList> queryVectors, ArrayList groundTruths, + int k, int[] graphToRavvOrdMap) throws IOException { + float totalRecall = 0.0f; + for (int i = 0; i < queryVectors.size(); i++) { + totalRecall += calculateRecall(graphIndex, buildScoreProvider, queryVectors.get(i), groundTruths.get(i), k, graphToRavvOrdMap); + } + return totalRecall / queryVectors.size(); + } + private static float calculateRecall(ImmutableGraphIndex graphIndex, BuildScoreProvider buildScoreProvider, VectorFloat queryVector, int[] groundTruth, int k) throws IOException { + return calculateRecall(graphIndex, buildScoreProvider, queryVector, groundTruth, k, null); + } + + private static float calculateRecall(ImmutableGraphIndex graphIndex, BuildScoreProvider buildScoreProvider, VectorFloat queryVector, int[] groundTruth, int k, int[] graphToRavvOrdMap) throws IOException { try (GraphSearcher graphSearcher = new GraphSearcher(graphIndex)){ SearchScoreProvider ssp = buildScoreProvider.searchProviderFor(queryVector); var searchResults = graphSearcher.search(ssp, k, Bits.ALL); - var predicted = Arrays.stream(searchResults.getNodes()).mapToInt(nodeScore -> nodeScore.node).boxed().collect(Collectors.toSet()); + Set predicted; + if (graphToRavvOrdMap != null) { + // Convert graph node IDs to RAVV ordinals for comparison with ground truth + predicted = Arrays.stream(searchResults.getNodes()) + .mapToInt(nodeScore -> graphToRavvOrdMap[nodeScore.node]) + .boxed() + .collect(Collectors.toSet()); + } else { + // Identity mapping: graph node IDs == RAVV ordinals + predicted = Arrays.stream(searchResults.getNodes()) + .mapToInt(nodeScore -> nodeScore.node) + .boxed() + .collect(Collectors.toSet()); + } return calculateRecall(predicted, groundTruth, k); } } @@ -249,11 +415,8 @@ private static float calculateRecall(Set predicted, int[] groundTruth, int actualK = Math.min(k, Math.min(predicted.size(), groundTruth.length)); for (int i = 0; i < actualK; i++) { - for (int j = 0; j < actualK; j++) { - if (predicted.contains(groundTruth[j])) { - hits++; - break; - } + if (predicted.contains(groundTruth[i])) { + hits++; } }