4545import java .util .stream .Collectors ;
4646import java .util .stream .IntStream ;
4747
48+ import static org .apache .commons .lang3 .ArrayUtils .shuffle ;
4849import static org .junit .Assert .assertEquals ;
4950
5051@ ThreadLeakScope (ThreadLeakScope .Scope .NONE )
@@ -61,6 +62,7 @@ public class OnHeapGraphIndexTest extends RandomizedTest {
6162 private static final float NEIGHBOR_OVERFLOW = 1.2f ;
6263 private static final boolean ADD_HIERARCHY = false ;
6364 private static final int TOP_K = 10 ;
65+ private static VectorSimilarityFunction SIMILARITY_FUNCTION = VectorSimilarityFunction .EUCLIDEAN ;
6466
6567 private Path testDirectory ;
6668
@@ -71,6 +73,7 @@ public class OnHeapGraphIndexTest extends RandomizedTest {
7173 private RandomAccessVectorValues newVectorsRavv ;
7274 private RandomAccessVectorValues allVectorsRavv ;
7375 private VectorFloat <?> queryVector ;
76+ private int [] groundTruthBaseVectors ;
7477 private int [] groundTruthAllVectors ;
7578 private BuildScoreProvider baseBuildScoreProvider ;
7679 private BuildScoreProvider allBuildScoreProvider ;
@@ -100,11 +103,12 @@ public void setup() throws IOException {
100103 allVectorsRavv = new ListRandomAccessVectorValues (allVectors , DIMENSION );
101104
102105 queryVector = createRandomVector (DIMENSION );
103- groundTruthAllVectors = getGroundTruth (allVectorsRavv , queryVector , TOP_K , VectorSimilarityFunction .EUCLIDEAN );
106+ groundTruthBaseVectors = getGroundTruth (baseVectorsRavv , queryVector , TOP_K , SIMILARITY_FUNCTION );
107+ groundTruthAllVectors = getGroundTruth (allVectorsRavv , queryVector , TOP_K , SIMILARITY_FUNCTION );
104108
105109 // score provider using the raw, in-memory vectors
106- baseBuildScoreProvider = BuildScoreProvider .randomAccessScoreProvider (baseVectorsRavv , VectorSimilarityFunction . EUCLIDEAN );
107- allBuildScoreProvider = BuildScoreProvider .randomAccessScoreProvider (allVectorsRavv , VectorSimilarityFunction . EUCLIDEAN );
110+ baseBuildScoreProvider = BuildScoreProvider .randomAccessScoreProvider (baseVectorsRavv , SIMILARITY_FUNCTION );
111+ allBuildScoreProvider = BuildScoreProvider .randomAccessScoreProvider (allVectorsRavv , SIMILARITY_FUNCTION );
108112 var baseGraphIndexBuilder = new GraphIndexBuilder (baseBuildScoreProvider ,
109113 baseVectorsRavv .dimension (),
110114 M , // graph degree
@@ -129,14 +133,38 @@ public void tearDown() {
129133 TestUtil .deleteQuietly (testDirectory );
130134 }
131135
136+ /**
137+ * Test that we can build a graph with a non-identity mapping from graph node id to ravv ordinal
138+ * and that the recall is the same as the identity mapping (meaning the graphs are equivalent)
139+ * @throws IOException exception
140+ */
141+ @ Test
142+ public void testGraphConstructionWithNonIdentityOrdinalMapping () throws IOException {
143+ // create reversed mapping from graph node id to ravv ordinal
144+ int [] graphToRavvOrdMap = IntStream .range (0 , baseVectorsRavv .size ()).map (i -> baseVectorsRavv .size () - 1 - i ).toArray ();
145+ var bsp = BuildScoreProvider .randomAccessScoreProvider (baseVectorsRavv , graphToRavvOrdMap , SIMILARITY_FUNCTION );
146+ try (var baseGraphIndexBuilder = new GraphIndexBuilder (bsp ,
147+ baseVectorsRavv .dimension (),
148+ M , // graph degree
149+ BEAM_WIDTH , // construction search depth
150+ NEIGHBOR_OVERFLOW , // allow degree overflow during construction by this factor
151+ ALPHA , // relax neighbor diversity requirement by this factor
152+ ADD_HIERARCHY ); // add the hierarchy) {
153+ var baseGraphIndexFromShuffledVectors = baseGraphIndexBuilder .build (baseVectorsRavv , graphToRavvOrdMap )) {
154+ float recallFromBaseGraphIndexFromShuffledVectors = calculateRecall (baseGraphIndexFromShuffledVectors , bsp , queryVector , groundTruthBaseVectors , TOP_K , graphToRavvOrdMap );
155+ float recallFromBaseGraphIndex = calculateRecall (baseGraphIndex , baseBuildScoreProvider , queryVector , groundTruthBaseVectors , TOP_K );
156+ Assert .assertEquals (recallFromBaseGraphIndex , recallFromBaseGraphIndexFromShuffledVectors , 0.01f );
157+ }
158+ }
132159
133160 /**
134161 * Create an {@link OnHeapGraphIndex} persist it as a {@link OnDiskGraphIndex} and reconstruct back to a mutable {@link OnHeapGraphIndex}
162+ * Using identity mapping from graph node id to ravv ordinal
135163 * Make sure that both graphs are equivalent
136164 * @throws IOException
137165 */
138166 @ Test
139- public void testReconstructionOfOnHeapGraphIndex () throws IOException {
167+ public void testReconstructionOfOnHeapGraphIndex_withIdentityOrdinalMapping () throws IOException {
140168 var graphOutputPath = testDirectory .resolve ("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex .getClass ().getSimpleName ());
141169 var heapGraphOutputPath = testDirectory .resolve ("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex .getClass ().getSimpleName () + "_onHeap" );
142170
@@ -166,6 +194,44 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException {
166194 }
167195 }
168196
197+ /**
198+ * Create an {@link OnHeapGraphIndex} persist it as a {@link OnDiskGraphIndex} and reconstruct back to a mutable {@link OnHeapGraphIndex}
199+ * Using random mapping from graph node id to ravv ordinal
200+ * Make sure that both graphs are equivalent
201+ * @throws IOException
202+ */
203+ @ Test
204+ public void testReconstructionOfOnHeapGraphIndex_withNonIdentityOrdinalMapping () throws IOException {
205+ var graphOutputPath = testDirectory .resolve ("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex .getClass ().getSimpleName ());
206+ var heapGraphOutputPath = testDirectory .resolve ("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex .getClass ().getSimpleName () + "_onHeap" );
207+
208+ // create reversed mapping from graph node id to ravv ordinal
209+ int [] graphToRavvOrdMap = IntStream .range (0 , baseVectorsRavv .size ()).map (i -> baseVectorsRavv .size () - 1 - i ).toArray ();
210+ var bsp = BuildScoreProvider .randomAccessScoreProvider (baseVectorsRavv , graphToRavvOrdMap , SIMILARITY_FUNCTION );
211+ try (var baseGraphIndexBuilder = new GraphIndexBuilder (bsp ,
212+ baseVectorsRavv .dimension (),
213+ M , // graph degree
214+ BEAM_WIDTH , // construction search depth
215+ NEIGHBOR_OVERFLOW , // allow degree overflow during construction by this factor
216+ ALPHA , // relax neighbor diversity requirement by this factor
217+ ADD_HIERARCHY ); // add the hierarchy) {
218+ var baseGraphIndex = baseGraphIndexBuilder .build (baseVectorsRavv , graphToRavvOrdMap )) {
219+ log .info ("Writing graph to {}" , graphOutputPath );
220+ TestUtil .writeGraph (baseGraphIndex , baseVectorsRavv , graphOutputPath );
221+
222+ log .info ("Writing on-heap graph to {}" , heapGraphOutputPath );
223+ try (SimpleWriter writer = new SimpleWriter (heapGraphOutputPath .toAbsolutePath ())) {
224+ ((OnHeapGraphIndex ) baseGraphIndex ).save (writer );
225+ }
226+
227+ log .info ("Reading on-heap graph from {}" , heapGraphOutputPath );
228+ try (var readerSupplier = new SimpleMappedReader .Supplier (heapGraphOutputPath .toAbsolutePath ())) {
229+ MutableGraphIndex reconstructedOnHeapGraphIndex = OnHeapGraphIndex .load (readerSupplier .get (), baseVectorsRavv .dimension (), NEIGHBOR_OVERFLOW , new VamanaDiversityProvider (bsp , ALPHA ));
230+ TestUtil .assertGraphEquals (baseGraphIndex , reconstructedOnHeapGraphIndex );
231+ }
232+ }
233+ }
234+
169235 /**
170236 * Create {@link OnDiskGraphIndex} then append to it via {@link GraphIndexBuilder#buildAndMergeNewNodes}
171237 * Verify that the resulting OnHeapGraphIndex is equivalent to the graph that would have been alternatively generated by bulk index into a new {@link OnDiskGraphIndex}
@@ -230,10 +296,27 @@ private static int[] getGroundTruth(RandomAccessVectorValues ravv, VectorFloat<?
230296 }
231297
232298 private static float calculateRecall (ImmutableGraphIndex graphIndex , BuildScoreProvider buildScoreProvider , VectorFloat <?> queryVector , int [] groundTruth , int k ) throws IOException {
299+ return calculateRecall (graphIndex , buildScoreProvider , queryVector , groundTruth , k , null );
300+ }
301+
302+ private static float calculateRecall (ImmutableGraphIndex graphIndex , BuildScoreProvider buildScoreProvider , VectorFloat <?> queryVector , int [] groundTruth , int k , int [] graphToRavvOrdMap ) throws IOException {
233303 try (GraphSearcher graphSearcher = new GraphSearcher (graphIndex )){
234304 SearchScoreProvider ssp = buildScoreProvider .searchProviderFor (queryVector );
235305 var searchResults = graphSearcher .search (ssp , k , Bits .ALL );
236- var predicted = Arrays .stream (searchResults .getNodes ()).mapToInt (nodeScore -> nodeScore .node ).boxed ().collect (Collectors .toSet ());
306+ Set <Integer > predicted ;
307+ if (graphToRavvOrdMap != null ) {
308+ // Convert graph node IDs to RAVV ordinals for comparison with ground truth
309+ predicted = Arrays .stream (searchResults .getNodes ())
310+ .mapToInt (nodeScore -> graphToRavvOrdMap [nodeScore .node ])
311+ .boxed ()
312+ .collect (Collectors .toSet ());
313+ } else {
314+ // Identity mapping: graph node IDs == RAVV ordinals
315+ predicted = Arrays .stream (searchResults .getNodes ())
316+ .mapToInt (nodeScore -> nodeScore .node )
317+ .boxed ()
318+ .collect (Collectors .toSet ());
319+ }
237320 return calculateRecall (predicted , groundTruth , k );
238321 }
239322 }
@@ -249,11 +332,8 @@ private static float calculateRecall(Set<Integer> predicted, int[] groundTruth,
249332 int actualK = Math .min (k , Math .min (predicted .size (), groundTruth .length ));
250333
251334 for (int i = 0 ; i < actualK ; i ++) {
252- for (int j = 0 ; j < actualK ; j ++) {
253- if (predicted .contains (groundTruth [j ])) {
254- hits ++;
255- break ;
256- }
335+ if (predicted .contains (groundTruth [i ])) {
336+ hits ++;
257337 }
258338 }
259339
0 commit comments