18
18
19
19
import com .carrotsearch .randomizedtesting .annotations .ThreadLeakScope ;
20
20
import io .github .jbellis .jvector .LuceneTestCase ;
21
+ import io .github .jbellis .jvector .TestUtil ;
22
+ import io .github .jbellis .jvector .disk .OnDiskGraphIndex ;
23
+ import io .github .jbellis .jvector .disk .SimpleMappedReader ;
24
+ import io .github .jbellis .jvector .pq .PQVectors ;
25
+ import io .github .jbellis .jvector .pq .ProductQuantization ;
21
26
import io .github .jbellis .jvector .util .Bits ;
22
27
import io .github .jbellis .jvector .vector .VectorEncoding ;
23
28
import io .github .jbellis .jvector .vector .VectorSimilarityFunction ;
24
29
import org .junit .Test ;
25
30
31
+ import java .io .IOException ;
32
+ import java .nio .file .Files ;
33
+ import java .nio .file .Path ;
26
34
import java .util .Arrays ;
27
35
import java .util .List ;
28
36
29
37
@ ThreadLeakScope (ThreadLeakScope .Scope .NONE )
30
38
public class Test2DThreshold extends LuceneTestCase {
31
39
@ Test
32
- public void testThreshold () {
40
+ public void testThreshold () throws IOException {
33
41
var R = getRandom ();
34
42
// generate 2D vectors
35
43
float [][] vectors = new float [10000 ][2 ];
@@ -40,9 +48,10 @@ public void testThreshold() {
40
48
41
49
var ravv = new ListRandomAccessVectorValues (List .of (vectors ), 2 );
42
50
var builder = new GraphIndexBuilder <>(ravv , VectorEncoding .FLOAT32 , VectorSimilarityFunction .EUCLIDEAN , 6 , 32 , 1.2f , 1.4f );
43
- var graph = builder .build ();
44
- var searcher = new GraphSearcher .Builder <>(graph .getView ()).build ();
51
+ var onHeapGraph = builder .build ();
45
52
53
+ // test raw vectors
54
+ var searcher = new GraphSearcher .Builder <>(onHeapGraph .getView ()).build ();
46
55
for (int i = 0 ; i < 10 ; i ++) {
47
56
TestParams tp = createTestParams (vectors );
48
57
@@ -52,6 +61,27 @@ public void testThreshold() {
52
61
assert result .getVisitedCount () < vectors .length : "visited all vectors for threshold " + tp .th ;
53
62
assert result .getNodes ().length >= 0.9 * tp .exactCount : "returned " + result .getNodes ().length + " nodes for threshold " + tp .th + " but should have returned at least " + tp .exactCount ;
54
63
}
64
+
65
+ // test compressed
66
+ Path outputPath = Files .createTempFile ("graph" , ".jvector" );
67
+ TestUtil .writeGraph (onHeapGraph , ravv , outputPath );
68
+ var pq = ProductQuantization .compute (ravv , ravv .dimension (), false );
69
+ var cv = new PQVectors (pq , pq .encodeAll (List .of (vectors )));
70
+
71
+ try (var marr = new SimpleMappedReader (outputPath .toAbsolutePath ().toString ());
72
+ var onDiskGraph = new OnDiskGraphIndex <float []>(marr ::duplicate , 0 ))
73
+ {
74
+ for (int i = 0 ; i < 10 ; i ++) {
75
+ TestParams tp = createTestParams (vectors );
76
+ searcher = new GraphSearcher .Builder <>(onDiskGraph .getView ()).build ();
77
+ NodeSimilarity .ReRanker <float []> reranker = (j , map ) -> VectorSimilarityFunction .EUCLIDEAN .compare (tp .q , map .get (j ));
78
+ var asf = cv .approximateScoreFunctionFor (tp .q , VectorSimilarityFunction .EUCLIDEAN );
79
+ var result = searcher .search (asf , reranker , vectors .length , tp .th , Bits .ALL );
80
+
81
+ assert result .getVisitedCount () < vectors .length : "visited all vectors for threshold " + tp .th ;
82
+ }
83
+ }
84
+
55
85
}
56
86
57
87
// it's not an interesting test if all the vectors are within the threshold
0 commit comments