Skip to content

Commit 98507a5

Browse files
committed
harden tests for heap graph reconstruction
fix minor bug in construction Signed-off-by: Samuel Herman <sherman8915@gmail.com>
1 parent 21e4a22 commit 98507a5

File tree

3 files changed

+98
-16
lines changed

3 files changed

+98
-16
lines changed

benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RecallWithRandomVectorsBenchmark.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,8 @@ private double calculateRecall(Set<Integer> predicted, int[] groundTruth, int k)
247247
int actualK = Math.min(k, Math.min(predicted.size(), groundTruth.length));
248248

249249
for (int i = 0; i < actualK; i++) {
250-
for (int j = 0; j < actualK; j++) {
251-
if (predicted.contains(groundTruth[j])) {
252-
hits++;
253-
break;
254-
}
250+
if (predicted.contains(groundTruth[i])) {
251+
hits++;
255252
}
256253
}
257254

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,12 +439,17 @@ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvi
439439
}
440440

441441
public ImmutableGraphIndex build(RandomAccessVectorValues ravv) {
442+
return build(ravv, null);
443+
}
444+
445+
public ImmutableGraphIndex build(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap) {
442446
var vv = ravv.threadLocalSupplier();
443447
int size = ravv.size();
444448

445449
simdExecutor.submit(() -> {
446450
IntStream.range(0, size).parallel().forEach(node -> {
447-
addGraphNode(node, vv.get().getVector(node));
451+
int ravvOrdinal = (graphToRavvOrdMap != null) ? graphToRavvOrdMap[node] : node;
452+
addGraphNode(node, vv.get().getVector(ravvOrdinal));
448453
});
449454
}).join();
450455

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import java.util.stream.Collectors;
4646
import java.util.stream.IntStream;
4747

48+
import static org.apache.commons.lang3.ArrayUtils.shuffle;
4849
import static org.junit.Assert.assertEquals;
4950

5051
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
@@ -61,6 +62,7 @@ public class OnHeapGraphIndexTest extends RandomizedTest {
6162
private static final float NEIGHBOR_OVERFLOW = 1.2f;
6263
private static final boolean ADD_HIERARCHY = false;
6364
private static final int TOP_K = 10;
65+
private static VectorSimilarityFunction SIMILARITY_FUNCTION = VectorSimilarityFunction.EUCLIDEAN;
6466

6567
private Path testDirectory;
6668

@@ -71,6 +73,7 @@ public class OnHeapGraphIndexTest extends RandomizedTest {
7173
private RandomAccessVectorValues newVectorsRavv;
7274
private RandomAccessVectorValues allVectorsRavv;
7375
private VectorFloat<?> queryVector;
76+
private int[] groundTruthBaseVectors;
7477
private int[] groundTruthAllVectors;
7578
private BuildScoreProvider baseBuildScoreProvider;
7679
private BuildScoreProvider allBuildScoreProvider;
@@ -100,11 +103,12 @@ public void setup() throws IOException {
100103
allVectorsRavv = new ListRandomAccessVectorValues(allVectors, DIMENSION);
101104

102105
queryVector = createRandomVector(DIMENSION);
103-
groundTruthAllVectors = getGroundTruth(allVectorsRavv, queryVector, TOP_K, VectorSimilarityFunction.EUCLIDEAN);
106+
groundTruthBaseVectors = getGroundTruth(baseVectorsRavv, queryVector, TOP_K, SIMILARITY_FUNCTION);
107+
groundTruthAllVectors = getGroundTruth(allVectorsRavv, queryVector, TOP_K, SIMILARITY_FUNCTION);
104108

105109
// score provider using the raw, in-memory vectors
106-
baseBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(baseVectorsRavv, VectorSimilarityFunction.EUCLIDEAN);
107-
allBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(allVectorsRavv, VectorSimilarityFunction.EUCLIDEAN);
110+
baseBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(baseVectorsRavv, SIMILARITY_FUNCTION);
111+
allBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(allVectorsRavv, SIMILARITY_FUNCTION);
108112
var baseGraphIndexBuilder = new GraphIndexBuilder(baseBuildScoreProvider,
109113
baseVectorsRavv.dimension(),
110114
M, // graph degree
@@ -129,14 +133,38 @@ public void tearDown() {
129133
TestUtil.deleteQuietly(testDirectory);
130134
}
131135

136+
/**
137+
* Test that we can build a graph with a non-identity mapping from graph node id to ravv ordinal
138+
* and that the recall is the same as the identity mapping (meaning the graphs are equivalent)
139+
* @throws IOException exception
140+
*/
141+
@Test
142+
public void testGraphConstructionWithNonIdentityOrdinalMapping() throws IOException {
143+
// create reversed mapping from graph node id to ravv ordinal
144+
int[] graphToRavvOrdMap = IntStream.range(0, baseVectorsRavv.size()).map(i -> baseVectorsRavv.size() - 1 - i).toArray();
145+
var bsp = BuildScoreProvider.randomAccessScoreProvider(baseVectorsRavv, graphToRavvOrdMap, SIMILARITY_FUNCTION);
146+
try (var baseGraphIndexBuilder = new GraphIndexBuilder(bsp,
147+
baseVectorsRavv.dimension(),
148+
M, // graph degree
149+
BEAM_WIDTH, // construction search depth
150+
NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor
151+
ALPHA, // relax neighbor diversity requirement by this factor
152+
ADD_HIERARCHY); // add the hierarchy) {
153+
var baseGraphIndexFromShuffledVectors = baseGraphIndexBuilder.build(baseVectorsRavv, graphToRavvOrdMap)) {
154+
float recallFromBaseGraphIndexFromShuffledVectors = calculateRecall(baseGraphIndexFromShuffledVectors, bsp, queryVector, groundTruthBaseVectors, TOP_K, graphToRavvOrdMap);
155+
float recallFromBaseGraphIndex = calculateRecall(baseGraphIndex, baseBuildScoreProvider, queryVector, groundTruthBaseVectors, TOP_K);
156+
Assert.assertEquals(recallFromBaseGraphIndex, recallFromBaseGraphIndexFromShuffledVectors, 0.01f);
157+
}
158+
}
132159

133160
/**
134161
* Create an {@link OnHeapGraphIndex} persist it as a {@link OnDiskGraphIndex} and reconstruct back to a mutable {@link OnHeapGraphIndex}
162+
* Using identity mapping from graph node id to ravv ordinal
135163
* Make sure that both graphs are equivalent
136164
* @throws IOException
137165
*/
138166
@Test
139-
public void testReconstructionOfOnHeapGraphIndex() throws IOException {
167+
public void testReconstructionOfOnHeapGraphIndex_withIdentityOrdinalMapping() throws IOException {
140168
var graphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName());
141169
var heapGraphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName() + "_onHeap");
142170

@@ -166,6 +194,44 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException {
166194
}
167195
}
168196

197+
/**
198+
* Create an {@link OnHeapGraphIndex} persist it as a {@link OnDiskGraphIndex} and reconstruct back to a mutable {@link OnHeapGraphIndex}
199+
* Using random mapping from graph node id to ravv ordinal
200+
* Make sure that both graphs are equivalent
201+
* @throws IOException
202+
*/
203+
@Test
204+
public void testReconstructionOfOnHeapGraphIndex_withNonIdentityOrdinalMapping() throws IOException {
205+
var graphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName());
206+
var heapGraphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName() + "_onHeap");
207+
208+
// create reversed mapping from graph node id to ravv ordinal
209+
int[] graphToRavvOrdMap = IntStream.range(0, baseVectorsRavv.size()).map(i -> baseVectorsRavv.size() - 1 - i).toArray();
210+
var bsp = BuildScoreProvider.randomAccessScoreProvider(baseVectorsRavv, graphToRavvOrdMap, SIMILARITY_FUNCTION);
211+
try (var baseGraphIndexBuilder = new GraphIndexBuilder(bsp,
212+
baseVectorsRavv.dimension(),
213+
M, // graph degree
214+
BEAM_WIDTH, // construction search depth
215+
NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor
216+
ALPHA, // relax neighbor diversity requirement by this factor
217+
ADD_HIERARCHY); // add the hierarchy) {
218+
var baseGraphIndex = baseGraphIndexBuilder.build(baseVectorsRavv, graphToRavvOrdMap)) {
219+
log.info("Writing graph to {}", graphOutputPath);
220+
TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, graphOutputPath);
221+
222+
log.info("Writing on-heap graph to {}", heapGraphOutputPath);
223+
try (SimpleWriter writer = new SimpleWriter(heapGraphOutputPath.toAbsolutePath())) {
224+
((OnHeapGraphIndex) baseGraphIndex).save(writer);
225+
}
226+
227+
log.info("Reading on-heap graph from {}", heapGraphOutputPath);
228+
try (var readerSupplier = new SimpleMappedReader.Supplier(heapGraphOutputPath.toAbsolutePath())) {
229+
MutableGraphIndex reconstructedOnHeapGraphIndex = OnHeapGraphIndex.load(readerSupplier.get(), baseVectorsRavv.dimension(), NEIGHBOR_OVERFLOW, new VamanaDiversityProvider(bsp, ALPHA));
230+
TestUtil.assertGraphEquals(baseGraphIndex, reconstructedOnHeapGraphIndex);
231+
}
232+
}
233+
}
234+
169235
/**
170236
* Create {@link OnDiskGraphIndex} then append to it via {@link GraphIndexBuilder#buildAndMergeNewNodes}
171237
* Verify that the resulting OnHeapGraphIndex is equivalent to the graph that would have been alternatively generated by bulk index into a new {@link OnDiskGraphIndex}
@@ -230,10 +296,27 @@ private static int[] getGroundTruth(RandomAccessVectorValues ravv, VectorFloat<?
230296
}
231297

232298
private static float calculateRecall(ImmutableGraphIndex graphIndex, BuildScoreProvider buildScoreProvider, VectorFloat<?> queryVector, int[] groundTruth, int k) throws IOException {
299+
return calculateRecall(graphIndex, buildScoreProvider, queryVector, groundTruth, k, null);
300+
}
301+
302+
private static float calculateRecall(ImmutableGraphIndex graphIndex, BuildScoreProvider buildScoreProvider, VectorFloat<?> queryVector, int[] groundTruth, int k, int[] graphToRavvOrdMap) throws IOException {
233303
try (GraphSearcher graphSearcher = new GraphSearcher(graphIndex)){
234304
SearchScoreProvider ssp = buildScoreProvider.searchProviderFor(queryVector);
235305
var searchResults = graphSearcher.search(ssp, k, Bits.ALL);
236-
var predicted = Arrays.stream(searchResults.getNodes()).mapToInt(nodeScore -> nodeScore.node).boxed().collect(Collectors.toSet());
306+
Set<Integer> predicted;
307+
if (graphToRavvOrdMap != null) {
308+
// Convert graph node IDs to RAVV ordinals for comparison with ground truth
309+
predicted = Arrays.stream(searchResults.getNodes())
310+
.mapToInt(nodeScore -> graphToRavvOrdMap[nodeScore.node])
311+
.boxed()
312+
.collect(Collectors.toSet());
313+
} else {
314+
// Identity mapping: graph node IDs == RAVV ordinals
315+
predicted = Arrays.stream(searchResults.getNodes())
316+
.mapToInt(nodeScore -> nodeScore.node)
317+
.boxed()
318+
.collect(Collectors.toSet());
319+
}
237320
return calculateRecall(predicted, groundTruth, k);
238321
}
239322
}
@@ -249,11 +332,8 @@ private static float calculateRecall(Set<Integer> predicted, int[] groundTruth,
249332
int actualK = Math.min(k, Math.min(predicted.size(), groundTruth.length));
250333

251334
for (int i = 0; i < actualK; i++) {
252-
for (int j = 0; j < actualK; j++) {
253-
if (predicted.contains(groundTruth[j])) {
254-
hits++;
255-
break;
256-
}
335+
if (predicted.contains(groundTruth[i])) {
336+
hits++;
257337
}
258338
}
259339

0 commit comments

Comments
 (0)