Skip to content

Commit 3b6bfe6

Browse files
authored
vectorsEncountered not always in sync with resultQueue, causing NPE when breaking out of loop due to threshold probability (#150)
1 parent 0b93065 commit 3b6bfe6

File tree

2 files changed

+49
-19
lines changed

2 files changed

+49
-19
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -189,20 +189,15 @@ SearchResult searchInternal(NodeSimilarity.ScoreFunction scoreFunction,
189189
visited.set(ep);
190190
numVisited++;
191191
candidates.push(ep, score);
192-
if (acceptOrds.get(ep) && score >= threshold) {
193-
resultsQueue.push(ep, score);
194-
}
195192

196193
// A bound that holds the minimum similarity to the query vector that a candidate vector must
197194
// have to be considered.
198195
float minAcceptedSimilarity = Float.NEGATIVE_INFINITY;
199-
if (resultsQueue.size() >= topK) {
200-
minAcceptedSimilarity = resultsQueue.topScore();
201-
}
202196

203197
while (candidates.size() > 0 && !resultsQueue.incomplete()) {
204-
// get the best candidate (closest or best scoring)
205-
if (candidates.topScore() < minAcceptedSimilarity) {
198+
// done when best candidate is worse than the worst result so far
199+
float topCandidateScore = candidates.topScore();
200+
if (topCandidateScore < minAcceptedSimilarity) {
206201
break;
207202
}
208203

@@ -211,10 +206,21 @@ SearchResult searchInternal(NodeSimilarity.ScoreFunction scoreFunction,
211206
break;
212207
}
213208

209+
// add the top candidate to the resultset
214210
int topCandidateNode = candidates.pop();
215-
if (!scoreFunction.isExact()) {
216-
vectorsEncountered.put(topCandidateNode, view.getVector(topCandidateNode));
211+
if (acceptOrds.get(topCandidateNode)
212+
&& topCandidateScore >= threshold
213+
&& resultsQueue.push(topCandidateNode, topCandidateScore))
214+
{
215+
if (resultsQueue.size() >= topK) {
216+
minAcceptedSimilarity = resultsQueue.topScore();
217+
}
218+
if (!scoreFunction.isExact()) {
219+
vectorsEncountered.put(topCandidateNode, view.getVector(topCandidateNode));
220+
}
217221
}
222+
223+
// add its neighbors to the candidates queue
218224
for (var it = view.getNeighborsIterator(topCandidateNode); it.hasNext(); ) {
219225
int friendOrd = it.nextInt();
220226
if (visited.getAndSet(friendOrd)) {
@@ -224,14 +230,8 @@ SearchResult searchInternal(NodeSimilarity.ScoreFunction scoreFunction,
224230

225231
float friendSimilarity = scoreFunction.similarityTo(friendOrd);
226232
scoreTracker.track(friendSimilarity);
227-
228233
if (friendSimilarity >= minAcceptedSimilarity) {
229234
candidates.push(friendOrd, friendSimilarity);
230-
if (acceptOrds.get(friendOrd) && friendSimilarity >= threshold) {
231-
if (resultsQueue.push(friendOrd, friendSimilarity) && resultsQueue.size() >= topK) {
232-
minAcceptedSimilarity = resultsQueue.topScore();
233-
}
234-
}
235235
}
236236
}
237237
}

jvector-tests/src/test/java/io/github/jbellis/jvector/graph/Test2DThreshold.java

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,26 @@
1818

1919
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
2020
import io.github.jbellis.jvector.LuceneTestCase;
21+
import io.github.jbellis.jvector.TestUtil;
22+
import io.github.jbellis.jvector.disk.OnDiskGraphIndex;
23+
import io.github.jbellis.jvector.disk.SimpleMappedReader;
24+
import io.github.jbellis.jvector.pq.PQVectors;
25+
import io.github.jbellis.jvector.pq.ProductQuantization;
2126
import io.github.jbellis.jvector.util.Bits;
2227
import io.github.jbellis.jvector.vector.VectorEncoding;
2328
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
2429
import org.junit.Test;
2530

31+
import java.io.IOException;
32+
import java.nio.file.Files;
33+
import java.nio.file.Path;
2634
import java.util.Arrays;
2735
import java.util.List;
2836

2937
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
3038
public class Test2DThreshold extends LuceneTestCase {
3139
@Test
32-
public void testThreshold() {
40+
public void testThreshold() throws IOException {
3341
var R = getRandom();
3442
// generate 2D vectors
3543
float[][] vectors = new float[10000][2];
@@ -40,9 +48,10 @@ public void testThreshold() {
4048

4149
var ravv = new ListRandomAccessVectorValues(List.of(vectors), 2);
4250
var builder = new GraphIndexBuilder<>(ravv, VectorEncoding.FLOAT32, VectorSimilarityFunction.EUCLIDEAN, 6, 32, 1.2f, 1.4f);
43-
var graph = builder.build();
44-
var searcher = new GraphSearcher.Builder<>(graph.getView()).build();
51+
var onHeapGraph = builder.build();
4552

53+
// test raw vectors
54+
var searcher = new GraphSearcher.Builder<>(onHeapGraph.getView()).build();
4655
for (int i = 0; i < 10; i++) {
4756
TestParams tp = createTestParams(vectors);
4857

@@ -52,6 +61,27 @@ public void testThreshold() {
5261
assert result.getVisitedCount() < vectors.length : "visited all vectors for threshold " + tp.th;
5362
assert result.getNodes().length >= 0.9 * tp.exactCount : "returned " + result.getNodes().length + " nodes for threshold " + tp.th + " but should have returned at least " + tp.exactCount;
5463
}
64+
65+
// test compressed
66+
Path outputPath = Files.createTempFile("graph", ".jvector");
67+
TestUtil.writeGraph(onHeapGraph, ravv, outputPath);
68+
var pq = ProductQuantization.compute(ravv, ravv.dimension(), false);
69+
var cv = new PQVectors(pq, pq.encodeAll(List.of(vectors)));
70+
71+
try (var marr = new SimpleMappedReader(outputPath.toAbsolutePath().toString());
72+
var onDiskGraph = new OnDiskGraphIndex<float[]>(marr::duplicate, 0))
73+
{
74+
for (int i = 0; i < 10; i++) {
75+
TestParams tp = createTestParams(vectors);
76+
searcher = new GraphSearcher.Builder<>(onDiskGraph.getView()).build();
77+
NodeSimilarity.ReRanker<float[]> reranker = (j, map) -> VectorSimilarityFunction.EUCLIDEAN.compare(tp.q, map.get(j));
78+
var asf = cv.approximateScoreFunctionFor(tp.q, VectorSimilarityFunction.EUCLIDEAN);
79+
var result = searcher.search(asf, reranker, vectors.length, tp.th, Bits.ALL);
80+
81+
assert result.getVisitedCount() < vectors.length : "visited all vectors for threshold " + tp.th;
82+
}
83+
}
84+
5585
}
5686

5787
// it's not an interesting test if all the vectors are within the threshold

0 commit comments

Comments
 (0)