Skip to content

Commit 02feabd

Browse files
authored
Add binary quantization (#135)
* extract CompressedVectors interface, implemented by PQVectors * BQ implementation * introduce VectorComporessor interface so Bench can mix PQ and BQ in the same grid * remove diskGrid/useDisk so we only do the uncompressed queries once
1 parent a1af7ce commit 02feabd

29 files changed

+773
-211
lines changed

README.md

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,16 @@ JVector implements [DiskANN](https://suhasjs.github.io/files/diskann_neurips19.p
6262
search, meaning that vectors can be compressed using product quantization so that searches
6363
can be performed using the compressed representation that is kept in memory. You can enable
6464
this with the following steps:
65-
- Create a [`ProductQuantization`](./jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java) object with your vectors using `ProductQuantization.compute`. This will take some time
66-
to compute the codebooks.
67-
- Use `ProductQuantization::encode` or `encodeAll` to encode your vectors.
68-
- Create a [`CompressedVectors`](./jvector-base/src/main/java/io/github/jbellis/jvector/disk/CompressedVectors.java) object from the encoded vectors.
69-
- Create a [`NeighborSimilarity.ApproximateScoreFunction`](./jvector-base/src/main/java/io/github/jbellis/jvector/graph/NeighborSimilarity.java) for your query that uses the
70-
`ProductQuantization` object and `CompressedVectors` to compute scores, and pass this
65+
- Create a [`VectorCompressor`](./jvector-base/src/main/java/io/github/jbellis/jvector/pq/VectorCompressor.java) object with your vectors using either `ProductQuantization.compute`
66+
- or `BinaryQuantization.compute`. PQ is more flexible than BQ and is less lossy: even at the same compressed size,
67+
in the datasets tested by Bench, only the ada002 vectors in the wikipedia dataset
68+
are large enough and/or overparameterized enough to benefit from BQ while achieving recall
69+
competitive with PQ. However, if you are dealing with very large vectors and/or your
70+
recall requirement is not strict, you may still want to try BQ since it is MUCH faster to both compute and search with.
71+
- Use `VectorCompressor::encode` or `encodeAll` to encode your vectors, then call
72+
`VectorCompressor::createCompressedVectors` to create a `CompressedVectors` object.
73+
- Call `CompressedVectors::approximateScoreFunctionFor` to create a [`NeighborSimilarity.ApproximateScoreFunction`](./jvector-base/src/main/java/io/github/jbellis/jvector/graph/NeighborSimilarity.java) for your query that uses the
74+
compressed vectors to accelerate search, and pass this
7175
to the `GraphSearcher.search` method.
7276

7377
## Saving and loading indexes
@@ -77,7 +81,7 @@ this with the following steps:
7781
implementation of [`RandomAccessReader`](./jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessReader.java) and the related `ReaderSupplier` to wrap your
7882
preferred i/o class for best performance. See `SimpleMappedReader` and `SimpleMappedReaderSupplier` for an example.
7983
- Building a graph does not technically require your RandomAccessVectorValues object
80-
to live in memory, but it will perform much better if it does. OnDiskGraphIndex,
84+
to live in memory, but it will perform much better if it does. `OnDiskGraphIndex`,
8185
by contrast, is designed to live on disk and use minimal memory otherwise.
8286
- You can optionally wrap `OnDiskGraphIndex` in a [`CachingGraphIndex`](./jvector-base/src/main/java/io/github/jbellis/jvector/disk/CachingGraphIndex.java) to keep the most commonly accessed
8387
nodes (the ones nearest to the graph entry point) in memory.

jvector-base/src/main/java/io/github/jbellis/jvector/disk/Io.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package io.github.jbellis.jvector.disk;
1818

19-
import java.io.DataInput;
2019
import java.io.DataOutput;
2120
import java.io.IOException;
2221

@@ -26,12 +25,4 @@ public static void writeFloats(DataOutput out, float[] v) throws IOException {
2625
out.writeFloat(a);
2726
}
2827
}
29-
30-
public static float[] readFloats(DataInput in, int size) throws IOException {
31-
var v = new float[size];
32-
for (int i = 0; i < size; i++) {
33-
v[i] = in.readFloat();
34-
}
35-
return v;
36-
}
3728
}

jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessReader.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ public interface RandomAccessReader extends AutoCloseable {
3131

3232
void readFully(float[] floats) throws IOException;
3333

34+
void readFully(long[] vector) throws IOException;
35+
3436
void read(int[] ints, int offset, int count) throws IOException;
3537

3638
void close() throws IOException;

jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReader.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@ public void readFully(byte[] b) {
8383
mbb.get(b);
8484
}
8585

86+
@Override
87+
public void readFully(long[] vector) throws IOException {
88+
for (int i = 0; i < vector.length; i++) {
89+
vector[i] = mbb.getLong();
90+
}
91+
}
92+
8693
@Override
8794
public int readInt() {
8895
return mbb.getInt();

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ SearchResult searchInternal(
171171
}
172172
while (candidates.size() > 0 && !resultsQueue.incomplete()) {
173173
// get the best candidate (closest or best scoring)
174-
float topCandidateSimilarity = candidates.topScore();
175-
if (topCandidateSimilarity < minAcceptedSimilarity) {
174+
if (candidates.topScore() < minAcceptedSimilarity) {
176175
break;
177176
}
178177

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.github.jbellis.jvector.pq;
18+
19+
import io.github.jbellis.jvector.disk.RandomAccessReader;
20+
import io.github.jbellis.jvector.graph.NeighborSimilarity;
21+
import io.github.jbellis.jvector.util.RamUsageEstimator;
22+
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
23+
import io.github.jbellis.jvector.vector.VectorUtil;
24+
25+
import java.io.DataOutput;
26+
import java.io.IOException;
27+
import java.util.Arrays;
28+
import java.util.Objects;
29+
30+
public class BQVectors implements CompressedVectors {
31+
private final BinaryQuantization bq;
32+
private final long[][] compressedVectors;
33+
34+
public BQVectors(BinaryQuantization bq, long[][] compressedVectors) {
35+
this.bq = bq;
36+
this.compressedVectors = compressedVectors;
37+
}
38+
39+
@Override
40+
public void write(DataOutput out) throws IOException {
41+
// BQ centering data
42+
bq.write(out);
43+
44+
// compressed vectors
45+
out.writeInt(compressedVectors.length);
46+
if (compressedVectors.length <= 0) {
47+
return;
48+
}
49+
out.writeInt(compressedVectors[0].length);
50+
for (var v : compressedVectors) {
51+
for (int i = 0; i < v.length; i++) {
52+
out.writeLong(v[i]);
53+
}
54+
}
55+
}
56+
57+
public static BQVectors load(RandomAccessReader in, int offset) throws IOException {
58+
in.seek(offset);
59+
60+
// BQ
61+
var bq = BinaryQuantization.load(in);
62+
63+
// check validity of compressed vectors header
64+
int size = in.readInt();
65+
if (size < 0) {
66+
throw new IOException("Invalid compressed vector count " + size);
67+
}
68+
var compressedVectors = new long[size][];
69+
if (size == 0) {
70+
return new BQVectors(bq, compressedVectors);
71+
}
72+
int compressedLength = in.readInt();
73+
if (compressedLength < 0) {
74+
throw new IOException("Invalid compressed vector dimension " + compressedLength);
75+
}
76+
77+
// read the compressed vectors
78+
for (int i = 0; i < size; i++)
79+
{
80+
long[] vector = new long[compressedLength];
81+
in.readFully(vector);
82+
compressedVectors[i] = vector;
83+
}
84+
85+
return new BQVectors(bq, compressedVectors);
86+
}
87+
88+
@Override
89+
public NeighborSimilarity.ApproximateScoreFunction approximateScoreFunctionFor(float[] q, VectorSimilarityFunction similarityFunction) {
90+
var qBQ = bq.encode(q);
91+
return node2 -> {
92+
var vBQ = compressedVectors[node2];
93+
return 1 - (float) VectorUtil.hammingDistance(qBQ, vBQ) / q.length;
94+
};
95+
}
96+
97+
public long[] get(int i) {
98+
return compressedVectors[i];
99+
}
100+
101+
@Override
102+
public long ramBytesUsed() {
103+
return compressedVectors.length * RamUsageEstimator.sizeOf(compressedVectors[0]);
104+
}
105+
106+
@Override
107+
public boolean equals(Object o) {
108+
if (this == o) return true;
109+
if (o == null || getClass() != o.getClass()) return false;
110+
BQVectors bqVectors = (BQVectors) o;
111+
return Objects.equals(bq, bqVectors.bq) && Arrays.deepEquals(compressedVectors, bqVectors.compressedVectors);
112+
}
113+
114+
@Override
115+
public int hashCode() {
116+
int result = Objects.hash(bq);
117+
result = 31 * result + Arrays.deepHashCode(compressedVectors);
118+
return result;
119+
}
120+
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.github.jbellis.jvector.pq;
18+
19+
import io.github.jbellis.jvector.disk.Io;
20+
import io.github.jbellis.jvector.disk.RandomAccessReader;
21+
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
22+
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
23+
import io.github.jbellis.jvector.util.PoolingSupport;
24+
import io.github.jbellis.jvector.vector.VectorUtil;
25+
26+
import java.io.DataOutput;
27+
import java.io.IOException;
28+
import java.util.Arrays;
29+
import java.util.List;
30+
import java.util.concurrent.ThreadLocalRandom;
31+
import java.util.stream.Collectors;
32+
import java.util.stream.IntStream;
33+
34+
import static java.lang.Math.min;
35+
36+
/**
37+
* Binary Quantization of float vectors: each float is compressed to a single bit,
38+
* and similarity is computed with a simple Hamming distance.
39+
*/
40+
public class BinaryQuantization implements VectorCompressor<long[]> {
41+
private final float[] globalCentroid;
42+
43+
public BinaryQuantization(float[] globalCentroid) {
44+
this.globalCentroid = globalCentroid;
45+
}
46+
47+
public static BinaryQuantization compute(RandomAccessVectorValues<float[]> ravv) {
48+
// limit the number of vectors we train on
49+
var P = min(1.0f, ProductQuantization.MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size());
50+
var ravvCopy = ravv.isValueShared() ? PoolingSupport.newThreadBased(ravv::copy) : PoolingSupport.newNoPooling(ravv);
51+
var vectors = IntStream.range(0, ravv.size()).parallel()
52+
.filter(i -> ThreadLocalRandom.current().nextFloat() < P)
53+
.mapToObj(targetOrd -> {
54+
try (var pooledRavv = ravvCopy.get()) {
55+
var localRavv = pooledRavv.get();
56+
float[] v = localRavv.vectorValue(targetOrd);
57+
return localRavv.isValueShared() ? Arrays.copyOf(v, v.length) : v;
58+
}
59+
})
60+
.collect(Collectors.toList());
61+
62+
// compute the centroid of the training set
63+
float[] globalCentroid = KMeansPlusPlusClusterer.centroidOf(vectors);
64+
return new BinaryQuantization(globalCentroid);
65+
}
66+
67+
@Override
68+
public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
69+
return new BQVectors(this, (long[][]) compressedVectors);
70+
}
71+
72+
@Override
73+
public long[][] encodeAll(List<float[]> vectors) {
74+
return PhysicalCoreExecutor.instance.submit(() -> vectors.stream().parallel().map(this::encode).toArray(long[][]::new));
75+
}
76+
77+
/**
78+
* Encodes the input vector
79+
*
80+
* @return one bit per original f32
81+
*/
82+
@Override
83+
public long[] encode(float[] v) {
84+
var centered = VectorUtil.sub(v, globalCentroid);
85+
86+
int M = (int) Math.ceil(centered.length / 64.0);
87+
long[] encoded = new long[M];
88+
for (int i = 0; i < M; i++) {
89+
long bits = 0;
90+
for (int j = 0; j < 64; j++) {
91+
int idx = i * 64 + j;
92+
if (idx >= centered.length) {
93+
break;
94+
}
95+
if (centered[idx] > 0) {
96+
bits |= 1L << j;
97+
}
98+
}
99+
encoded[i] = bits;
100+
}
101+
return encoded;
102+
}
103+
104+
@Override
105+
public void write(DataOutput out) throws IOException {
106+
out.writeInt(globalCentroid.length);
107+
Io.writeFloats(out, globalCentroid);
108+
}
109+
110+
public static BinaryQuantization load(RandomAccessReader in) throws IOException {
111+
int length = in.readInt();
112+
var centroid = new float[length];
113+
in.readFully(centroid);
114+
return new BinaryQuantization(centroid);
115+
}
116+
117+
@Override
118+
public boolean equals(Object o) {
119+
if (this == o) return true;
120+
if (o == null || getClass() != o.getClass()) return false;
121+
BinaryQuantization that = (BinaryQuantization) o;
122+
return Arrays.equals(globalCentroid, that.globalCentroid);
123+
}
124+
125+
@Override
126+
public int hashCode() {
127+
return Arrays.hashCode(globalCentroid);
128+
}
129+
130+
@Override
131+
public String toString() {
132+
return "BinaryQuantization";
133+
}
134+
}

0 commit comments

Comments
 (0)