diff --git a/README.md b/README.md index 8fef5996b..5060e1aa0 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ The upper layers of the hierarchy are represented by an in-memory adjacency list The bottom layer of the graph is represented by an on-disk adjacency list per node. JVector uses additional data stored inline to support two-pass searches, with the first pass powered by lossily compressed representations of the vectors kept in memory, and the second by a more accurate representation read from disk. The first pass can be performed with * Product quantization (PQ), optionally with [anisotropic weighting](https://arxiv.org/abs/1908.10396) * [Binary quantization](https://huggingface.co/blog/embedding-quantization) (BQ) -* Fused ADC, where PQ codebooks are transposed and written inline with the graph adjacency list +* Fused PQ, where PQ codebooks are written inline with the graph adjacency list The second pass can be performed with * Full resolution float32 vectors @@ -265,13 +265,13 @@ Commentary: * Embeddings models produce output from a consistent distribution of vectors. This means that you can save and re-use ProductQuantization codebooks, even for a different set of vectors, as long as you had a sufficiently large training set to build it the first time around. ProductQuantization.MAX_PQ_TRAINING_SET_SIZE (128,000 vectors) has proven to be sufficiently large. * JDK ThreadLocal objects cannot be referenced except from the thread that created them. This is a difficult design into which to fit caching of Closeable objects like GraphSearcher. JVector provides the ExplicitThreadLocal class to solve this. -* Fused ADC is only compatible with Product Quantization, not Binary Quantization. This is no great loss since [very few models generate embeddings that are best suited for BQ](https://thenewstack.io/why-vector-size-matters/). That said, BQ continues to be supported with non-Fused indexes. +* Fused PQ is only compatible with Product Quantization, not Binary Quantization. This is no great loss since [very few models generate embeddings that are best suited for BQ](https://thenewstack.io/why-vector-size-matters/). That said, BQ continues to be supported with non-Fused indexes. * JVector heavily utilizes the Panama Vector API(SIMD) for ANN indexing and search. We have seen cases where the memory bandwidth is saturated during indexing and product quantization and can cause the process to slow down. To avoid this, the batch methods for index and PQ builds use a [PhysicalCoreExecutor](https://javadoc.io/doc/io.github.jbellis/jvector/latest/io/github/jbellis/jvector/util/PhysicalCoreExecutor.html) to limit the amount of operations to the physical core count. The default value is 1/2 the processor count seen by Java. This may not be correct in all setups (e.g. no hyperthreading or hybrid architectures) so if you wish to override the default use the `-Djvector.physical_core_count` property, or pass in your own ForkJoinPool instance. ### Advanced features -* Fused ADC is represented as a Feature that is supported during incremental index construction, like InlineVectors above. [See the Grid class for sample code](https://github.com/jbellis/jvector/blob/main/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java). +* Fused PQ is represented as a Feature that is supported during incremental index construction, like InlineVectors above. [See the Grid class for sample code](https://github.com/jbellis/jvector/blob/main/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java). * Anisotropic PQ is built into the ProductQuantization class and can improve recall, but nobody knows how to tune it (with the T/threshold parameter) except experimentally on a per-model basis, and choosing the wrong setting can make things worse. From Figure 3 in the paper: ![APQ performnce on Glove first improves and then degrades as T increases](https://github.com/jbellis/jvector/assets/42158/fd459222-6929-43ca-a405-ac34dbaf6646) @@ -285,7 +285,6 @@ Commentary: * [Anisotropic PQ paper](https://arxiv.org/abs/1908.10396) * [Quicker ADC paper](https://arxiv.org/abs/1812.09162) - ## Developing and Testing This project is organized as a [multimodule Maven build](https://maven.apache.org/guides/mini/guide-multiple-modules.html). The intent is to produce a multirelease jar suitable for use as a dependency from any Java 11 code. When run on a Java 20+ JVM with the Vector module enabled, optimized vector diff --git a/UPGRADING.md b/UPGRADING.md index 7dfbb317d..c99de606f 100644 --- a/UPGRADING.md +++ b/UPGRADING.md @@ -8,7 +8,15 @@ - Support for hierarchical graph indices. This new type of index blends HNSW and DiskANN in a novel way. An HNSW-like hierarchy resides in memory for quickly seeding the search. This also reduces the need for caching the DiskANN graph near the entrypoint. The base layer of the hierarchy is a DiskANN-like index and inherits its - properties. This hierarchical structure can be disabled, ending up with just the base DiskANN layer. + properties. This hierarchical structure can be disabled, ending up with just the base DiskANN layer. +- The feature previously known as Fused ADC has been renamed to Fused PQ. This feature allows to offload the PQ + codebooks from memory during search, storing them within the graph in a way that does not slow down the search. + Implementation notes: The implementation of this feature has been overhauled to not require native code acceleration. + This explores a design space allowing for packed representations of vectors fused into the graph in shapes optimal + for approximate score calculation. This new feature of graph indexes is opt-in but fully functional now. Any graph + degree limitations have been lifted. At this time, only 256-cluster ProductQuantization can use fused PQ. + Version 6 or greater of the file disk format is required to use this feature. + ## API changes - MemorySegmentReader.Supplier and SimpleMappedReader.Supplier must now be explicitly closed, instead of being @@ -20,7 +28,6 @@ we do early termination of the search. In certain cases, this can accelerate the search at the potential cost of some accuracy. It is set to false by default. - The constructors of GraphIndexBuilder allow to specify different maximum out-degrees for the graphs in each layer. - However, this feature does not work with FusedADC in this version. ### API changes in 3.0.6 diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java index 60a91c9ae..73cc5fbd5 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java @@ -396,7 +396,6 @@ void searchOneLayer(SearchScoreProvider scoreProvider, // track scores to predict when we are done with threshold queries var scoreTracker = scoreTrackerFactory.getScoreTracker(pruneSearch, rerankK, threshold); - VectorFloat similarities = null; // the main search loop while (candidates.size() > 0) { @@ -423,25 +422,12 @@ void searchOneLayer(SearchScoreProvider scoreProvider, // score the neighbors of the top candidate and add them to the queue var scoreFunction = scoreProvider.scoreFunction(); - var useEdgeLoading = scoreFunction.supportsEdgeLoadingSimilarity(); - if (useEdgeLoading) { - similarities = scoreFunction.edgeLoadingSimilarityTo(topCandidateNode); - } - int i = 0; - for (var it = view.getNeighborsIterator(level, topCandidateNode); it.hasNext(); ) { - var friendOrd = it.nextInt(); - if (!visited.add(friendOrd)) { - continue; - } + ImmutableGraphIndex.NeighborProcessor neighborProcessor = (node2, score) -> { + scoreTracker.track(score); + candidates.push(node2, score); visitedCount++; - - float friendSimilarity = useEdgeLoading - ? similarities.get(i) - : scoreFunction.similarityTo(friendOrd); - scoreTracker.track(friendSimilarity); - candidates.push(friendOrd, friendSimilarity); - i++; - } + }; + view.processNeighbors(level, topCandidateNode, scoreFunction, visited::add, neighborProcessor); } } catch (Throwable t) { // clear scratch structures if terminated via throwable, as they may not have been drained diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java index 0fc6d27f8..1d7ca4388 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ImmutableGraphIndex.java @@ -35,6 +35,7 @@ import java.io.Closeable; import java.io.IOException; +import java.util.function.Function; /** * Represents a graph-based vector index. Nodes are represented as ints, and edges are @@ -131,6 +132,24 @@ default boolean containsNode(int nodeId) { */ int size(int level); + /** + * The steps needed to process a neighbor during a search. That is, adding it to the priority queue, etc. + */ + interface NeighborProcessor { + void process(int friendOrd, float similarity); + } + + /** + * Serves as an abstract interface for marking nodes as visited + */ + @FunctionalInterface + interface IntMarker { + /** + * Marks the node and returns true if it was not marked previously. Returns false otherwise + */ + boolean mark(int value); + } + /** * Encapsulates the state of a graph for searching. Re-usable across search calls, * but each thread needs its own. @@ -142,6 +161,12 @@ interface View extends Closeable { */ NodesIterator getNeighborsIterator(int level, int node); + /** + * Iterates over the neighbors of a given node if they have not been visited yet. + * For each non-visited neighbor, it computes its similarity and processes it using the given processor. + */ + void processNeighbors(int level, int node, ScoreFunction scoreFunction, IntMarker visited, NeighborProcessor neighborProcessor); + /** * This method is deprecated as most View usages should not need size. * Where they do, they could access the graph. diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 711304f79..835821cf5 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -28,6 +28,7 @@ import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors; import io.github.jbellis.jvector.graph.diversity.DiversityProvider; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.util.Accountable; import io.github.jbellis.jvector.util.BitSet; import io.github.jbellis.jvector.util.Bits; @@ -48,6 +49,7 @@ import java.util.concurrent.atomic.AtomicIntegerArray; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.StampedLock; +import java.util.function.Function; import java.util.stream.IntStream; /** @@ -462,6 +464,17 @@ public NodesIterator getNeighborsIterator(int level, int node) { } + @Override + public void processNeighbors(int level, int node, ScoreFunction scoreFunction, IntMarker visited, NeighborProcessor neighborProcessor) { + for (var it = getNeighborsIterator(level, node); it.hasNext(); ) { + var friendOrd = it.nextInt(); + if (visited.mark(friendOrd)) { + float friendSimilarity = scoreFunction.similarityTo(friendOrd); + neighborProcessor.process(friendOrd, friendSimilarity); + } + } + } + @Override public int size() { return OnHeapGraphIndex.this.size(0); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java index 761024ff8..b4b98a7b1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/AbstractGraphIndexWriter.java @@ -18,11 +18,20 @@ import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.graph.ImmutableGraphIndex; -import io.github.jbellis.jvector.graph.disk.feature.*; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedFeature; +import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; +import io.github.jbellis.jvector.graph.disk.feature.NVQ; +import io.github.jbellis.jvector.graph.disk.feature.SeparatedFeature; +import io.github.jbellis.jvector.graph.disk.feature.SeparatedNVQ; +import io.github.jbellis.jvector.graph.disk.feature.SeparatedVectors; + import org.agrona.collections.Int2IntHashMap; import java.io.IOException; import java.util.EnumMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -38,20 +47,18 @@ public abstract class AbstractGraphIndexWriter implements final ImmutableGraphIndex graph; final OrdinalMapper ordinalMapper; final int dimension; - // we don't use Map features but EnumMap is the best way to make sure we don't - // accidentally introduce an ordering bug in the future - final EnumMap featureMap; + final Map featureMap; final T out; /* output for graph nodes and inline features */ final int headerSize; volatile int maxOrdinalWritten = -1; final List inlineFeatures; AbstractGraphIndexWriter(T out, - int version, - ImmutableGraphIndex graph, - OrdinalMapper oldToNewOrdinals, - int dimension, - EnumMap features) + int version, + ImmutableGraphIndex graph, + OrdinalMapper oldToNewOrdinals, + int dimension, + EnumMap features) { if (graph.getMaxLevel() > 0 && version < 4) { throw new IllegalArgumentException("Multilayer graphs must be written with version 4 or higher"); @@ -60,8 +67,28 @@ public abstract class AbstractGraphIndexWriter implements this.graph = graph; this.ordinalMapper = oldToNewOrdinals; this.dimension = dimension; - this.featureMap = features; - this.inlineFeatures = features.values().stream().filter(f -> !(f instanceof SeparatedFeature)).collect(Collectors.toList()); + + if (version <= 5) { + // Versions <= 5 use the old feature ordering, simply provided by the FeatureId + this.featureMap = features; + this.inlineFeatures = features.values().stream().filter(f -> !(f instanceof SeparatedFeature)).collect(Collectors.toList()); + } else { + // Version 6 uses the new feature ordering to place fused features last in the list + var sortedFeatures = features.values().stream().sorted().collect(Collectors.toList()); + this.featureMap = new LinkedHashMap<>(); + for (var feature : sortedFeatures) { + this.featureMap.put(feature.id(), feature); + } + this.inlineFeatures = sortedFeatures.stream().filter(f -> !(f instanceof SeparatedFeature)).sorted().collect(Collectors.toList()); + } + + long fusedFeaturesCount = this.inlineFeatures.stream().filter(Feature::isFused).count(); + if (fusedFeaturesCount > 1) { + throw new IllegalArgumentException("At most one fused feature is allowed"); + } + if (fusedFeaturesCount == 1 && version < 6) { + throw new IllegalArgumentException("Fused features require version 6 or higher"); + } this.out = out; // create a mock Header to determine the correct size @@ -164,7 +191,7 @@ public synchronized void writeHeader(ImmutableGraphIndex.View view, long startOf assert out.position() == startOffset + headerSize : String.format("%d != %d", out.position(), startOffset + headerSize); } - void writeSparseLevels(ImmutableGraphIndex.View view) throws IOException { + void writeSparseLevels(ImmutableGraphIndex.View view, Map> featureStateSuppliers) throws IOException { // write sparse levels for (int level = 1; level <= graph.getMaxLevel(); level++) { int layerSize = graph.size(level); @@ -193,6 +220,50 @@ void writeSparseLevels(ImmutableGraphIndex.View view) throws IOException { throw new IllegalStateException("Mismatch between layer size and nodes written"); } } + + // In V6, fused features for the in-memory hierarchy are written in a block after the top layers of the graph. + // Since everything in level 1 is also contained in the higher levels, we only need to write the fused features for level 1. + if (version == 6) { + // There should be only one fused feature per node. This is checked in the class constructor. + // This is the only place where we explicitly need the fused feature. If there are more places in the + // future, it may be worth having fusedFeature as class member. + FusedFeature fusedFeature = null; + for (var feature : inlineFeatures) { + if (feature.isFused()) { + fusedFeature = (FusedFeature) feature; + } + } + if (fusedFeature != null) { + var supplier = featureStateSuppliers.get(fusedFeature.id()); + if (supplier == null) { + throw new IllegalStateException("Supplier for feature " + fusedFeature.id() + " not found"); + } + + if (graph.getMaxLevel() >= 1) { + int level = 1; + int layerSize = graph.size(level); + int nodesWritten = 0; + for (var it = graph.getNodes(level); it.hasNext(); ) { + int originalOrdinal = it.nextInt(); + + // We write the ordinal (node id) so that we can map it to the corresponding feature + final int newOrdinal = ordinalMapper.oldToNew(originalOrdinal); + out.writeInt(newOrdinal); + fusedFeature.writeSourceFeature(out, supplier.apply(originalOrdinal)); + nodesWritten++; + } + if (nodesWritten != layerSize) { + throw new IllegalStateException("Mismatch between layer 1 size and features written"); + } + } else { + // Write the source feature of the entry node + final int originalEntryNode = view.entryNode().node; + final int entryNode = ordinalMapper.oldToNew(originalEntryNode); + out.writeInt(entryNode); + fusedFeature.writeSourceFeature(out, supplier.apply(originalEntryNode)); + } + } + } } void writeSeparatedFeatures(Map> featureStateSuppliers) throws IOException { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/Header.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/Header.java index 476a9f298..d80ba9698 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/Header.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/Header.java @@ -18,14 +18,14 @@ import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; -import io.github.jbellis.jvector.disk.RandomAccessWriter; import io.github.jbellis.jvector.graph.disk.feature.Feature; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; -import java.io.DataOutput; import java.io.IOException; import java.util.EnumMap; import java.util.EnumSet; +import java.util.LinkedHashMap; +import java.util.Map; /** * Header information for an on-disk graph index, containing both common metadata and feature-specific headers. @@ -40,9 +40,12 @@ */ class Header { final CommonHeader common; - final EnumMap features; - Header(CommonHeader common, EnumMap features) { + // In V6, it is important that the features map is sorted according to the feature order defined in AbstractFeature + // In V5 and older, these maps use the FeatureId as the sorting order + final Map features; + + Header(CommonHeader common, Map features) { this.common = common; this.features = features; } @@ -50,20 +53,35 @@ class Header { void write(IndexWriter out) throws IOException { common.write(out); - if (common.version >= 3) { - out.writeInt(FeatureId.serialize(EnumSet.copyOf(features.keySet()))); - } + if (common.version >= 6) { + // Writing the features in order instead of writing a single integer with all the features (as done in = 3) { + out.writeInt(FeatureId.serialize(EnumSet.copyOf(features.keySet()))); + } + + // we restrict pre-version-3 writers to INLINE_VECTORS features, so we don't need additional version-handling here + for (Feature writer : features.values()) { + writer.writeHeader(out); + } - // we restrict pre-version-3 writers to INLINE_VECTORS features, so we don't need additional version-handling here - for (Feature writer : features.values()) { - writer.writeHeader(out); } } public int size() { int size = common.size(); - if (common.version >= 3) { + if (common.version >= 6) { + // In V6, this accounts for the number of features and the ordinal of each feature + size += Integer.BYTES + features.size() * Integer.BYTES; + } else if (common.version >= 3) { size += Integer.BYTES; } @@ -75,17 +93,28 @@ public int size() { static Header load(RandomAccessReader reader, long offset) throws IOException { reader.seek(offset); - EnumSet featureIds; - EnumMap features = new EnumMap<>(FeatureId.class); + Map features; + CommonHeader common = CommonHeader.load(reader); - if (common.version >= 3) { - featureIds = FeatureId.deserialize(reader.readInt()); + if (common.version >= 6) { + features = new LinkedHashMap<>(); + int nFeatures = reader.readInt(); + for (int i = 0; i < nFeatures; i++) { + FeatureId featureId = FeatureId.values()[reader.readInt()]; + features.put(featureId, featureId.load(common, reader)); + } } else { - featureIds = EnumSet.of(FeatureId.INLINE_VECTORS); - } - - for (FeatureId featureId : featureIds) { - features.put(featureId, featureId.load(common, reader)); + EnumSet featureIds; + features = new EnumMap<>(FeatureId.class); + + if (common.version >= 3) { + featureIds = FeatureId.deserialize(reader.readInt()); + } else { + featureIds = EnumSet.of(FeatureId.INLINE_VECTORS); + } + for (FeatureId featureId : featureIds) { + features.put(featureId, featureId.load(common, reader)); + } } return new Header(common, features); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index 8f18ffcf4..3aee94e26 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -25,7 +25,8 @@ import io.github.jbellis.jvector.graph.disk.feature.Feature; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; import io.github.jbellis.jvector.graph.disk.feature.FeatureSource; -import io.github.jbellis.jvector.graph.disk.feature.FusedADC; +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; +import io.github.jbellis.jvector.graph.disk.feature.FusedFeature; import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; import io.github.jbellis.jvector.graph.disk.feature.NVQ; import io.github.jbellis.jvector.graph.disk.feature.SeparatedFeature; @@ -48,6 +49,8 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -66,7 +69,7 @@ public class OnDiskGraphIndex implements ImmutableGraphIndex, AutoCloseable, Accountable { private static final Logger logger = LoggerFactory.getLogger(OnDiskGraphIndex.class); - public static final int CURRENT_VERSION = 5; + public static final int CURRENT_VERSION = 6; static final int MAGIC = 0xFFFF0D61; // FFFF to distinguish from old graphs, which should never start with a negative size "ODGI" static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); final ReaderSupplier readerSupplier; @@ -75,15 +78,18 @@ public class OnDiskGraphIndex implements ImmutableGraphIndex, AutoCloseable, Acc final NodeAtLevel entryNode; final int idUpperBound; final int inlineBlockSize; // total size of all inline elements contributed by features - final EnumMap features; + final Map features; final EnumMap inlineOffsets; + private final List layerInfo; // offset of L0 adjacency data private final long neighborsOffset; - /** For layers > 0, store adjacency fully in memory. */ + // For layers > 0, store adjacency fully in memory. private final AtomicReference>> inMemoryNeighbors; + // When using fused features, store the features fully in memory for layers > 0 + private final AtomicReference> inMemoryFeatures; - OnDiskGraphIndex(ReaderSupplier readerSupplier, Header header, long neighborsOffset) + private OnDiskGraphIndex(ReaderSupplier readerSupplier, Header header, long neighborsOffset) { this.readerSupplier = readerSupplier; this.version = header.common.version; @@ -104,6 +110,7 @@ public class OnDiskGraphIndex implements ImmutableGraphIndex, AutoCloseable, Acc } this.inlineBlockSize = inlineBlockSize; inMemoryNeighbors = new AtomicReference<>(null); + inMemoryFeatures = new AtomicReference<>(null); } private List> getInMemoryLayers(RandomAccessReader in) throws IOException { @@ -123,8 +130,7 @@ private List> loadInMemoryLayers(RandomAccessReader in) var imn = new ArrayList>(layerInfo.size()); // For levels > 0, we load adjacency into memory imn.add(null); // L0 placeholder so we don't have to mangle indexing - long L0size = 0; - L0size = idUpperBound * (inlineBlockSize + Integer.BYTES * (1L + 1L + layerInfo.get(0).degree)); + long L0size = idUpperBound * (inlineBlockSize + Integer.BYTES * (1L + 1L + layerInfo.get(0).degree)); in.seek(neighborsOffset + L0size); for (int lvl = 1; lvl < layerInfo.size(); lvl++) { @@ -152,6 +158,70 @@ private List> loadInMemoryLayers(RandomAccessReader in) return imn; } + private Int2ObjectHashMap getInMemoryFeatures(RandomAccessReader in) throws IOException { + return inMemoryFeatures.updateAndGet(current -> { + if (current != null) { + return current; + } + // Only load the in-memory features if the graph is fused + for (var feature : features.values()) { + if (feature.isFused()) { + try { + return loadInMemoryFeatures(in); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } + return null; + }); + } + + private Int2ObjectHashMap loadInMemoryFeatures(RandomAccessReader in) throws IOException { + Int2ObjectHashMap hierarchyFeatures = new Int2ObjectHashMap<>(); + + long L0size = idUpperBound * (inlineBlockSize + Integer.BYTES * (1L + 1L + layerInfo.get(0).degree)); + long inMemorySize = 0; + for (int lvl = 1; lvl < layerInfo.size(); lvl++) { + CommonHeader.LayerInfo info = layerInfo.get(lvl); + inMemorySize += Integer.BYTES * info.size * (1L + 1L + info.degree); + } + in.seek(neighborsOffset + L0size + inMemorySize); + + // In V6, fused features for the in-memory hierarchy are written in a block after the top layers of the graph. + if (version == 6) { + if (layerInfo.size() >= 2) { + int level = 1; + CommonHeader.LayerInfo info = layerInfo.get(level); + for (int i = 0; i < info.size; i++) { + int nodeId = in.readInt(); + + // There should be only one fused feature per node. This is checked in AbstractGraphIndexWriter. + for (var feature : features.values()) { + if (feature.isFused()) { + var fusedFeature = (FusedFeature) feature; + var inlineSource = fusedFeature.loadSourceFeature(in); + hierarchyFeatures.put(nodeId, inlineSource); + } + } + } + } else { + // read the entry node + int nodeId = in.readInt(); + + // There should be only one fused feature per node. This is checked in AbstractGraphIndexWriter. + for (var feature : features.values()) { + if (feature.isFused()) { + var fusedFeature = (FusedFeature) feature; + var inlineSource = fusedFeature.loadSourceFeature(in); + hierarchyFeatures.put(nodeId, inlineSource); + } + } + } + } + return hierarchyFeatures; + } + /** * Load an index from the given reader supplier where header and graph are located on the same file, * where the index starts at `offset`. @@ -160,6 +230,19 @@ private List> loadInMemoryLayers(RandomAccessReader in) * @param offset the offset in bytes from the start of the file where the index starts. */ public static OnDiskGraphIndex load(ReaderSupplier readerSupplier, long offset) { + return load(readerSupplier, offset, true); + } + + /** + * Load an index from the given reader supplier where header and graph are located on the same file, + * where the index starts at `offset`. + * + * @param readerSupplier the reader supplier to use to read the graph and index. + * @param offset the offset in bytes from the start of the file where the index starts. + * @param useFooter whether to use the footer to load the index. + * @return the loaded index. + */ + public static OnDiskGraphIndex load(ReaderSupplier readerSupplier, long offset, boolean useFooter) { try (var reader = readerSupplier.get()) { logger.debug("Loading OnDiskGraphIndex from offset={}", offset); var header = Header.load(reader, offset); @@ -168,11 +251,14 @@ public static OnDiskGraphIndex load(ReaderSupplier readerSupplier, long offset) header.common.version, header.common.dimension, header.common.entryNode, header.common.layerInfo.size()); logger.debug("Position after reading header={}", reader.getPosition()); - if (header.common.version >= 5) { + if (header.common.version >= 5 && useFooter) { logger.debug("Version 5+ onwards uses a footer instead of header for metadata. Loading from footer"); return loadFromFooter(readerSupplier, reader.getPosition()); } else { - return new OnDiskGraphIndex(readerSupplier, header, reader.getPosition()); + var odgi = new OnDiskGraphIndex(readerSupplier, header, reader.getPosition()); + odgi.getInMemoryLayers(reader); + odgi.getInMemoryFeatures(reader); + return odgi; } } catch (Exception e) { throw new RuntimeException("Error initializing OnDiskGraph at offset " + offset, e); @@ -217,7 +303,11 @@ private static OnDiskGraphIndex loadFromFooter(ReaderSupplier readerSupplier, lo header.common.entryNode, header.common.layerInfo.size(), in.getPosition()); - return new OnDiskGraphIndex(readerSupplier, header, neighborsOffset); + var odgi = new OnDiskGraphIndex(readerSupplier, header, neighborsOffset); + odgi.getInMemoryLayers(in); + odgi.getInMemoryFeatures(in); + return odgi; + } catch (Exception e) { throw new RuntimeException("Error initializing OnDiskGraph", e); } @@ -227,6 +317,10 @@ public Set getFeatureSet() { return features.keySet(); } + public Map getFeatures() { + return features; + } + @Override public int getDimension() { return dimension; @@ -300,8 +394,19 @@ public NodesIterator getNodes(int level) { @Override public long ramBytesUsed() { + List> inMemoryNeighborsLocal = inMemoryNeighbors.get(); + + long inMemoryNeighborsBytes = RamUsageEstimator.NUM_BYTES_OBJECT_REF; + for (Int2ObjectHashMap neighbors : inMemoryNeighborsLocal) { + inMemoryNeighborsBytes += neighbors.values().stream().mapToLong(is -> Integer.BYTES * (long) is.length).sum(); + inMemoryNeighborsBytes += RamUsageEstimator.NUM_BYTES_OBJECT_REF; + } + long inMemoryFeaturesBytes = inMemoryFeatures.get().values().stream().mapToLong(is -> Integer.BYTES * is.ramBytesUsed()).sum(); + inMemoryFeaturesBytes += RamUsageEstimator.NUM_BYTES_OBJECT_REF; + return Long.BYTES + 6 * Integer.BYTES + RamUsageEstimator.NUM_BYTES_OBJECT_REF - + (long) 2 * RamUsageEstimator.NUM_BYTES_OBJECT_REF * FeatureId.values().length; + + (long) 2 * RamUsageEstimator.NUM_BYTES_OBJECT_REF * FeatureId.values().length + + inMemoryNeighborsBytes + inMemoryFeaturesBytes; } public void close() throws IOException { @@ -349,6 +454,7 @@ public double getAverageDegree(int level) { public class View implements FeatureSource, ScoringView, RandomAccessVectorValues { protected final RandomAccessReader reader; private final int[] neighbors; + private int nodeDegree; public View(RandomAccessReader reader) { this.reader = reader; @@ -441,26 +547,84 @@ public void getVectorInto(int node, VectorFloat vector, int offset) { public NodesIterator getNeighborsIterator(int level, int node) { try { + int[] stored; + if (level == 0) { // For layer 0, read from disk reader.seek(neighborsOffsetFor(level, node)); - int neighborCount = reader.readInt(); - assert neighborCount <= neighbors.length - : String.format("Node %d neighborCount %d > M %d", node, neighborCount, neighbors.length); - reader.read(neighbors, 0, neighborCount); - return new NodesIterator.ArrayNodesIterator(neighbors, neighborCount); + nodeDegree = reader.readInt(); + assert nodeDegree <= neighbors.length + : String.format("Node %d neighborCount %d > M %d", node, nodeDegree, neighbors.length); + reader.read(neighbors, 0, nodeDegree); + stored = neighbors; } else { // For levels > 0, read from memory var imn = getInMemoryLayers(reader); - int[] stored = imn.get(level).get(node); + stored = imn.get(level).get(node); + nodeDegree = stored.length; assert stored != null : String.format("No neighbors found for node %d at level %d", node, level); - return new NodesIterator.ArrayNodesIterator(stored, stored.length); + } + return new NodesIterator.ArrayNodesIterator(stored, nodeDegree); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public void getPackedNeighbors(int node, FeatureId featureId, Consumer featureConsumer) throws IOException { + Feature feature = features.get(featureId); + if (!feature.isFused()) { + throw new UnsupportedOperationException("Only fused features are supported with packed neighbors"); + } + + long offset = offsetFor(node, featureId); + reader.seek(offset); + featureConsumer.accept(reader); + + if (version < 6) { + reader.seek(neighborsOffsetFor(0, node)); + } + + nodeDegree = reader.readInt(); + assert nodeDegree <= neighbors.length + : String.format("Node %d neighborCount %d > M %d", node, nodeDegree, neighbors.length); + reader.read(neighbors, 0, nodeDegree); + + } + + public Int2ObjectHashMap getInlineSourceFeatures() { + try { + return OnDiskGraphIndex.this.getInMemoryFeatures(reader); } catch (IOException e) { throw new UncheckedIOException(e); } } + @Override + public void processNeighbors(int level, int node, ScoreFunction scoreFunction, IntMarker visited, NeighborProcessor neighborProcessor) { + var useEdgeLoading = scoreFunction.supportsSimilarityToNeighbors(); + if (useEdgeLoading && level == 0) { + scoreFunction.enableSimilarityToNeighbors(node); + + for (int i = 0; i < nodeDegree; i++) { + var friendOrd = neighbors[i]; + if (visited.mark(friendOrd)) { + float friendSimilarity = scoreFunction.similarityToNeighbor(node, i); + neighborProcessor.process(friendOrd, friendSimilarity); + } + } + } else { + var it = getNeighborsIterator(level, node); + while (it.hasNext()) { + var friendOrd = it.nextInt(); + if (visited.mark(friendOrd)) { + float friendSimilarity = scoreFunction.similarityTo(friendOrd); + neighborProcessor.process(friendOrd, friendSimilarity); + } + } + } + } + @Override public int size() { // For vector operations we only care about layer 0 @@ -515,8 +679,8 @@ public ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, @Override public ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat queryVector, VectorSimilarityFunction vsf) { - if (features.containsKey(FeatureId.FUSED_ADC)) { - return ((FusedADC) features.get(FeatureId.FUSED_ADC)).approximateScoreFunctionFor(queryVector, vsf, this, rerankerFor(queryVector, vsf)); + if (features.containsKey(FeatureId.FUSED_PQ)) { + return ((FusedPQ) features.get(FeatureId.FUSED_PQ)).approximateScoreFunctionFor(queryVector, vsf, this, rerankerFor(queryVector, vsf)); } else { throw new UnsupportedOperationException("No approximate score function available for this graph"); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java index a8515c191..6b210926e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java @@ -68,12 +68,12 @@ public class OnDiskGraphIndexWriter extends AbstractGraphIndexWriter features) + int version, + long startOffset, + ImmutableGraphIndex graph, + OrdinalMapper oldToNewOrdinals, + int dimension, + EnumMap features) { super(randomAccessWriter, version, graph, oldToNewOrdinals, dimension, features); this.startOffset = startOffset; @@ -213,7 +213,7 @@ public synchronized void write(Map> featur } // We will use the abstract method because no random access is needed - writeSparseLevels(view); + writeSparseLevels(view, featureStateSuppliers); // We will use the abstract method because no random access is needed writeSeparatedFeatures(featureStateSuppliers); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskSequentialGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskSequentialGraphIndexWriter.java index e7dd69476..c8a0bb832 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskSequentialGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskSequentialGraphIndexWriter.java @@ -152,7 +152,7 @@ public synchronized void write(Map> featur } } - writeSparseLevels(view); + writeSparseLevels(view, featureStateSuppliers); writeSeparatedFeatures(featureStateSuppliers); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/AbstractFeature.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/AbstractFeature.java new file mode 100644 index 000000000..e1fed2f54 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/AbstractFeature.java @@ -0,0 +1,26 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk.feature; + +public abstract class AbstractFeature implements Feature { + public int compareTo(Feature f) { + if (this.isFused() != f.isFused()) { + return Boolean.compare(this.isFused(), f.isFused()); + } + return this.id().compareTo(f.id()); + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/AbstractSeparatedFeature.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/AbstractSeparatedFeature.java new file mode 100644 index 000000000..8e5650b93 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/AbstractSeparatedFeature.java @@ -0,0 +1,26 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk.feature; + +public abstract class AbstractSeparatedFeature implements SeparatedFeature { + public int compareTo(Feature f) { + if (this.isFused() != f.isFused()) { + return Boolean.compare(this.isFused(), f.isFused()); + } + return this.id().compareTo(f.id()); + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java index a72ff10b6..b8c86581f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java @@ -24,9 +24,13 @@ /** * A feature of an on-disk graph index. Information to use a feature is stored in the header on-disk. */ -public interface Feature { +public interface Feature extends Comparable { FeatureId id(); + default boolean isFused() { + return false; + } + int headerSize(); int featureSize(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureId.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureId.java index dd0857834..131c4c8ee 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureId.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FeatureId.java @@ -19,9 +19,7 @@ import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.disk.CommonHeader; -import java.util.Collections; import java.util.EnumSet; -import java.util.Set; import java.util.function.BiFunction; /** @@ -32,13 +30,11 @@ */ public enum FeatureId { INLINE_VECTORS(InlineVectors::load), - FUSED_ADC(FusedADC::load), + FUSED_PQ(FusedPQ::load), NVQ_VECTORS(NVQ::load), SEPARATED_VECTORS(SeparatedVectors::load), SEPARATED_NVQ(SeparatedNVQ::load); - public static final Set ALL = Collections.unmodifiableSet(EnumSet.allOf(FeatureId.class)); - private final BiFunction loader; FeatureId(BiFunction loader) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedFeature.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedFeature.java new file mode 100644 index 000000000..deddb7d4f --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedFeature.java @@ -0,0 +1,41 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk.feature; + +import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.util.Accountable; + +import java.io.DataOutput; +import java.io.IOException; + +/** + * A fused feature is one that is computed from the neighbors of a node. + * - writeInline writes the fused features based on the neighbors of the node + * - writeSource writes the feature of the node itself + * Implements Quick ADC-style scoring by fusing PQ-encoded neighbors into an OnDiskGraphIndex. + */ +public interface FusedFeature extends Feature { + default boolean isFused() { + return true; + } + + void writeSourceFeature(DataOutput out, State state) throws IOException; + + interface InlineSource extends Accountable {} + + InlineSource loadSourceFeature(RandomAccessReader in) throws IOException; +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedPQ.java similarity index 51% rename from jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedPQ.java index 59ca11564..5fcf59293 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedADC.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedPQ.java @@ -21,7 +21,7 @@ import io.github.jbellis.jvector.graph.disk.CommonHeader; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; -import io.github.jbellis.jvector.quantization.FusedADCPQDecoder; +import io.github.jbellis.jvector.quantization.FusedPQDecoder; import io.github.jbellis.jvector.quantization.PQVectors; import io.github.jbellis.jvector.quantization.ProductQuantization; import io.github.jbellis.jvector.util.ExplicitThreadLocal; @@ -34,34 +34,37 @@ import java.io.DataOutput; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.function.IntFunction; /** - * Implements Quick ADC-style scoring by fusing PQ-encoded neighbors into an OnDiskGraphIndex. + * Implements scoring by fusing PQ-encoded neighbors into an OnDiskGraphIndex. */ -public class FusedADC implements Feature { +public class FusedPQ extends AbstractFeature implements FusedFeature { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); private final ProductQuantization pq; private final int maxDegree; private final ThreadLocal> reusableResults; - private final ExplicitThreadLocal> reusableNeighbors; - private ByteSequence compressedNeighbors = null; + private final ExplicitThreadLocal> reusableNeighborCodes; + private final ExplicitThreadLocal> pqCodeScratch; - public FusedADC(int maxDegree, ProductQuantization pq) { - if (maxDegree != 32) { - throw new IllegalArgumentException("maxDegree must be 32 for FusedADC. This limitation may be removed in future releases"); - } + public FusedPQ(int maxDegree, ProductQuantization pq) { if (pq.getClusterCount() != 256) { - throw new IllegalArgumentException("FusedADC requires a 256-cluster PQ. This limitation may be removed in future releases"); + throw new IllegalArgumentException("FusedPQ requires a 256-cluster PQ. This limitation may be removed in future releases"); } this.maxDegree = maxDegree; this.pq = pq; this.reusableResults = ThreadLocal.withInitial(() -> vectorTypeSupport.createFloatVector(maxDegree)); - this.reusableNeighbors = ExplicitThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(pq.compressedVectorSize() * maxDegree)); + this.reusableNeighborCodes = ExplicitThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(pq.compressedVectorSize() * maxDegree)); + this.pqCodeScratch = ExplicitThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(pq.compressedVectorSize())); } @Override public FeatureId id() { - return FeatureId.FUSED_ADC; + return FeatureId.FUSED_PQ; + } + + public ProductQuantization getPQ() { + return pq; } @Override @@ -74,18 +77,23 @@ public int featureSize() { return pq.compressedVectorSize() * maxDegree; } - static FusedADC load(CommonHeader header, RandomAccessReader reader) { - // TODO doesn't work with different degrees + static FusedPQ load(CommonHeader header, RandomAccessReader reader) { try { - return new FusedADC(header.layerInfo.get(0).degree, ProductQuantization.load(reader)); + return new FusedPQ(header.layerInfo.get(0).degree, ProductQuantization.load(reader)); } catch (IOException e) { throw new UncheckedIOException(e); } } + /** + * @param view The view needs to be the one used by the searcher + * @param esf + * @return + */ public ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat queryVector, VectorSimilarityFunction vsf, OnDiskGraphIndex.View view, ScoreFunction.ExactScoreFunction esf) { var neighbors = new PackedNeighbors(view); - return FusedADCPQDecoder.newDecoder(neighbors, pq, queryVector, reusableResults.get(), vsf, esf); + var hierarchyCachedFeatures = view.getInlineSourceFeatures(); + return FusedPQDecoder.newDecoder(neighbors, pq, hierarchyCachedFeatures, queryVector, reusableNeighborCodes.get(), reusableResults.get(), vsf, esf); } @Override @@ -97,38 +105,74 @@ public void writeHeader(DataOutput out) throws IOException { // generate the fused set based on the neighbors of the node, not just the node itself @Override public void writeInline(DataOutput out, Feature.State state_) throws IOException { - if (compressedNeighbors == null) { - compressedNeighbors = vectorTypeSupport.createByteSequence(pq.compressedVectorSize() * maxDegree); - } - var state = (FusedADC.State) state_; - var pqv = state.pqVectors; + var state = (FusedPQ.State) state_; var neighbors = state.view.getNeighborsIterator(0, state.nodeId); - int n = 0; - compressedNeighbors.zero(); + int count = 0; while (neighbors.hasNext()) { - var compressed = pqv.get(neighbors.nextInt()); - for (int j = 0; j < pqv.getCompressedSize(); j++) { - compressedNeighbors.set(j * maxDegree + n, compressed.get(j)); - } - n++; + int node = neighbors.nextInt(); + var compressed = state.compressedVectorFunction.apply(node); + vectorTypeSupport.writeByteSequence(out, compressed.copy()); + count++; + } + pqCodeScratch.get().zero(); + for (; count < maxDegree; count++) { + vectorTypeSupport.writeByteSequence(out, pqCodeScratch.get()); } - - vectorTypeSupport.writeByteSequence(out, compressedNeighbors); } public static class State implements Feature.State { public final ImmutableGraphIndex.View view; - public final PQVectors pqVectors; + public final IntFunction> compressedVectorFunction; public final int nodeId; public State(ImmutableGraphIndex.View view, PQVectors pqVectors, int nodeId) { + this(view, pqVectors::get, nodeId); + } + + public State(ImmutableGraphIndex.View view, IntFunction> compressedVectorFunction, int nodeId) { this.view = view; - this.pqVectors = pqVectors; + this.compressedVectorFunction = compressedVectorFunction; this.nodeId = nodeId; } } + @Override + public void writeSourceFeature(DataOutput out, Feature.State state_) throws IOException { + var state = (FusedPQ.State) state_; + var compressed = state.compressedVectorFunction.apply(state.nodeId); + var temp = pqCodeScratch.get(); + for (int i = 0; i < compressed.length(); i++) { + temp.set(i, compressed.get(i)); + } + vectorTypeSupport.writeByteSequence(out, temp); + } + + public static class FusedPQInlineSource implements InlineSource { + private ByteSequence code; + + public FusedPQInlineSource(ByteSequence code) { + this.code = code; + } + + @Override + public long ramBytesUsed() { + return code.length(); + } + + public ByteSequence getCode() { + return code; + } + } + + @Override + public InlineSource loadSourceFeature(RandomAccessReader in) throws IOException { + int length = pq.getSubspaceCount(); + var code = vectorTypeSupport.createByteSequence(length); + vectorTypeSupport.readByteSequence(in, code); + return new FusedPQInlineSource(code); + } + public class PackedNeighbors { private final OnDiskGraphIndex.View view; @@ -136,12 +180,17 @@ public PackedNeighbors(OnDiskGraphIndex.View view) { this.view = view; } - public ByteSequence getPackedNeighbors(int node) { + public void readInto(int node, ByteSequence neighborCodes) { try { - var reader = view.featureReaderForNode(node, FeatureId.FUSED_ADC); - var tlNeighbors = reusableNeighbors.get(); - vectorTypeSupport.readByteSequence(reader, tlNeighbors); - return tlNeighbors; + view.getPackedNeighbors(node, FeatureId.FUSED_PQ, + reader -> { + try { + vectorTypeSupport.readByteSequence(reader, neighborCodes); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + ); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java index 59e2b359c..0e80bd467 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java @@ -28,7 +28,7 @@ /** * Implements the storage of full-resolution vectors inline into an OnDiskGraphIndex. These can be used for exact scoring. */ -public class InlineVectors implements Feature { +public class InlineVectors extends AbstractFeature { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); private final int dimension; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java index 2489ada21..866bd171c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java @@ -33,7 +33,7 @@ /** * Implements the storage of NuVeQ vectors in an on-disk graph index. These can be used for reranking. */ -public class NVQ implements Feature { +public class NVQ extends AbstractFeature { private final NVQuantization nvq; private final NVQScorer scorer; private final ThreadLocal reusableQuantizedVector; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java index b5d4cc476..d7cb8080b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java @@ -29,7 +29,7 @@ import java.io.IOException; import java.io.UncheckedIOException; -public class SeparatedNVQ implements SeparatedFeature { +public class SeparatedNVQ extends AbstractSeparatedFeature { private final NVQuantization nvq; private final NVQScorer scorer; private final ThreadLocal reusableQuantizedVector; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java index f6bff8472..50bcef545 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java @@ -25,7 +25,7 @@ import java.io.IOException; import java.io.UncheckedIOException; -public class SeparatedVectors implements SeparatedFeature { +public class SeparatedVectors extends AbstractSeparatedFeature { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); private final int dimension; private long offset; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/ScoreFunction.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/ScoreFunction.java index 0dcb95823..9ae58fab1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/ScoreFunction.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/ScoreFunction.java @@ -17,7 +17,6 @@ package io.github.jbellis.jvector.graph.similarity; import io.github.jbellis.jvector.vector.VectorizationProvider; -import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; /** @@ -42,20 +41,31 @@ public interface ScoreFunction { float similarityTo(int node2); /** - * @return the similarity to all of the nodes that `node2` has an edge towards. + * Computes the similarity to the neighborIndex-th neighbor of origin. + * Before calling this function, enableSimilarityToNeighbors must be called first with the same origin. + * This function only works if it is called for the same origin node multiple times. * Used when expanding the neighbors of a search candidate. + * @param origin the node we are expanding + * @param neighborIndex the index of the neighbor we are scoring, a number between 0 and the number of neighbors of the origin node. + * @return the score */ - default VectorFloat edgeLoadingSimilarityTo(int node2) { + default float similarityToNeighbor(int origin, int neighborIndex) { throw new UnsupportedOperationException("bulk similarity not supported"); } /** - * @return true if `edgeLoadingSimilarityTo` is supported + * Load the corresponding data so that similarityToNeighbor can be used with the neighbors of the origin node. */ - default boolean supportsEdgeLoadingSimilarity() { + default void enableSimilarityToNeighbors(int origin) {} + + /** + * @return true if `similarityToNeighbor` is supported + */ + default boolean supportsSimilarityToNeighbors() { return false; } + interface ExactScoreFunction extends ScoreFunction { default boolean isExact() { return true; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java deleted file mode 100644 index d55ffbd8c..000000000 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java +++ /dev/null @@ -1,317 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.github.jbellis.jvector.quantization; - -import io.github.jbellis.jvector.graph.disk.feature.FusedADC; -import io.github.jbellis.jvector.graph.similarity.ScoreFunction; -import io.github.jbellis.jvector.vector.VectorSimilarityFunction; -import io.github.jbellis.jvector.vector.VectorUtil; -import io.github.jbellis.jvector.vector.VectorizationProvider; -import io.github.jbellis.jvector.vector.types.ByteSequence; -import io.github.jbellis.jvector.vector.types.VectorFloat; -import io.github.jbellis.jvector.vector.types.VectorTypeSupport; - -import java.util.Arrays; - -/** - * Performs similarity comparisons with compressed vectors without decoding them. - * These decoders use Quick(er) ADC-style transposed vectors fused into a graph. - */ -public abstract class FusedADCPQDecoder implements ScoreFunction.ApproximateScoreFunction { - private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); - protected final ProductQuantization pq; - protected final VectorFloat query; - protected final ExactScoreFunction esf; - protected final ByteSequence partialQuantizedSums; - // connected to the Graph View by caller - protected final FusedADC.PackedNeighbors neighbors; - // caller passes this to us for re-use across calls - protected final VectorFloat results; - // decoder state - protected final VectorFloat partialSums; - protected final VectorFloat partialBestDistances; - protected final int invocationThreshold; - protected int invocations = 0; - protected float bestDistance; - protected float worstDistance; - protected float delta; - protected boolean supportsQuantizedSimilarity = false; - protected final VectorSimilarityFunction vsf; - - // Implements section 3.4 of "Quicker ADC : Unlocking the Hidden Potential of Product Quantization with SIMD" - // The main difference is that since our graph structure rapidly converges towards the best results, - // we don't need to scan K values to have enough confidence that our worstDistance bound is reasonable. - protected FusedADCPQDecoder(ProductQuantization pq, VectorFloat query, int invocationThreshold, FusedADC.PackedNeighbors neighbors, VectorFloat results, ExactScoreFunction esf, VectorSimilarityFunction vsf) { - this.pq = pq; - this.query = query; - this.esf = esf; - this.invocationThreshold = invocationThreshold; - this.neighbors = neighbors; - this.results = results; - this.vsf = vsf; - - // compute partialSums, partialBestDistances, and bestDistance from the codebooks - // cosine similarity is a special case where we need to compute the squared magnitude of the query - // in the same loop, so we skip this and compute it in the cosine constructor - partialSums = pq.reusablePartialSums(); - partialBestDistances = pq.reusablePartialBestDistances(); - if (vsf != VectorSimilarityFunction.COSINE) { - VectorFloat center = pq.globalCentroid; - var centeredQuery = center == null ? query : VectorUtil.sub(query, center); - for (var i = 0; i < pq.getSubspaceCount(); i++) { - int offset = pq.subvectorSizesAndOffsets[i][1]; - int size = pq.subvectorSizesAndOffsets[i][0]; - var codebook = pq.codebooks[i]; - VectorUtil.calculatePartialSums(codebook, i, size, pq.getClusterCount(), centeredQuery, offset, vsf, partialSums, partialBestDistances); - } - bestDistance = VectorUtil.sum(partialBestDistances); - } - - // these will be computed by edgeLoadingSimilarityTo as we search - partialQuantizedSums = pq.reusablePartialQuantizedSums(); - } - - @Override - public VectorFloat edgeLoadingSimilarityTo(int origin) { - var permutedNodes = neighbors.getPackedNeighbors(origin); - results.zero(); - - if (supportsQuantizedSimilarity) { - // we have seen enough data to compute `delta`, so take the fast path using the permuted nodes - VectorUtil.bulkShuffleQuantizedSimilarity(permutedNodes, pq.compressedVectorSize(), partialQuantizedSums, delta, bestDistance, results, vsf); - return results; - } - - // we have not yet computed worstDistance or delta, so we need to assemble the results manually - // from the PQ codebooks - var nodeCount = results.length(); - for (int i = 0; i < pq.getSubspaceCount(); i++) { - for (int j = 0; j < nodeCount; j++) { - results.set(j, results.get(j) + partialSums.get(i * pq.getClusterCount() + Byte.toUnsignedInt(permutedNodes.get(i * nodeCount + j)))); - } - } - - // update worstDistance from our new set of results - for (int i = 0; i < nodeCount; i++) { - var result = results.get(i); - invocations++; - updateWorstDistance(result); - results.set(i, distanceToScore(result)); - } - - // once we have enough data, set up delta, partialQuantizedSums, and partialQuantizedMagnitudes for the fast path - if (invocations >= invocationThreshold) { - delta = (worstDistance - bestDistance) / 65535; - VectorUtil.quantizePartials(delta, partialSums, partialBestDistances, partialQuantizedSums); - supportsQuantizedSimilarity = true; - } - - return results; - } - - @Override - public boolean supportsEdgeLoadingSimilarity() { - return true; - } - - @Override - public float similarityTo(int node2) { - return esf.similarityTo(node2); - } - - protected abstract float distanceToScore(float distance); - - protected abstract void updateWorstDistance(float distance); - - static class DotProductDecoder extends FusedADCPQDecoder { - public DotProductDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat query, VectorFloat results, ExactScoreFunction esf) { - super(pq, query, neighbors.maxDegree(), neighbors, results, esf, VectorSimilarityFunction.DOT_PRODUCT); - worstDistance = Float.MAX_VALUE; // initialize at best value, update as we search - } - - @Override - protected float distanceToScore(float distance) { - return (distance + 1) / 2; - } - - @Override - protected void updateWorstDistance(float distance) { - worstDistance = Math.min(worstDistance, distance); - } - } - - static class EuclideanDecoder extends FusedADCPQDecoder { - public EuclideanDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat query, VectorFloat results, ExactScoreFunction esf) { - super(pq, query, neighbors.maxDegree(), neighbors, results, esf, VectorSimilarityFunction.EUCLIDEAN); - worstDistance = 0; // initialize at best value, update as we search - } - - @Override - protected float distanceToScore(float distance) { - return 1 / (1 + distance); - } - - @Override - protected void updateWorstDistance(float distance) { - worstDistance = Math.max(worstDistance, distance); - } - } - - - // CosineDecoder differs from DotProductDecoder/EuclideanDecoder because there are two different tables of quantized fragments to sum: query to codebook entry dot products, - // and codebook entry to codebook entry dot products. The latter can be calculated once per ProductQuantization, but for lookups to go at the appropriate speed, they must - // also be quantized. We use a similar quantization to partial sums, but we know exactly the worst/best bounds, so overflow does not matter. - static class CosineDecoder extends FusedADCPQDecoder { - private final float queryMagnitudeSquared; - private final VectorFloat partialSquaredMagnitudes; - private final ByteSequence partialQuantizedSquaredMagnitudes; - // prior to quantization, we need a good place on-heap to aggregate these for worstDistance tracking/result calculation - private final float[] resultSumAggregates; - private final float[] resultMagnitudeAggregates; - // store these to avoid repeated volatile lookups - private float minSquaredMagnitude; - private float squaredMagnitudeDelta; - - protected CosineDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat query, VectorFloat results, ExactScoreFunction esf) { - super(pq, query, neighbors.maxDegree(), neighbors, results, esf, VectorSimilarityFunction.COSINE); - worstDistance = Float.MAX_VALUE; // initialize at best value, update as we search - - // this part is not query-dependent, so we can cache it - partialSquaredMagnitudes = pq.partialSquaredMagnitudes().updateAndGet(current -> { - if (current != null) { - squaredMagnitudeDelta = pq.squaredMagnitudeDelta; - minSquaredMagnitude = pq.minSquaredMagnitude; - return current; - } - - // we only need these for quantization, minSquaredMagnitude/squaredMagnitudeDelta are sufficient for dequantization - float maxMagnitude = 0; - VectorFloat partialMinMagnitudes = vts.createFloatVector(pq.getSubspaceCount()); - - var partialSquaredMagnitudes = vts.createFloatVector(pq.getSubspaceCount() * pq.getClusterCount()); - for (int m = 0; m < pq.getSubspaceCount(); ++m) { - int size = pq.subvectorSizesAndOffsets[m][0]; - var codebook = pq.codebooks[m]; - float minPartialMagnitude = Float.POSITIVE_INFINITY; - float maxPartialMagnitude = 0; - for (int j = 0; j < pq.getClusterCount(); ++j) { - var partialMagnitude = VectorUtil.dotProduct(codebook, j * size, codebook, j * size, size); - minPartialMagnitude = Math.min(minPartialMagnitude, partialMagnitude); - maxPartialMagnitude = Math.max(maxPartialMagnitude, partialMagnitude); - partialSquaredMagnitudes.set((m * pq.getClusterCount()) + j, partialMagnitude); - } - - partialMinMagnitudes.set(m, minPartialMagnitude); - maxMagnitude += maxPartialMagnitude; - minSquaredMagnitude += minPartialMagnitude; - } - squaredMagnitudeDelta = (maxMagnitude - minSquaredMagnitude) / 65535; - var partialQuantizedSquaredMagnitudes = vts.createByteSequence(pq.getSubspaceCount() * pq.getClusterCount() * 2); - VectorUtil.quantizePartials(squaredMagnitudeDelta, partialSquaredMagnitudes, partialMinMagnitudes, partialQuantizedSquaredMagnitudes); - - // publish for future use in other decoders using this PQ - pq.squaredMagnitudeDelta = squaredMagnitudeDelta; - pq.minSquaredMagnitude = minSquaredMagnitude; - pq.partialQuantizedSquaredMagnitudes().set(partialQuantizedSquaredMagnitudes); - - return partialSquaredMagnitudes; - }); - partialQuantizedSquaredMagnitudes = pq.partialQuantizedSquaredMagnitudes().get(); - - // compute partialSums, partialBestDistances, bestDistance, and queryMagnitudeSquared from the codebooks - VectorFloat center = pq.globalCentroid; - float queryMagSum = 0.0f; - var centeredQuery = center == null ? query : VectorUtil.sub(query, center); - for (var i = 0; i < pq.getSubspaceCount(); i++) { - int offset = pq.subvectorSizesAndOffsets[i][1]; - int size = pq.subvectorSizesAndOffsets[i][0]; - var codebook = pq.codebooks[i]; - // cosine numerator is the same partial sums as if we were using DOT_PRODUCT - VectorUtil.calculatePartialSums(codebook, i, size, pq.getClusterCount(), centeredQuery, offset, VectorSimilarityFunction.DOT_PRODUCT, partialSums, partialBestDistances); - queryMagSum += VectorUtil.dotProduct(centeredQuery, offset, centeredQuery, offset, size); - } - this.queryMagnitudeSquared = queryMagSum; - bestDistance = VectorUtil.sum(partialBestDistances); - - this.resultSumAggregates = new float[results.length()]; - this.resultMagnitudeAggregates = new float[results.length()]; - } - - @Override - public VectorFloat edgeLoadingSimilarityTo(int origin) { - var permutedNodes = neighbors.getPackedNeighbors(origin); - - if (supportsQuantizedSimilarity) { - results.zero(); - // we have seen enough data to compute `delta`, so take the fast path using the permuted nodes - VectorUtil.bulkShuffleQuantizedSimilarityCosine(permutedNodes, pq.compressedVectorSize(), partialQuantizedSums, delta, bestDistance, partialQuantizedSquaredMagnitudes, squaredMagnitudeDelta, minSquaredMagnitude, queryMagnitudeSquared, results); - return results; - } - - // we have not yet computed worstDistance or delta, so we need to assemble the results manually - // from the PQ codebooks - var nodeCount = results.length(); - Arrays.fill(resultSumAggregates, 0); - Arrays.fill(resultMagnitudeAggregates, 0); - for (int i = 0; i < pq.getSubspaceCount(); i++) { - for (int j = 0; j < nodeCount; j++) { - resultSumAggregates[j] += partialSums.get(i * pq.getClusterCount() + Byte.toUnsignedInt(permutedNodes.get(i * nodeCount + j))); - resultMagnitudeAggregates[j] += partialSquaredMagnitudes.get(i * pq.getClusterCount() + Byte.toUnsignedInt(permutedNodes.get(i * nodeCount + j))); - } - } - - // update worstDistance from our new set of results - for (int i = 0; i < nodeCount; i++) { - updateWorstDistance(resultSumAggregates[i]); - var result = resultSumAggregates[i] / (float) Math.sqrt(resultMagnitudeAggregates[i] * queryMagnitudeSquared); - invocations++; - results.set(i, distanceToScore(result)); - } - - // once we have enough data, set up delta and partialQuantizedSums for the fast path - if (invocations >= invocationThreshold) { - delta = (worstDistance - bestDistance) / 65535; - VectorUtil.quantizePartials(delta, partialSums, partialBestDistances, partialQuantizedSums); - supportsQuantizedSimilarity = true; - } - - return results; - } - - protected float distanceToScore(float distance) { - return (1 + distance) / 2; - }; - - protected void updateWorstDistance(float distance) { - worstDistance = Math.min(worstDistance, distance); - }; - } - - public static FusedADCPQDecoder newDecoder(FusedADC.PackedNeighbors neighbors, ProductQuantization pq, VectorFloat query, - VectorFloat results, VectorSimilarityFunction similarityFunction, ExactScoreFunction esf) { - switch (similarityFunction) { - case DOT_PRODUCT: - return new DotProductDecoder(neighbors, pq, query, results, esf); - case EUCLIDEAN: - return new EuclideanDecoder(neighbors, pq, query, results, esf); - case COSINE: - return new CosineDecoder(neighbors, pq, query, results, esf); - default: - throw new IllegalArgumentException("Unsupported similarity function: " + similarityFunction); - } - } -} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedPQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedPQDecoder.java new file mode 100644 index 000000000..b52d0ca66 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedPQDecoder.java @@ -0,0 +1,236 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.quantization; + +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; +import io.github.jbellis.jvector.graph.disk.feature.FusedFeature; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorUtil; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.ByteSequence; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.agrona.collections.Int2ObjectHashMap; + +/** + * Performs similarity comparisons with compressed vectors without decoding them. + * These decoders use vectors fused into a graph. + */ +public abstract class FusedPQDecoder implements ScoreFunction.ApproximateScoreFunction { + private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + protected final ProductQuantization pq; + Int2ObjectHashMap hierarchyCachedFeatures; + protected final VectorFloat query; + protected final ExactScoreFunction esf; + // connected to the Graph View by caller + protected final FusedPQ.PackedNeighbors packedNeighbors; + // caller passes this to us for re-use across calls + protected final ByteSequence neighborCodes; + // decoder state + protected final VectorFloat partialSums; + protected final VectorSimilarityFunction vsf; + protected int origin; + + protected FusedPQDecoder(ProductQuantization pq, + Int2ObjectHashMap hierarchyCachedFeatures, + VectorFloat query, FusedPQ.PackedNeighbors packedNeighbors, + ByteSequence neighborCodes, VectorFloat results, ExactScoreFunction esf, + VectorSimilarityFunction vsf) { + this.pq = pq; + this.hierarchyCachedFeatures = hierarchyCachedFeatures; + this.query = query; + this.esf = esf; + this.packedNeighbors = packedNeighbors; + this.neighborCodes = neighborCodes; + this.vsf = vsf; + this.origin = -1; + + // compute partialSums + // cosine similarity is a special case where we need to compute the squared magnitude of the query + // in the same loop, so we skip this and compute it in the cosine constructor + partialSums = pq.reusablePartialSums(); + if (vsf != VectorSimilarityFunction.COSINE) { + VectorFloat center = pq.globalCentroid; + var centeredQuery = center == null ? query : VectorUtil.sub(query, center); + for (var i = 0; i < pq.getSubspaceCount(); i++) { + int offset = pq.subvectorSizesAndOffsets[i][1]; + int size = pq.subvectorSizesAndOffsets[i][0]; + var codebook = pq.codebooks[i]; + VectorUtil.calculatePartialSums(codebook, i, size, pq.getClusterCount(), centeredQuery, offset, vsf, partialSums); + } + } + } + + @Override + public boolean supportsSimilarityToNeighbors() { + return true; + } + + @Override + public void enableSimilarityToNeighbors(int origin) { + if (this.origin != origin) { + this.origin = origin; + packedNeighbors.readInto(origin, neighborCodes); + } + } + + @Override + public float similarityTo(int node2) { + if (!hierarchyCachedFeatures.containsKey(node2)) { + throw new IllegalArgumentException("Node " + node2 + " is not in the hierarchy"); + } + + var code2 = (FusedPQ.FusedPQInlineSource) hierarchyCachedFeatures.get(node2); + float sim = VectorUtil.assembleAndSum(partialSums, pq.getClusterCount(), code2.getCode()); + return distanceToScore(sim); + } + + @Override + public float similarityToNeighbor(int origin, int neighborIndex) { + if (this.origin != origin) { + throw new IllegalArgumentException("origin must be the same as the origin used to enable similarityToNeighbor"); + } + int position = neighborIndex * pq.getSubspaceCount(); + float sim = VectorUtil.assembleAndSum(partialSums, pq.getClusterCount(), neighborCodes, position, pq.getSubspaceCount()); + return distanceToScore(sim); + } + + protected abstract float distanceToScore(float distance); + + static class DotProductDecoder extends FusedPQDecoder { + public DotProductDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq, + Int2ObjectHashMap hierarchyCachedFeatures, + VectorFloat query, ByteSequence neighborCodes, VectorFloat results, + ExactScoreFunction esf) { + super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, esf, VectorSimilarityFunction.DOT_PRODUCT); + } + + @Override + protected float distanceToScore(float distance) { + return (distance + 1) / 2; + } + } + + static class EuclideanDecoder extends FusedPQDecoder { + public EuclideanDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq, + Int2ObjectHashMap hierarchyCachedFeatures, + VectorFloat query, ByteSequence neighborCodes, VectorFloat results, + ExactScoreFunction esf) { + super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, esf, VectorSimilarityFunction.EUCLIDEAN); + } + + @Override + protected float distanceToScore(float distance) { + return 1 / (1 + distance); + } + } + + + // CosineDecoder differs from DotProductDecoder/EuclideanDecoder because there are two different tables of fragments to sum: query to codebook entry dot products, + // and codebook entry to codebook entry dot products. The latter can be calculated once per ProductQuantization. + static class CosineDecoder extends FusedPQDecoder { + private final float queryMagnitudeSquared; + private final VectorFloat partialSquaredMagnitudes; + + protected CosineDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq, + Int2ObjectHashMap hierarchyCachedFeatures, + VectorFloat query, ByteSequence neighborCodes, VectorFloat results, + ExactScoreFunction esf) { + super(pq, hierarchyCachedFeatures, query, neighbors, neighborCodes, results, esf, VectorSimilarityFunction.COSINE); + + // this part is not query-dependent, so we can cache it + partialSquaredMagnitudes = pq.partialSquaredMagnitudes().updateAndGet(current -> { + if (current != null) { + return current; + } + + var partialSquaredMagnitudes = vts.createFloatVector(pq.getSubspaceCount() * pq.getClusterCount()); + for (int m = 0; m < pq.getSubspaceCount(); ++m) { + int size = pq.subvectorSizesAndOffsets[m][0]; + var codebook = pq.codebooks[m]; + float minPartialMagnitude = Float.POSITIVE_INFINITY; + float maxPartialMagnitude = 0; + for (int j = 0; j < pq.getClusterCount(); ++j) { + var partialMagnitude = VectorUtil.dotProduct(codebook, j * size, codebook, j * size, size); + minPartialMagnitude = Math.min(minPartialMagnitude, partialMagnitude); + maxPartialMagnitude = Math.max(maxPartialMagnitude, partialMagnitude); + partialSquaredMagnitudes.set((m * pq.getClusterCount()) + j, partialMagnitude); + } + } + return partialSquaredMagnitudes; + }); + + // compute partialSums + VectorFloat center = pq.globalCentroid; + float queryMagSum = 0.0f; + var centeredQuery = center == null ? query : VectorUtil.sub(query, center); + for (var i = 0; i < pq.getSubspaceCount(); i++) { + int offset = pq.subvectorSizesAndOffsets[i][1]; + int size = pq.subvectorSizesAndOffsets[i][0]; + var codebook = pq.codebooks[i]; + // cosine numerator is the same partial sums as if we were using DOT_PRODUCT + VectorUtil.calculatePartialSums(codebook, i, size, pq.getClusterCount(), centeredQuery, offset, VectorSimilarityFunction.DOT_PRODUCT, partialSums); + queryMagSum += VectorUtil.dotProduct(centeredQuery, offset, centeredQuery, offset, size); + } + + this.queryMagnitudeSquared = queryMagSum; + } + + @Override + public float similarityTo(int node2) { + if (!hierarchyCachedFeatures.containsKey(node2)) { + throw new IllegalArgumentException("Node " + node2 + " is not in the hierarchy"); + } + + var code2 = (FusedPQ.FusedPQInlineSource) hierarchyCachedFeatures.get(node2); + float cos = VectorUtil.pqDecodedCosineSimilarity(code2.getCode(), 0, pq.getSubspaceCount(), pq.getClusterCount(), partialSums, partialSquaredMagnitudes, queryMagnitudeSquared); + return distanceToScore(cos); + } + + @Override + public float similarityToNeighbor(int origin, int neighborIndex) { + if (this.origin != origin) { + throw new IllegalArgumentException("origin must be the same as the origin used to enable similarityToNeighbor"); + } + int position = neighborIndex * pq.getSubspaceCount(); + float cos = VectorUtil.pqDecodedCosineSimilarity(neighborCodes, position, pq.getSubspaceCount(), pq.getClusterCount(), partialSums, partialSquaredMagnitudes, queryMagnitudeSquared); + return distanceToScore(cos); + } + + + protected float distanceToScore(float distance) { + return (1 + distance) / 2; + }; + } + + public static FusedPQDecoder newDecoder(FusedPQ.PackedNeighbors neighbors, ProductQuantization pq, + Int2ObjectHashMap hierarchyCachedFeatures, VectorFloat query, + ByteSequence reusableNeighborCodes, VectorFloat results, + VectorSimilarityFunction similarityFunction, ExactScoreFunction esf) { + switch (similarityFunction) { + case DOT_PRODUCT: + return new DotProductDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, esf); + case EUCLIDEAN: + return new EuclideanDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, esf); + case COSINE: + return new CosineDecoder(neighbors, pq, hierarchyCachedFeatures, query, reusableNeighborCodes, results, esf); + default: + throw new IllegalArgumentException("Unsupported similarity function: " + similarityFunction); + } + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java index 0e98bbf32..12d683d59 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java @@ -72,12 +72,7 @@ public class ProductQuantization implements VectorCompressor>, A final float anisotropicThreshold; // parallel cost multiplier private final float[][] centroidNormsSquared; // precomputed norms of the centroids, for encoding private final ThreadLocal> partialSums; // for dot product, euclidean, and cosine partials - private final ThreadLocal> partialBestDistances; // for partial best distances during fused ADC - private final ThreadLocal> partialQuantizedSums; // for quantized sums during fused ADC private final AtomicReference> partialSquaredMagnitudes; // for cosine partials - private final AtomicReference> partialQuantizedSquaredMagnitudes; // for quantized squared magnitude partials during cosine fused ADC - protected volatile float squaredMagnitudeDelta = 0; // for cosine fused ADC squared magnitude quantization delta (since this is invariant for a given PQ) - protected volatile float minSquaredMagnitude = 0; // for cosine fused ADC minimum squared magnitude (invariant for a given PQ) /** * Initializes the codebooks by clustering the input data using Product Quantization. @@ -211,10 +206,7 @@ public ProductQuantization refine(RandomAccessVectorValues ravv, } this.anisotropicThreshold = anisotropicThreshold; this.partialSums = ThreadLocal.withInitial(() -> vectorTypeSupport.createFloatVector(getSubspaceCount() * getClusterCount())); - this.partialBestDistances = ThreadLocal.withInitial(() -> vectorTypeSupport.createFloatVector(getSubspaceCount())); - this.partialQuantizedSums = ThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(getSubspaceCount() * getClusterCount() * 2)); this.partialSquaredMagnitudes = new AtomicReference<>(null); - this.partialQuantizedSquaredMagnitudes= new AtomicReference<>(null); centroidNormsSquared = new float[M][clusterCount]; @@ -531,22 +523,10 @@ VectorFloat reusablePartialSums() { return partialSums.get(); } - ByteSequence reusablePartialQuantizedSums() { - return partialQuantizedSums.get(); - } - - VectorFloat reusablePartialBestDistances() { - return partialBestDistances.get(); - } - AtomicReference> partialSquaredMagnitudes() { return partialSquaredMagnitudes; } - AtomicReference> partialQuantizedSquaredMagnitudes() { - return partialQuantizedSquaredMagnitudes; - } - public void write(DataOutput out, int version) throws IOException { if (version > OnDiskGraphIndex.CURRENT_VERSION) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequence.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequence.java index 26405c549..dcce077c7 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequence.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequence.java @@ -56,12 +56,6 @@ public void set(int n, byte value) { data[n] = value; } - @Override - public void setLittleEndianShort(int shortIndex, short value) { - data[shortIndex * 2] = (byte) (value & 0xFF); - data[shortIndex * 2 + 1] = (byte) ((value >> 8) & 0xFF); - } - @Override public void zero() { Arrays.fill(data, (byte) 0); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArraySliceByteSequence.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArraySliceByteSequence.java index 231325440..903d51aca 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArraySliceByteSequence.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArraySliceByteSequence.java @@ -57,14 +57,6 @@ public void set(int n, byte value) { data.set(offset + n, value); } - @Override - public void setLittleEndianShort(int shortIndex, short value) { - // Can't call setLittleEndianShort because the method shifts the index and we don't require - // that the slice is aligned to a short boundary - data.set(offset + shortIndex * 2, (byte) (value & 0xFF)); - data.set(offset + shortIndex * 2 + 1, (byte) ((value >> 8) & 0xFF)); - } - @Override public void zero() { for (int i = 0; i < length; i++) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java index 867e1c85d..5843dc5f6 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java @@ -364,43 +364,6 @@ public void calculatePartialSums(VectorFloat codebook, int codebookIndex, int } } - @Override - public void calculatePartialSums(VectorFloat codebook, int codebookIndex, int size, int clusterCount, VectorFloat query, int queryOffset, VectorSimilarityFunction vsf, VectorFloat partialSums, VectorFloat partialBest) { - float best = vsf == VectorSimilarityFunction.EUCLIDEAN ? Float.MAX_VALUE : -Float.MAX_VALUE; - float val; - int codebookBase = codebookIndex * clusterCount; - for (int i = 0; i < clusterCount; i++) { - switch (vsf) { - case DOT_PRODUCT: - val = dotProduct(codebook, i * size, query, queryOffset, size); - partialSums.set(codebookBase + i, val); - best = Math.max(best, val); - break; - case EUCLIDEAN: - val = squareDistance(codebook, i * size, query, queryOffset, size); - partialSums.set(codebookBase + i, val); - best = Math.min(best, val); - break; - default: - throw new UnsupportedOperationException("Unsupported similarity function " + vsf); - } - } - partialBest.set(codebookIndex, best); - } - - @Override - public void quantizePartials(float delta, VectorFloat partials, VectorFloat partialBases, ByteSequence quantizedPartials) { - var codebookSize = partials.length() / partialBases.length(); - for (int i = 0; i < partialBases.length(); i++) { - var localBest = partialBases.get(i); - for (int j = 0; j < codebookSize; j++) { - var val = partials.get(i * codebookSize + j); - var quantized = (short) Math.min((val - localBest) / delta, 65535); - quantizedPartials.setLittleEndianShort(i * codebookSize + j, quantized); - } - } - } - @Override public float max(VectorFloat v) { float max = -Float.MAX_VALUE; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index e7e8b068f..83cb5885b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -174,33 +174,14 @@ public static float assembleAndSumPQ(VectorFloat data, int subspaceCount, Byt return impl.assembleAndSumPQ(data, subspaceCount, dataOffsets1, dataOffsetsOffset1, dataOffsets2, dataOffsetsOffset2, clusterCount); } - public static void bulkShuffleQuantizedSimilarity(ByteSequence shuffles, int codebookCount, ByteSequence quantizedPartials, float delta, float minDistance, VectorFloat results, VectorSimilarityFunction vsf) { - impl.bulkShuffleQuantizedSimilarity(shuffles, codebookCount, quantizedPartials, delta, minDistance, vsf, results); - } - - public static void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int codebookCount, - ByteSequence quantizedPartialSums, float sumDelta, float minDistance, - ByteSequence quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, - float queryMagnitudeSquared, VectorFloat results) { - impl.bulkShuffleQuantizedSimilarityCosine(shuffles, codebookCount, quantizedPartialSums, sumDelta, minDistance, quantizedPartialMagnitudes, magnitudeDelta, minMagnitude, queryMagnitudeSquared, results); - } - public static int hammingDistance(long[] v1, long[] v2) { return impl.hammingDistance(v1, v2); } - public static void calculatePartialSums(VectorFloat codebook, int codebookIndex, int size, int clusterCount, VectorFloat query, int offset, VectorSimilarityFunction vsf, VectorFloat partialSums, VectorFloat partialBestDistances) { - impl.calculatePartialSums(codebook, codebookIndex, size, clusterCount, query, offset, vsf, partialSums, partialBestDistances); - } - public static void calculatePartialSums(VectorFloat codebook, int codebookIndex, int size, int clusterCount, VectorFloat query, int offset, VectorSimilarityFunction vsf, VectorFloat partialSums) { impl.calculatePartialSums(codebook, codebookIndex, size, clusterCount, query, offset, vsf, partialSums); } - public static void quantizePartials(float delta, VectorFloat partials, VectorFloat partialBase, ByteSequence quantizedPartials) { - impl.quantizePartials(delta, partials, partialBase, quantizedPartials); - } - /** * Calculates the maximum value in the vector. * @param v vector diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index cc1f74f1b..d8223ab12 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -132,111 +132,8 @@ public interface VectorUtilSupport { int hammingDistance(long[] v1, long[] v2); - // default implementation used here because Panama SIMD can't express necessary SIMD operations and degrades to scalar - /** - * Calculates the similarity score of multiple product quantization-encoded vectors against a single query vector, - * using quantized precomputed similarity score fragments derived from codebook contents and evaluations during a search. - * @param shuffles a sequence of shuffles to be used against partial pre-computed fragments. These are transposed PQ-encoded - * vectors using the same codebooks as the partials. Due to the transposition, rather than this being - * contiguous encoded vectors, the first component of all vectors is stored contiguously, then the second, and so on. - * @param codebookCount The number of codebooks used in the PQ encoding. - * @param quantizedPartials The quantized precomputed score fragments for each codebook entry. These are stored as a contiguous vector of all - * the fragments for one codebook, followed by all the fragments for the next codebook, and so on. These have been - * quantized by quantizePartialSums. - * @param vsf The similarity function to use. - * @param results The output vector to store the similarity scores. This should be pre-allocated to the same size as the number of shuffles. - */ - default void bulkShuffleQuantizedSimilarity(ByteSequence shuffles, int codebookCount, ByteSequence quantizedPartials, float delta, float minDistance, VectorSimilarityFunction vsf, VectorFloat results) { - for (int i = 0; i < codebookCount; i++) { - for (int j = 0; j < results.length(); j++) { - var shuffle = Byte.toUnsignedInt(shuffles.get(i * results.length() + j)) * 2; - var lowByte = quantizedPartials.get(i * 512 + shuffle); - var highByte = quantizedPartials.get(i * 512 + shuffle + 1); - var val = ((Byte.toUnsignedInt(highByte) << 8) | Byte.toUnsignedInt(lowByte)); - results.set(j, results.get(j) + val); - } - } - - for (int i = 0; i < results.length(); i++) { - switch (vsf) { - case EUCLIDEAN: - results.set(i, 1 / (1 + (delta * results.get(i)) + minDistance)); - break; - case DOT_PRODUCT: - results.set(i, (1 + (delta * results.get(i)) + minDistance) / 2); - break; - default: - throw new UnsupportedOperationException("Unsupported similarity function " + vsf); - } - } - } - - // default implementation used here because Panama SIMD can't express necessary SIMD operations and degrades to scalar - /** - * Calculates the similarity score of multiple product quantization-encoded vectors against a single query vector, - * using quantized precomputed similarity score fragments derived from codebook contents and evaluations during a search. - * @param shuffles a sequence of shuffles to be used against partial pre-computed fragments. These are transposed PQ-encoded - * vectors using the same codebooks as the partials. Due to the transposition, rather than this being - * contiguous encoded vectors, the first component of all vectors is stored contiguously, then the second, and so on. - * @param codebookCount The number of codebooks used in the PQ encoding. - * @param quantizedPartialSums The quantized precomputed dot product fragments between query vector and codebook entries. - * These are stored as a contiguous vector of all the fragments for one codebook, followed by - * all the fragments for the next codebook, and so on. These have been quantized by quantizePartials. - * @param sumDelta The delta used to quantize quantizedPartialSums. - * @param minDistance The minimum distance used to quantize quantizedPartialSums. - * @param quantizedPartialSquaredMagnitudes The quantized precomputed squared magnitudes of each codebook entry. Quantized through the - * same process as quantizedPartialSums. - * @param magnitudeDelta The delta used to quantize quantizedPartialSquaredMagnitudes. - * @param minMagnitude The minimum magnitude used to quantize quantizedPartialSquaredMagnitudes. - * @param queryMagnitudeSquared The squared magnitude of the query vector. - * @param results The output vector to store the similarity distances. This should be pre-allocated to the same size as the number of shuffles. - */ - default void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int codebookCount, - ByteSequence quantizedPartialSums, float sumDelta, float minDistance, - ByteSequence quantizedPartialSquaredMagnitudes, float magnitudeDelta, float minMagnitude, - float queryMagnitudeSquared, VectorFloat results) { - float[] sums = new float[results.length()]; - float[] magnitudes = new float[results.length()]; - for (int i = 0; i < codebookCount; i++) { - for (int j = 0; j < results.length(); j++) { - var shuffle = Byte.toUnsignedInt(shuffles.get(i * results.length() + j)) * 2; - var lowByte = quantizedPartialSums.get(i * 512 + shuffle); - var highByte = quantizedPartialSums.get(i * 512 + shuffle + 1); - var val = ((Byte.toUnsignedInt(highByte) << 8) | Byte.toUnsignedInt(lowByte)); - sums[j] += val; - lowByte = quantizedPartialSquaredMagnitudes.get(i * 512 + shuffle); - highByte = quantizedPartialSquaredMagnitudes.get(i * 512 + shuffle + 1); - val = ((Byte.toUnsignedInt(highByte) << 8) | Byte.toUnsignedInt(lowByte)); - magnitudes[j] += val; - } - } - - for (int i = 0; i < results.length(); i++) { - float unquantizedSum = sumDelta * sums[i] + minDistance; - float unquantizedMagnitude = magnitudeDelta * magnitudes[i] + minMagnitude; - double divisor = Math.sqrt(unquantizedMagnitude * queryMagnitudeSquared); - results.set(i, (1 + (float) (unquantizedSum / divisor)) / 2); - } - } - void calculatePartialSums(VectorFloat codebook, int codebookIndex, int size, int clusterCount, VectorFloat query, int offset, VectorSimilarityFunction vsf, VectorFloat partialSums); - void calculatePartialSums(VectorFloat codebook, int codebookIndex, int size, int clusterCount, VectorFloat query, int offset, VectorSimilarityFunction vsf, VectorFloat partialSums, VectorFloat partialMins); - - /** - * Quantizes values in partials (of length N = M * K) into unsigned little-endian 16-bit integers stored in quantizedPartials in the same order. - * partialBases is of length M. For each indexed chunk of K values in partials, each value in the chunk is quantized by subtracting the value - * in partialBases as the same index and dividing by delta. If the value is greater than 65535, 65535 will be used. - * - * The caller is responsible for ensuring than no value in partialSums is larger than its corresponding partialBase. - * - * @param delta the divisor to use for quantization - * @param partials the values to quantize - * @param partialBases the base values to subtract from the partials - * @param quantizedPartials the output sequence to store the quantized values - */ - void quantizePartials(float delta, VectorFloat partials, VectorFloat partialBases, ByteSequence quantizedPartials); - float max(VectorFloat v); float min(VectorFloat v); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java index 1ebbe8196..35d5a748e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java @@ -34,12 +34,6 @@ public interface ByteSequence extends Accountable void set(int i, byte value); - /** - * @param shortIndex index (as if this was a short array) inside the sequence to set the short value - * @param value short value to set - */ - void setLittleEndianShort(int shortIndex, short value); - void zero(); void copyFrom(ByteSequence src, int srcOffset, int destOffset, int length); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java index 4623cbe9d..96e41dbce 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java @@ -71,7 +71,7 @@ public static void main(String[] args) throws IOException { ); List> featureSets = Arrays.asList( EnumSet.of(FeatureId.NVQ_VECTORS), -// EnumSet.of(FeatureId.NVQ_VECTORS, FeatureId.FUSED_ADC), + EnumSet.of(FeatureId.NVQ_VECTORS, FeatureId.FUSED_PQ), EnumSet.of(FeatureId.INLINE_VECTORS) ); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java index a4d62645f..2f84a1da4 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java @@ -21,10 +21,10 @@ import io.github.jbellis.jvector.example.benchmarks.BenchmarkTablePrinter; import io.github.jbellis.jvector.example.benchmarks.CountBenchmark; import io.github.jbellis.jvector.example.benchmarks.LatencyBenchmark; +import io.github.jbellis.jvector.example.benchmarks.Metric; import io.github.jbellis.jvector.example.benchmarks.QueryBenchmark; import io.github.jbellis.jvector.example.benchmarks.QueryTester; import io.github.jbellis.jvector.example.benchmarks.ThroughputBenchmark; -import io.github.jbellis.jvector.example.benchmarks.*; import io.github.jbellis.jvector.example.benchmarks.diagnostics.DiagnosticLevel; import io.github.jbellis.jvector.example.util.CompressorParameters; import io.github.jbellis.jvector.example.util.DataSet; @@ -35,7 +35,7 @@ import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.disk.feature.Feature; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; -import io.github.jbellis.jvector.graph.disk.feature.FusedADC; +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; import io.github.jbellis.jvector.graph.disk.feature.NVQ; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; @@ -255,15 +255,15 @@ private static Map, ImmutableGraphIndex> buildOnDisk(List { var writer = entry.getValue(); var features = entry.getKey(); Map> writeSuppliers; - if (features.contains(FeatureId.FUSED_ADC)) { + if (features.contains(FeatureId.FUSED_PQ)) { writeSuppliers = new EnumMap<>(FeatureId.class); var view = builder.getGraph().getView(); - writeSuppliers.put(FeatureId.FUSED_ADC, ordinal -> new FusedADC.State(view, pq, ordinal)); + writeSuppliers.put(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(view, pq, ordinal)); } else { writeSuppliers = Map.of(); } @@ -307,13 +307,13 @@ private static BuilderWithSuppliers builderWithSuppliers(Set features builder.with(new InlineVectors(floatVectors.dimension())); suppliers.put(FeatureId.INLINE_VECTORS, ordinal -> new InlineVectors.State(floatVectors.getVector(ordinal))); break; - case FUSED_ADC: + case FUSED_PQ: if (pq == null) { System.out.println("Skipping Fused ADC feature due to null ProductQuantization"); continue; } // no supplier as these will be used for writeInline, when we don't have enough information to fuse neighbors - builder.with(new FusedADC(onHeapGraph.maxDegree(), pq)); + builder.with(new FusedPQ(onHeapGraph.maxDegree(), pq)); break; case NVQ_VECTORS: int nSubVectors = floatVectors.dimension() == 2 ? 1 : 2; @@ -396,7 +396,7 @@ private static Map, ImmutableGraphIndex> buildInMemory(List queryVector, Immutabl var scoringView = (ImmutableGraphIndex.ScoringView) view; ScoreFunction.ApproximateScoreFunction asf; - if (features.contains(FeatureId.FUSED_ADC)) { + if (features.contains(FeatureId.FUSED_PQ)) { asf = scoringView.approximateScoreFunctionFor(queryVector, ds.similarityFunction); } else { asf = cv.precomputedScoreFunctionFor(queryVector, ds.similarityFunction); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryExecutor.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryExecutor.java index 9ec728808..0f70d68e1 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryExecutor.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/QueryExecutor.java @@ -34,10 +34,7 @@ public class QueryExecutor { */ public static SearchResult executeQuery(ConfiguredSystem cs, int topK, int rerankK, boolean usePruning, int i) { var queryVector = cs.getDataSet().queryVectors.get(i); - var searcher = cs.getSearcher(); - searcher.usePruning(usePruning); - var sf = cs.scoreProviderFor(queryVector, searcher.getView()); - return searcher.search(sf, topK, rerankK, 0.0f, 0.0f, Bits.ALL); + return executeQuery(cs, topK, rerankK, usePruning, queryVector); } // Overload to allow single query injection (e.g., for warm-up with random vectors) diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/ConstructionParameters.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/ConstructionParameters.java index 23fd75e03..5177fdf4a 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/ConstructionParameters.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/yaml/ConstructionParameters.java @@ -30,18 +30,43 @@ public class ConstructionParameters extends CommonParameters { public List addHierarchy; public List refineFinalGraph; public List reranking; + public List fusedGraph; public Boolean useSavedIndexIfExists; public List> getFeatureSets() { - return reranking.stream().map(item -> { - switch (item) { - case "FP": - return EnumSet.of(FeatureId.INLINE_VECTORS); - case "NVQ": - return EnumSet.of(FeatureId.NVQ_VECTORS); - default: - throw new IllegalArgumentException("Only 'FP' and 'NVQ' are supported"); + List> featureSets = null; + for (var fusedItem : fusedGraph) { + var newFeatures = reranking.stream().map(item -> { + EnumSet features; + + switch (item) { + case "FP": + if (fusedItem) { + features = EnumSet.of(FeatureId.INLINE_VECTORS, FeatureId.FUSED_PQ); + } else { + features = EnumSet.of(FeatureId.INLINE_VECTORS); + } + break; + case "NVQ": + if (fusedItem) { + features = EnumSet.of(FeatureId.NVQ_VECTORS, FeatureId.FUSED_PQ); + } else { + features = EnumSet.of(FeatureId.NVQ_VECTORS); + } + break; + default: + throw new IllegalArgumentException("Only 'FP' and 'NVQ' are supported"); + } + + return features; + }).collect(Collectors.toList()); + if (featureSets == null) { + featureSets = newFeatures; + } else { + featureSets.addAll(newFeatures); } - }).collect(Collectors.toList()); + } + + return featureSets; } } \ No newline at end of file diff --git a/jvector-examples/yaml-configs/ada002-100k.yml b/jvector-examples/yaml-configs/ada002-100k.yml index 141f952b3..628071d94 100644 --- a/jvector-examples/yaml-configs/ada002-100k.yml +++ b/jvector-examples/yaml-configs/ada002-100k.yml @@ -1,4 +1,4 @@ -version: 5 +version: 6 dataset: ada002-100k @@ -8,6 +8,7 @@ construction: neighborOverflow: [1.2f] addHierarchy: [Yes] refineFinalGraph: [Yes] + fusedGraph: [Yes, No] compression: - type: PQ parameters: @@ -31,4 +32,9 @@ search: m: 192 # k: 256 # optional parameter. By default, k=256 centerData: No - anisotropicThreshold: -1.0 # optional parameter. By default, anisotropicThreshold=-1 (i.e., no anisotropy) \ No newline at end of file + anisotropicThreshold: -1.0 # optional parameter. By default, anisotropicThreshold=-1 (i.e., no anisotropy) + benchmarks: + throughput: [AVG] + latency: [AVG, STD, P999] + count: [visited, expanded base layer] + accuracy: [recall] \ No newline at end of file diff --git a/jvector-examples/yaml-configs/autoDefault.yml b/jvector-examples/yaml-configs/autoDefault.yml index b770be43c..2381461e8 100644 --- a/jvector-examples/yaml-configs/autoDefault.yml +++ b/jvector-examples/yaml-configs/autoDefault.yml @@ -8,6 +8,7 @@ construction: neighborOverflow: [1.2f] addHierarchy: [Yes] refineFinalGraph: [Yes] + fusedGraph: [No] compression: - type: PQ parameters: diff --git a/jvector-examples/yaml-configs/colbert-1M.yml b/jvector-examples/yaml-configs/colbert-1M.yml index fe700c93f..880d5da41 100644 --- a/jvector-examples/yaml-configs/colbert-1M.yml +++ b/jvector-examples/yaml-configs/colbert-1M.yml @@ -1,4 +1,4 @@ -version: 5 +version: 6 dataset: colbert-1M @@ -8,6 +8,7 @@ construction: neighborOverflow: [1.2f] addHierarchy: [Yes] refineFinalGraph: [Yes] + fusedGraph: [No] compression: - type: PQ parameters: diff --git a/jvector-examples/yaml-configs/default.yml b/jvector-examples/yaml-configs/default.yml index 98b98d970..74a42f973 100644 --- a/jvector-examples/yaml-configs/default.yml +++ b/jvector-examples/yaml-configs/default.yml @@ -1,4 +1,4 @@ -version: 5 +version: 6 dataset: default @@ -8,6 +8,7 @@ construction: neighborOverflow: [1.2f] addHierarchy: [Yes] refineFinalGraph: [Yes] + fusedGraph: [No] compression: - type: PQ parameters: diff --git a/jvector-native/src/main/c/jextract_vector_simd.sh b/jvector-native/src/main/c/jextract_vector_simd.sh index 42c89b053..d44d375dd 100755 --- a/jvector-native/src/main/c/jextract_vector_simd.sh +++ b/jvector-native/src/main/c/jextract_vector_simd.sh @@ -52,6 +52,10 @@ if [ "$(printf '%s\n' "$MIN_GCC_VERSION" "$CURRENT_GCC_VERSION" | sort -V | head gcc -fPIC -O3 -march=icelake-server -c jvector_simd.c -o jvector_simd.o gcc -fPIC -O3 -march=x86-64 -c jvector_simd_check.c -o jvector_simd_check.o gcc -shared -o ../resources/libjvector.so jvector_simd_check.o jvector_simd.o + + rm -rf jvector_common.o + rm -rf jvector_simd.o + rm -rf jvector_simd_check.o else echo "WARNING: GCC version $CURRENT_GCC_VERSION is too old. Please upgrade to GCC $MIN_GCC_VERSION or newer." fi @@ -73,4 +77,4 @@ jextract \ jvector_simd.h # Set critical linker option with heap-based segments for all generated methods -sed -i 's/DESC)/DESC, Linker.Option.critical(true))/g' ../java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java +sed -i 's/DESC)/DESC, Linker.Option.critical(true))/g' ../java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java \ No newline at end of file diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index 812d14e44..d9c909c0f 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -544,4 +544,4 @@ void calculate_partial_sums_best_euclidean_f32_512(const float* codebook, int co } } partialBestDistances[codebookIndex] = best; -} +} \ No newline at end of file diff --git a/jvector-native/src/main/c/jvector_simd.h b/jvector-native/src/main/c/jvector_simd.h index 83a2219dc..55f1a46c1 100644 --- a/jvector-native/src/main/c/jvector_simd.h +++ b/jvector-native/src/main/c/jvector_simd.h @@ -34,4 +34,4 @@ void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances); void calculate_partial_sums_best_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances); -#endif +#endif \ No newline at end of file diff --git a/jvector-native/src/main/c/jvector_simd_check.c b/jvector-native/src/main/c/jvector_simd_check.c index 660deda03..bf134805e 100644 --- a/jvector-native/src/main/c/jvector_simd_check.c +++ b/jvector-native/src/main/c/jvector_simd_check.c @@ -34,4 +34,4 @@ bool check_compatibility(void) { } return avx512f_supported && avx512cd_supported && avx512bw_supported && avx512dq_supported && avx512vl_supported; -} +} \ No newline at end of file diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java index fe0917e7f..dece3bf63 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java @@ -87,11 +87,6 @@ public void set(int n, byte value) { segment.set(ValueLayout.JAVA_BYTE, n, value); } - @Override - public void setLittleEndianShort(int shortIndex, short value) { - segment.set(LITTLE_ENDIAN_SHORT_LAYOUT_UNALIGNED, shortIndex * 2, value); - } - @Override public void zero() { segment.fill((byte) 0); diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index 207b5ce98..48cd7d66e 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -18,6 +18,7 @@ import java.nio.ByteOrder; +import io.github.jbellis.jvector.annotations.Experimental; import io.github.jbellis.jvector.vector.cnative.NativeSimdOps; import io.github.jbellis.jvector.vector.types.ByteSequence; import io.github.jbellis.jvector.vector.types.VectorFloat; @@ -27,37 +28,36 @@ import jdk.incubator.vector.VectorSpecies; /** + * Experimental! * VectorUtilSupport implementation that prefers native/Panama SIMD. */ +@Experimental final class NativeVectorUtilSupport extends PanamaVectorUtilSupport { + public NativeVectorUtilSupport() {} + @Override - protected FloatVector fromVectorFloat(VectorSpecies SPEC, VectorFloat vector, int offset) - { + protected FloatVector fromVectorFloat(VectorSpecies SPEC, VectorFloat vector, int offset) { return FloatVector.fromMemorySegment(SPEC, ((MemorySegmentVectorFloat) vector).get(), vector.offset(offset), ByteOrder.LITTLE_ENDIAN); } @Override - protected FloatVector fromVectorFloat(VectorSpecies SPEC, VectorFloat vector, int offset, int[] indices, int indicesOffset) - { + protected FloatVector fromVectorFloat(VectorSpecies SPEC, VectorFloat vector, int offset, int[] indices, int indicesOffset) { throw new UnsupportedOperationException("Assembly not supported with memory segments."); } @Override - protected void intoVectorFloat(FloatVector vector, VectorFloat v, int offset) - { + protected void intoVectorFloat(FloatVector vector, VectorFloat v, int offset) { vector.intoMemorySegment(((MemorySegmentVectorFloat) v).get(), v.offset(offset), ByteOrder.LITTLE_ENDIAN); } @Override - protected ByteVector fromByteSequence(VectorSpecies SPEC, ByteSequence vector, int offset) - { + protected ByteVector fromByteSequence(VectorSpecies SPEC, ByteSequence vector, int offset) { return ByteVector.fromMemorySegment(SPEC, ((MemorySegmentByteSequence) vector).get(), offset, ByteOrder.LITTLE_ENDIAN); } @Override - protected void intoByteSequence(ByteVector vector, ByteSequence v, int offset) - { + protected void intoByteSequence(ByteVector vector, ByteSequence v, int offset) { vector.intoMemorySegment(((MemorySegmentByteSequence) v).get(), offset, ByteOrder.LITTLE_ENDIAN); } @@ -79,7 +79,6 @@ public float assembleAndSum(VectorFloat data, int dataBase, ByteSequence b return NativeSimdOps.assemble_and_sum_f32_512(((MemorySegmentVectorFloat) data).get(), dataBase, ((MemorySegmentByteSequence) baseOffsets).get(), baseOffsetsOffset, baseOffsetsLength); } - @Override public float assembleAndSumPQ( VectorFloat codebookPartialSums, @@ -95,52 +94,12 @@ public float assembleAndSumPQ( } @Override - public void calculatePartialSums(VectorFloat codebook, int codebookBase, int size, int clusterCount, VectorFloat query, int queryOffset, VectorSimilarityFunction vsf, VectorFloat partialSums) { - switch (vsf) { - case DOT_PRODUCT -> NativeSimdOps.calculate_partial_sums_dot_f32_512(((MemorySegmentVectorFloat)codebook).get(), codebookBase, size, clusterCount, ((MemorySegmentVectorFloat)query).get(), queryOffset, ((MemorySegmentVectorFloat)partialSums).get()); - case EUCLIDEAN -> NativeSimdOps.calculate_partial_sums_euclidean_f32_512(((MemorySegmentVectorFloat)codebook).get(), codebookBase, size, clusterCount, ((MemorySegmentVectorFloat)query).get(), queryOffset, ((MemorySegmentVectorFloat)partialSums).get()); - case COSINE -> throw new UnsupportedOperationException("Cosine similarity not supported for calculatePartialSums"); - } - } - - @Override - public void calculatePartialSums(VectorFloat codebook, int codebookBase, int size, int clusterCount, VectorFloat query, int queryOffset, VectorSimilarityFunction vsf, VectorFloat partialSums, VectorFloat partialBestDistances) { - switch (vsf) { - case DOT_PRODUCT -> NativeSimdOps.calculate_partial_sums_best_dot_f32_512(((MemorySegmentVectorFloat)codebook).get(), codebookBase, size, clusterCount, ((MemorySegmentVectorFloat)query).get(), queryOffset, ((MemorySegmentVectorFloat)partialSums).get(), ((MemorySegmentVectorFloat)partialBestDistances).get()); - case EUCLIDEAN -> NativeSimdOps.calculate_partial_sums_best_euclidean_f32_512(((MemorySegmentVectorFloat)codebook).get(), codebookBase, size, clusterCount, ((MemorySegmentVectorFloat)query).get(), queryOffset, ((MemorySegmentVectorFloat)partialSums).get(), ((MemorySegmentVectorFloat)partialBestDistances).get()); - case COSINE -> throw new UnsupportedOperationException("Cosine similarity not supported for calculatePartialSums"); - } - } - - @Override - public void bulkShuffleQuantizedSimilarity(ByteSequence shuffles, int codebookCount, ByteSequence quantizedPartials, float delta, float bestDistance, VectorSimilarityFunction vsf, VectorFloat results) { - assert shuffles.offset() == 0 : "Bulk shuffle shuffles are expected to have an offset of 0. Found: " + shuffles.offset(); - switch (vsf) { - case DOT_PRODUCT -> NativeSimdOps.bulk_quantized_shuffle_dot_f32_512(((MemorySegmentByteSequence) shuffles).get(), codebookCount, ((MemorySegmentByteSequence) quantizedPartials).get(), delta, bestDistance, ((MemorySegmentVectorFloat) results).get()); - case EUCLIDEAN -> NativeSimdOps.bulk_quantized_shuffle_euclidean_f32_512(((MemorySegmentByteSequence) shuffles).get(), codebookCount, ((MemorySegmentByteSequence) quantizedPartials).get(), delta, bestDistance, ((MemorySegmentVectorFloat) results).get()); - case COSINE -> throw new UnsupportedOperationException("Cosine similarity not supported for bulkShuffleQuantizedSimilarity"); - } - } - - @Override - public void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int codebookCount, - ByteSequence quantizedPartialSums, float sumDelta, float minDistance, - ByteSequence quantizedPartialSquaredMagnitudes, float magnitudeDelta, float minMagnitude, - float queryMagnitudeSquared, VectorFloat results) { - assert shuffles.offset() == 0 : "Bulk shuffle shuffles are expected to have an offset of 0. Found: " + shuffles.offset(); - NativeSimdOps.bulk_quantized_shuffle_cosine_f32_512(((MemorySegmentByteSequence) shuffles).get(), codebookCount, ((MemorySegmentByteSequence) quantizedPartialSums).get(), sumDelta, minDistance, - ((MemorySegmentByteSequence) quantizedPartialSquaredMagnitudes).get(), magnitudeDelta, minMagnitude, queryMagnitudeSquared, ((MemorySegmentVectorFloat) results).get()); - } - - @Override - public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) - { + public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude); } @Override - public float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) - { + public float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { assert encoded.offset() == 0 : "Bulk shuffle shuffles are expected to have an offset of 0. Found: " + encoded.offset(); // encoded is a pointer into a PQ chunk - we need to index into it by encodedOffset and provide encodedLength to the native code return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encodedOffset, encodedLength, clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude); diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorizationProvider.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorizationProvider.java index 3c876e517..7bb4fa514 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorizationProvider.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorizationProvider.java @@ -51,4 +51,4 @@ public VectorUtilSupport getVectorUtilSupport() { public VectorTypeSupport getVectorTypeSupport() { return vectorTypeSupport; } -} +} \ No newline at end of file diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/LibraryLoader.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/LibraryLoader.java index 3eb95f455..80d7b2d94 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/LibraryLoader.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/LibraryLoader.java @@ -54,4 +54,4 @@ public static boolean loadJvector() { return false; } -} +} \ No newline at end of file diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java index 2005d0d5f..014bdf4b0 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java @@ -22,15 +22,15 @@ public class NativeSimdOps { static final boolean TRACE_DOWNCALLS = Boolean.getBoolean("jextract.trace.downcalls"); static void traceDowncall(String name, Object... args) { - String traceArgs = Arrays.stream(args) - .map(Object::toString) - .collect(Collectors.joining(", ")); - System.out.printf("%s(%s)\n", name, traceArgs); + String traceArgs = Arrays.stream(args) + .map(Object::toString) + .collect(Collectors.joining(", ")); + System.out.printf("%s(%s)\n", name, traceArgs); } static MemorySegment findOrThrow(String symbol) { return SYMBOL_LOOKUP.find(symbol) - .orElseThrow(() -> new UnsatisfiedLinkError("unresolved symbol: " + symbol)); + .orElseThrow(() -> new UnsatisfiedLinkError("unresolved symbol: " + symbol)); } static MethodHandle upcallHandle(Class fi, String name, FunctionDescriptor fdesc) { @@ -98,7 +98,7 @@ public static int __bool_true_false_are_defined() { private static class check_compatibility { public static final FunctionDescriptor DESC = FunctionDescriptor.of( - NativeSimdOps.C_BOOL ); + NativeSimdOps.C_BOOL ); public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("check_compatibility"); @@ -148,19 +148,19 @@ public static boolean check_compatibility() { } return (boolean)mh$.invokeExact(); } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); + throw new AssertionError("should not reach here", ex$); } } private static class dot_product_f32 { public static final FunctionDescriptor DESC = FunctionDescriptor.of( - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT ); public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("dot_product_f32"); @@ -211,19 +211,19 @@ public static float dot_product_f32(int preferred_size, MemorySegment a, int aof } return (float)mh$.invokeExact(preferred_size, a, aoffset, b, boffset, length); } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); + throw new AssertionError("should not reach here", ex$); } } private static class euclidean_f32 { public static final FunctionDescriptor DESC = FunctionDescriptor.of( - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT ); public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("euclidean_f32"); @@ -274,18 +274,18 @@ public static float euclidean_f32(int preferred_size, MemorySegment a, int aoffs } return (float)mh$.invokeExact(preferred_size, a, aoffset, b, boffset, length); } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); + throw new AssertionError("should not reach here", ex$); } } private static class bulk_quantized_shuffle_dot_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_POINTER + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_POINTER ); public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("bulk_quantized_shuffle_dot_f32_512"); @@ -336,18 +336,18 @@ public static void bulk_quantized_shuffle_dot_f32_512(MemorySegment shuffles, in } mh$.invokeExact(shuffles, codebookCount, quantizedPartials, delta, minDistance, results); } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); + throw new AssertionError("should not reach here", ex$); } } private static class bulk_quantized_shuffle_euclidean_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_POINTER + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_POINTER ); public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("bulk_quantized_shuffle_euclidean_f32_512"); @@ -398,22 +398,22 @@ public static void bulk_quantized_shuffle_euclidean_f32_512(MemorySegment shuffl } mh$.invokeExact(shuffles, codebookCount, quantizedPartials, delta, minDistance, results); } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); + throw new AssertionError("should not reach here", ex$); } } private static class bulk_quantized_shuffle_cosine_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_POINTER + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_POINTER ); public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("bulk_quantized_shuffle_cosine_f32_512"); @@ -464,18 +464,18 @@ public static void bulk_quantized_shuffle_cosine_f32_512(MemorySegment shuffles, } mh$.invokeExact(shuffles, codebookCount, quantizedPartialSums, sumDelta, minDistance, quantizedPartialMagnitudes, magnitudeDelta, minMagnitude, queryMagnitudeSquared, results); } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); + throw new AssertionError("should not reach here", ex$); } } private static class assemble_and_sum_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.of( - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT ); public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("assemble_and_sum_f32_512"); @@ -526,20 +526,20 @@ public static float assemble_and_sum_f32_512(MemorySegment data, int dataBase, M } return (float)mh$.invokeExact(data, dataBase, baseOffsets, baseOffsetsOffset, baseOffsetsLength); } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); + throw new AssertionError("should not reach here", ex$); } } private static class pq_decoded_cosine_similarity_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.of( - NativeSimdOps.C_FLOAT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_FLOAT + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_FLOAT ); public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("pq_decoded_cosine_similarity_f32_512"); @@ -590,19 +590,19 @@ public static float pq_decoded_cosine_similarity_f32_512(MemorySegment baseOffse } return (float)mh$.invokeExact(baseOffsets, baseOffsetsOffset, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude); } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); + throw new AssertionError("should not reach here", ex$); } } private static class calculate_partial_sums_dot_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER ); public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("calculate_partial_sums_dot_f32_512"); @@ -653,19 +653,19 @@ public static void calculate_partial_sums_dot_f32_512(MemorySegment codebook, in } mh$.invokeExact(codebook, codebookBase, size, clusterCount, query, queryOffset, partialSums); } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); + throw new AssertionError("should not reach here", ex$); } } private static class calculate_partial_sums_euclidean_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER ); public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("calculate_partial_sums_euclidean_f32_512"); @@ -716,20 +716,20 @@ public static void calculate_partial_sums_euclidean_f32_512(MemorySegment codebo } mh$.invokeExact(codebook, codebookBase, size, clusterCount, query, queryOffset, partialSums); } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); + throw new AssertionError("should not reach here", ex$); } } private static class calculate_partial_sums_best_dot_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_POINTER + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_POINTER ); public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("calculate_partial_sums_best_dot_f32_512"); @@ -780,20 +780,20 @@ public static void calculate_partial_sums_best_dot_f32_512(MemorySegment codeboo } mh$.invokeExact(codebook, codebookBase, size, clusterCount, query, queryOffset, partialSums, partialBestDistances); } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); + throw new AssertionError("should not reach here", ex$); } } private static class calculate_partial_sums_best_euclidean_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_INT, - NativeSimdOps.C_POINTER, - NativeSimdOps.C_POINTER + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_POINTER ); public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("calculate_partial_sums_best_euclidean_f32_512"); @@ -844,8 +844,7 @@ public static void calculate_partial_sums_best_euclidean_f32_512(MemorySegment c } mh$.invokeExact(codebook, codebookBase, size, clusterCount, query, queryOffset, partialSums, partialBestDistances); } catch (Throwable ex$) { - throw new AssertionError("should not reach here", ex$); + throw new AssertionError("should not reach here", ex$); } } -} - +} \ No newline at end of file diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java index 05a3f7195..f94164630 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java @@ -24,10 +24,13 @@ import io.github.jbellis.jvector.graph.disk.CommonHeader; import io.github.jbellis.jvector.graph.disk.feature.Feature; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; -import io.github.jbellis.jvector.graph.disk.feature.FusedADC; +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; +import io.github.jbellis.jvector.graph.disk.feature.NVQ; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; +import io.github.jbellis.jvector.quantization.NVQuantization; import io.github.jbellis.jvector.quantization.PQVectors; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.vector.VectorUtil; @@ -53,6 +56,7 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; import java.util.function.IntFunction; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -151,15 +155,38 @@ public static void writeGraph(ImmutableGraphIndex graph, RandomAccessVectorValue OnDiskGraphIndex.write(graph, ravv, outputPath); } - public static void writeFusedGraph(ImmutableGraphIndex graph, RandomAccessVectorValues ravv, PQVectors pqv, Path outputPath) throws IOException { - try (var writer = new OnDiskGraphIndexWriter.Builder(graph, outputPath) - .with(new InlineVectors(ravv.dimension())) - .with(new FusedADC(graph.maxDegree(), pqv.getCompressor())).build()) - { - var suppliers = new EnumMap>(FeatureId.class); - suppliers.put(FeatureId.INLINE_VECTORS, ordinal -> new InlineVectors.State(ravv.getVector(ordinal))); - suppliers.put(FeatureId.FUSED_ADC, ordinal -> new FusedADC.State(graph.getView(), pqv, ordinal)); - writer.write(suppliers); + + public static void writeFusedGraph(ImmutableGraphIndex graph, RandomAccessVectorValues ravv, PQVectors pqv, FeatureId featureId, Path outputPath) throws IOException { + writeFusedGraph(graph, ravv, pqv, featureId, null, outputPath); + } + + public static void writeFusedGraph(ImmutableGraphIndex graph, RandomAccessVectorValues ravv, PQVectors pqv, + FeatureId featureId, Map oldToNewOrdinals, + Path outputPath) throws IOException { + var builder = new OnDiskGraphIndexWriter.Builder(graph, outputPath) + .with(new FusedPQ(graph.maxDegree(), pqv.getCompressor())); + + if (oldToNewOrdinals != null) { + builder = builder.withMap(oldToNewOrdinals); + } + + var suppliers = new EnumMap>(FeatureId.class); + suppliers.put(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(graph.getView(), pqv, ordinal)); + + if (featureId == FeatureId.INLINE_VECTORS) { + builder.with(new InlineVectors(ravv.dimension())); + suppliers.put(featureId, ordinal -> new InlineVectors.State(ravv.getVector(ordinal))); + } else if (featureId == FeatureId.NVQ_VECTORS) { + int nSubVectors = ravv.dimension() == 2 ? 1 : 2; + var nvq = NVQuantization.compute(ravv, nSubVectors); + builder.with(new NVQ(nvq)); + suppliers.put(FeatureId.NVQ_VECTORS, ordinal -> new NVQ.State(nvq.encode(ravv.getVector(ordinal)))); + } else { + throw new IllegalArgumentException("Either INLINE_VECTORS or NVQ_VECTORS are needed for reranking"); + } + + try (var finalWriter = builder.build()) { + finalWriter.write(suppliers); } } @@ -316,6 +343,17 @@ public NodesIterator getNeighborsIterator(int level, int node) { layerSizes.get(level) - 1); } + @Override + public void processNeighbors(int level, int node, ScoreFunction scoreFunction, IntMarker visited, NeighborProcessor neighborProcessor) { + for (var it = getNeighborsIterator(level, node); it.hasNext(); ) { + var friendOrd = it.nextInt(); + if (visited.mark(friendOrd)) { + float friendSimilarity = scoreFunction.similarityTo(friendOrd); + neighborProcessor.process(friendOrd, friendSimilarity); + } + } + } + @Deprecated @Override public int size() { @@ -454,6 +492,17 @@ public NodesIterator getNeighborsIterator(int level, int node) { return new NodesIterator.ArrayNodesIterator(adjacency.get(node)); } + @Override + public void processNeighbors(int level, int node, ScoreFunction scoreFunction, IntMarker visited, NeighborProcessor neighborProcessor) { + for (var it = getNeighborsIterator(level, node); it.hasNext(); ) { + var friendOrd = it.nextInt(); + if (visited.mark(friendOrd)) { + float friendSimilarity = scoreFunction.similarityTo(friendOrd); + neighborProcessor.process(friendOrd, friendSimilarity); + } + } + } + @Deprecated @Override public int size() { diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java index c1eaf4ec3..9b3e1c54a 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java @@ -28,7 +28,7 @@ import io.github.jbellis.jvector.graph.TestVectorGraph; import io.github.jbellis.jvector.graph.disk.feature.Feature; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; -import io.github.jbellis.jvector.graph.disk.feature.FusedADC; +import io.github.jbellis.jvector.graph.disk.feature.FusedPQ; import io.github.jbellis.jvector.graph.disk.feature.InlineVectors; import io.github.jbellis.jvector.graph.disk.feature.NVQ; import io.github.jbellis.jvector.graph.disk.feature.SeparatedNVQ; @@ -497,7 +497,7 @@ public void testIncrementalWrites() throws IOException { var pqv = (PQVectors) pq.encodeAll(ravv); try (var writer = new OnDiskGraphIndexWriter.Builder(graph, incrementalFadcPath) .with(new InlineVectors(ravv.dimension())) - .with(new FusedADC(graph.getDegree(0), pq)) + .with(new FusedPQ(graph.getDegree(0), pq)) .build()) { // write inline vectors incrementally @@ -506,8 +506,8 @@ public void testIncrementalWrites() throws IOException { writer.writeInline(i, state); } // write graph structure, fused ADC - writer.write(Feature.singleStateFactory(FeatureId.FUSED_ADC, i -> new FusedADC.State(graph.getView(), pqv, i))); - writer.write(Map.of()); + writer.write(Feature.singleStateFactory(FeatureId.FUSED_PQ, i -> new FusedPQ.State(graph.getView(), pqv, i))); + writer.write(Map.of(FeatureId.FUSED_PQ, ordinal -> new FusedPQ.State(graph.getView(), pqv, ordinal))); } // graph and vectors should be identical diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java deleted file mode 100644 index d6aefbe58..000000000 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.github.jbellis.jvector.quantization; - -import com.carrotsearch.randomizedtesting.RandomizedTest; -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; -import io.github.jbellis.jvector.TestUtil; -import io.github.jbellis.jvector.disk.SimpleMappedReader; -import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; -import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; -import io.github.jbellis.jvector.vector.VectorSimilarityFunction; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; - -import static io.github.jbellis.jvector.TestUtil.createRandomVectors; -import static org.junit.Assert.assertEquals; - -@ThreadLeakScope(ThreadLeakScope.Scope.NONE) -public class TestADCGraphIndex extends RandomizedTest { - - private Path testDirectory; - - @Before - public void setup() throws IOException { - testDirectory = Files.createTempDirectory(this.getClass().getSimpleName()); - } - - @After - public void tearDown() { - TestUtil.deleteQuietly(testDirectory); - } - - @Test - public void testFusedGraph() throws Exception { - // generate random graph, M=32, 256-dimension vectors - var graph = new TestUtil.RandomlyConnectedGraphIndex(1000, 32, getRandom()); - var outputPath = testDirectory.resolve("large_graph"); - var vectors = createRandomVectors(1000, 512); - var ravv = new ListRandomAccessVectorValues(vectors, 512); - var pq = ProductQuantization.compute(ravv, 8, 256, false); - var pqv = (PQVectors) pq.encodeAll(ravv); - - TestUtil.writeFusedGraph(graph, ravv, pqv, outputPath); - - try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath); - var onDiskGraph = OnDiskGraphIndex.load(readerSupplier, 0)) - { - TestUtil.assertGraphEquals(graph, onDiskGraph); - try (var cachedOnDiskView = onDiskGraph.getView()) - { - for (var similarityFunction : VectorSimilarityFunction.values()) { - var queryVector = TestUtil.randomVector(getRandom(), 512); - var pqScoreFunction = pqv.precomputedScoreFunctionFor(queryVector, similarityFunction); - var reranker = cachedOnDiskView.rerankerFor(queryVector, similarityFunction); - for (int i = 0; i < 50; i++) { - var fusedScoreFunction = cachedOnDiskView.approximateScoreFunctionFor(queryVector, similarityFunction); - var ordinal = getRandom().nextInt(graph.size(0)); - // first pass compares fused ADC's direct similarity to reranker's similarity, used for comparisons to a specific node - var neighbors = cachedOnDiskView.getNeighborsIterator(0, ordinal); - for (; neighbors.hasNext(); ) { - var neighbor = neighbors.next(); - var similarity = fusedScoreFunction.similarityTo(neighbor); - assertEquals(reranker.similarityTo(neighbor), similarity, 0.01); - } - // second pass compares fused ADC's edge similarity prior to having enough information for quantization to PQ - neighbors = cachedOnDiskView.getNeighborsIterator(0, ordinal); - var edgeSimilarities = fusedScoreFunction.edgeLoadingSimilarityTo(ordinal); - for (int j = 0; neighbors.hasNext(); j++) { - var neighbor = neighbors.next(); - assertEquals(pqScoreFunction.similarityTo(neighbor), edgeSimilarities.get(j), 0.01); - } - // third pass compares fused ADC's edge similarity after quantization to edge similarity before quantization - var edgeSimilaritiesCopy = edgeSimilarities.copy(); // results of second pass - var fusedEdgeSimilarities = fusedScoreFunction.edgeLoadingSimilarityTo(ordinal); // results of third pass - for (int j = 0; j < fusedEdgeSimilarities.length(); j++) { - assertEquals(fusedEdgeSimilarities.get(j), edgeSimilaritiesCopy.get(j), 0.01); - } - } - } - } - } - } -} diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestFusedGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestFusedGraphIndex.java new file mode 100644 index 000000000..c3a7bf2b1 --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestFusedGraphIndex.java @@ -0,0 +1,381 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.quantization; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import io.github.jbellis.jvector.TestUtil; +import io.github.jbellis.jvector.disk.SimpleMappedReader; +import io.github.jbellis.jvector.graph.GraphIndexBuilder; +import io.github.jbellis.jvector.graph.GraphSearcher; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.MockVectorValues; +import io.github.jbellis.jvector.graph.NodeQueue; +import io.github.jbellis.jvector.graph.SearchResult; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.util.BoundedLongHeap; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static io.github.jbellis.jvector.TestUtil.createRandomVectors; +import static io.github.jbellis.jvector.graph.TestVectorGraph.createRandomFloatVectors; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class TestFusedGraphIndex extends RandomizedTest { + + private Path testDirectory; + private Random random; + + @Before + public void setup() throws IOException { + testDirectory = Files.createTempDirectory(this.getClass().getSimpleName()); + random = getRandom(); + } + + @After + public void tearDown() { + TestUtil.deleteQuietly(testDirectory); + } + + @Test + public void testFusedGraph() throws Exception { + // generate random graph, M=32, 256-dimension vectors + var graph = new TestUtil.RandomlyConnectedGraphIndex(1000, 32, random); + var outputPath = testDirectory.resolve("large_graph"); + var vectors = createRandomVectors(1000, 512); + var ravv = new ListRandomAccessVectorValues(vectors, 512); + var pq = ProductQuantization.compute(ravv, 8, 256, false); + var pqv = (PQVectors) pq.encodeAll(ravv); + + TestUtil.writeFusedGraph(graph, ravv, pqv, FeatureId.INLINE_VECTORS, outputPath); + + try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath); + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier, 0)) + { + TestUtil.assertGraphEquals(graph, onDiskGraph); + try (var cachedOnDiskView = onDiskGraph.getView()) + { + for (var similarityFunction : VectorSimilarityFunction.values()) { + var queryVector = TestUtil.randomVector(random, 512); + var pqScoreFunction = pqv.precomputedScoreFunctionFor(queryVector, similarityFunction); + var fusedScoreFunction = cachedOnDiskView.approximateScoreFunctionFor(queryVector, similarityFunction); + + for (int i = 0; i < 50; i++) { + var ordinal = random.nextInt(graph.size(0)); + fusedScoreFunction.enableSimilarityToNeighbors(ordinal); + + var neighbors = cachedOnDiskView.getNeighborsIterator(0, ordinal); + for (int j = 0; neighbors.hasNext(); j++) { + var neighbor = neighbors.next(); + assertEquals( + pqScoreFunction.similarityTo(neighbor), + fusedScoreFunction.similarityToNeighbor(ordinal, j), + 1e-3 + ); + } + } + } + } + } + } + + @Test + // build a random graph, then check that it has at least 90% recall + public void testRecallOnGraphWithRandomVectors() throws IOException { + for (var similarityFunction : VectorSimilarityFunction.values()) { + for (var addHierarchy : List.of(false, true)) { + for (var featureId: List.of(FeatureId.INLINE_VECTORS, FeatureId.NVQ_VECTORS)) { + testRecallOnGraphWithRandomVectors(addHierarchy, similarityFunction, featureId); + } + } + } + } + + // build a random graph, then check that it has at least 90% recall + public void testRecallOnGraphWithRandomVectors(boolean addHierarchy, VectorSimilarityFunction similarityFunction, FeatureId featureId) throws IOException { + var outputPath = testDirectory.resolve("random_fused_graph" + random.nextInt()); + + int size = 1_000; + int dim = 32; + MockVectorValues vectors = vectorValues(size, dim); + + int topK = 5; + int efSearch = 20; + + GraphIndexBuilder builder = new GraphIndexBuilder(vectors, similarityFunction, 32, 32, 1.2f, 1.2f, addHierarchy); + var tempGraph = builder.build(vectors); + + var pq = ProductQuantization.compute(vectors, 8, 256, false); + var pqv = (PQVectors) pq.encodeAll(vectors); + + TestUtil.writeFusedGraph(tempGraph, vectors, pqv, featureId, outputPath); + + try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath); + var graph = OnDiskGraphIndex.load(readerSupplier, 0)) { + var searcher = new GraphSearcher(graph); + + Map totalMatches = new HashMap<>(); + totalMatches.put(true, 0); // true will be used for fused computations + totalMatches.put(false, 0); // false will be used for unfused computations + + for (int i = 0; i < 100; i++) { + SearchResult.NodeScore[] actual; + VectorFloat query = randomVector(dim); + + NodeQueue expected = new NodeQueue(new BoundedLongHeap(topK), NodeQueue.Order.MIN_HEAP); + for (int j = 0; j < size; j++) { + expected.push(j, similarityFunction.compare(query, vectors.getVector(j))); + } + + for (var fused : List.of(true, false)) { + SearchScoreProvider ssp = scoreProviderFor(fused, query, similarityFunction, searcher.getView(), pqv); + actual = searcher.search(ssp, topK, efSearch, 0.0f, 0.0f, Bits.ALL).getNodes(); + var actualNodeIds = Arrays.stream(actual, 0, topK).mapToInt(nodeScore -> nodeScore.node).toArray(); + + assertEquals(topK, actualNodeIds.length); + totalMatches.put(fused, totalMatches.get(fused) + computeOverlap(actualNodeIds, expected.nodesCopy())); + } + } + assertEquals(totalMatches.get(true), totalMatches.get(false)); + for (var fused : List.of(true, false)) { + double overlap = totalMatches.get(fused) / (double) (100 * topK); + assertTrue("overlap=" + overlap, overlap > 0.90); + } + } + } + + @Test + // build a random graph, then check that it has at least 90% recall + public void testScoresWithRandomVectors() throws IOException { + for (var similarityFunction : VectorSimilarityFunction.values()) { + for (var addHierarchy : List.of(false, true)) { + for (var featureId: List.of(FeatureId.INLINE_VECTORS, FeatureId.NVQ_VECTORS)) { + testScoresWithRandomVectors(addHierarchy, similarityFunction, featureId); + } + } + } + } + + public void testScoresWithRandomVectors(boolean addHierarchy, VectorSimilarityFunction similarityFunction, FeatureId featureId) throws IOException { + var outputPath = testDirectory.resolve("random_fused_graph" + random.nextInt()); + + int size = 1_000; + int dim = 32; + MockVectorValues vectors = vectorValues(size, dim); + + GraphIndexBuilder builder = new GraphIndexBuilder(vectors, similarityFunction, 32, 32, 1.2f, 1.2f, addHierarchy); + var tempGraph = builder.build(vectors); + + var pq = ProductQuantization.compute(vectors, 8, 256, false); + var pqv = (PQVectors) pq.encodeAll(vectors); + + TestUtil.writeFusedGraph(tempGraph, vectors, pqv, featureId, outputPath); + + try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath); + var graph = OnDiskGraphIndex.load(readerSupplier, 0)) { + + for (int iQuery = 0; iQuery < 10; iQuery++) { + VectorFloat query = randomVector(dim); + + var scoreFunction = scoreProviderFor(false, query, similarityFunction, graph.getView(), pqv).scoreFunction(); + var fusedScoreFunction = scoreProviderFor(true, query, similarityFunction, graph.getView(), pqv).scoreFunction(); + + for (int node = 0; node < size; node++) { + fusedScoreFunction.enableSimilarityToNeighbors(node); + var it = graph.getView().getNeighborsIterator(0, node); + int neighIndex = 0; + while (it.hasNext()) { + int neighbor = it.next(); + assertEquals( + scoreFunction.similarityTo(neighbor), + fusedScoreFunction.similarityToNeighbor(node, neighIndex), + 1e-6 + ); + neighIndex++; + } + } + } + } + } + + @Test + public void testReorderingRenumbering() throws IOException { + testReorderingRenumbering(false); + testReorderingRenumbering(true); + } + + public void testReorderingRenumbering(boolean addHierarchy) throws IOException { + var outputPath = testDirectory.resolve("renumbered_graph" + random.nextInt()); + + // graph of 3 vectors + int size = 1_000; + int dim = 32; + MockVectorValues ravv = vectorValues(size, dim); + + var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 32, 10, 1.0f, 1.0f, addHierarchy); + var original = TestUtil.buildSequentially(builder, ravv); + + // create renumbering map + Map oldToNewMap = new HashMap<>(); + for (int i = 0; i < ravv.size(); i++) { + oldToNewMap.put(i, ravv.size() - 1 - i); + } + + var pq = ProductQuantization.compute(ravv, 8, 256, false); + var pqv = (PQVectors) pq.encodeAll(ravv); + + // write the graph + TestUtil.writeFusedGraph(original, ravv, pqv, FeatureId.INLINE_VECTORS, oldToNewMap, outputPath); + + // check that written graph ordinals match the new ones + try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath); + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier); + var onDiskView = onDiskGraph.getView()) + { + for (int i = 0; i < ravv.size(); i++) { + assertEquals(onDiskView.getVector(i), ravv.getVector(ravv.size() - 1 - i)); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + // build a random graph, then check that it has at least 90% recall + public void testRecallOnGraphWithRenumbering() throws IOException { + for (var addHierarchy : List.of(true)) { + testRecallOnGraphWithRenumbering(addHierarchy, VectorSimilarityFunction.COSINE, FeatureId.INLINE_VECTORS); + } + } + + // build a random graph, then check that it has at least 90% recall + public void testRecallOnGraphWithRenumbering(boolean addHierarchy, VectorSimilarityFunction similarityFunction, FeatureId featureId) throws IOException { + var outputPath = testDirectory.resolve("random_fused_graph"); + + int size = 1_000; + int dim = 32; + MockVectorValues vectors = vectorValues(size, dim); + + int topK = 5; + int efSearch = 20; + + GraphIndexBuilder builder = new GraphIndexBuilder(vectors, similarityFunction, 32, 32, 1.2f, 1.2f, addHierarchy); + var tempGraph = builder.build(vectors); + + var pq = ProductQuantization.compute(vectors, 8, 256, false); + var pqv = (PQVectors) pq.encodeAll(vectors); + + // create renumbering map + Map oldToNewMap = new HashMap<>(); + for (int i = 0; i < vectors.size(); i++) { + oldToNewMap.put(i, vectors.size() - 1 - i); + } + + TestUtil.writeFusedGraph(tempGraph, vectors, pqv, featureId, oldToNewMap, outputPath); + + try (var readerSupplier = new SimpleMappedReader.Supplier(outputPath); + var graph = OnDiskGraphIndex.load(readerSupplier, 0)) { + var searcher = new GraphSearcher(graph); + + int totalMatches = 0; + + for (int i = 0; i < 100; i++) { + SearchResult.NodeScore[] actual; + VectorFloat query = randomVector(dim); + + NodeQueue expected = new NodeQueue(new BoundedLongHeap(topK), NodeQueue.Order.MIN_HEAP); + for (int j = 0; j < size; j++) { + expected.push(j, similarityFunction.compare(query, vectors.getVector(j))); + } + int[] expectedNodeIds = expected.nodesCopy(); + for (int j = 0; j < expectedNodeIds.length; j++) { + expectedNodeIds[j] = oldToNewMap.get(expectedNodeIds[j]); + } + + SearchScoreProvider ssp = scoreProviderFor(true, query, similarityFunction, searcher.getView(), pqv); + actual = searcher.search(ssp, topK, efSearch, 0.0f, 0.0f, Bits.ALL).getNodes(); + var actualNodeIds = Arrays.stream(actual, 0, topK).mapToInt(nodeScore -> nodeScore.node).toArray(); + + assertEquals(topK, actualNodeIds.length); + totalMatches += computeOverlap(actualNodeIds, expectedNodeIds); + } + + double overlap = totalMatches / (double) (100 * topK); + assertTrue("overlap=" + overlap, overlap > 0.90); + + } + Files.deleteIfExists(outputPath); + } + + public SearchScoreProvider scoreProviderFor(boolean fused, VectorFloat queryVector, VectorSimilarityFunction similarityFunction, ImmutableGraphIndex.View view, CompressedVectors cv) { + var scoringView = (ImmutableGraphIndex.ScoringView) view; + ScoreFunction.ApproximateScoreFunction asf; + if (fused) { + asf = scoringView.approximateScoreFunctionFor(queryVector, similarityFunction); + } else { + asf = cv.precomputedScoreFunctionFor(queryVector, similarityFunction); + } + var rr = scoringView.rerankerFor(queryVector, similarityFunction); + return new DefaultSearchScoreProvider(asf, rr); + } + + MockVectorValues vectorValues(int size, int dimension) { + return MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, random)); + } + + VectorFloat randomVector(int dim) { + return TestUtil.randomVector(random, dim); + } + + private int computeOverlap(int[] a, int[] b) { + Arrays.sort(a); + Arrays.sort(b); + int overlap = 0; + for (int i = 0, j = 0; i < a.length && j < b.length; ) { + if (a[i] == b[j]) { + ++overlap; + ++i; + ++j; + } else if (a[i] > b[j]) { + ++j; + } else { + ++i; + } + } + return overlap; + } +} diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/vector/TestArraySliceByteSequence.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/vector/TestArraySliceByteSequence.java index 514b1eef3..281bbc6f3 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/vector/TestArraySliceByteSequence.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/vector/TestArraySliceByteSequence.java @@ -68,11 +68,6 @@ void testSetOperations() { slice.set(0, (byte) 10); assertEquals(10, slice.get(0)); assertEquals(10, baseSequence.get(1)); - - // Test setLittleEndianShort - slice.setLittleEndianShort(0, (short) 258); // 258 = 0x0102 - assertEquals(2, slice.get(0)); // least significant byte - assertEquals(1, slice.get(1)); // most significant byte } @Test diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index eacb10866..dc18c4bf2 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -23,7 +23,6 @@ import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.IntVector; import jdk.incubator.vector.LongVector; -import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.VectorMask; import jdk.incubator.vector.VectorOperators; import jdk.incubator.vector.VectorSpecies; @@ -1035,29 +1034,23 @@ public float min(VectorFloat v) { return min; } - @Override - public void quantizePartials(float delta, VectorFloat partials, VectorFloat partialBases, ByteSequence quantizedPartials) { - var codebookSize = partials.length() / partialBases.length(); - var codebookCount = partialBases.length(); - - for (int i = 0; i < codebookCount; i++) { - var vectorizedLength = FloatVector.SPECIES_512.loopBound(codebookSize); - var codebookBase = partialBases.get(i); - var codebookBaseVector = FloatVector.broadcast(FloatVector.SPECIES_512, codebookBase); - int j = 0; - for (; j < vectorizedLength; j += FloatVector.SPECIES_512.length()) { - var partialVector = fromVectorFloat(FloatVector.SPECIES_512, partials, i * codebookSize + j); - var quantized = (partialVector.sub(codebookBaseVector)).div(delta); - quantized = quantized.max(FloatVector.zero(FloatVector.SPECIES_512)).min(FloatVector.broadcast(FloatVector.SPECIES_512, 65535)); - var quantizedBytes = (ShortVector) quantized.convertShape(VectorOperators.F2S, ShortVector.SPECIES_256, 0); - intoByteSequence(quantizedBytes.reinterpretAsBytes(), quantizedPartials, 2 * (i * codebookSize + j)); - } - for (; j < codebookSize; j++) { - var val = partials.get(i * codebookSize + j); - var quantized = (short) Math.min((val - codebookBase) / delta, 65535); - quantizedPartials.setLittleEndianShort(i * codebookSize + j, quantized); - } - } + private static int combineBytes(int i, int shuffle, ByteSequence quantizedPartials) { + var lowByte = quantizedPartials.get(i * 512 + shuffle); + var highByte = quantizedPartials.get((i * 512) + 256 + shuffle); + return ((Byte.toUnsignedInt(highByte) << 8) | Byte.toUnsignedInt(lowByte)); + } + + private static float combineBytes(int i, int shuffle, VectorFloat partials) { + return partials.get(i * 256 + shuffle); + } + + private static int computeSingleShuffle(int codebookPosition, int neighborPosition, ByteSequence shuffles, int codebookCount) { + int blockSize = ByteVector.SPECIES_PREFERRED.length(); + + int blockIndex = neighborPosition / blockSize; + int positionWithinBlock = neighborPosition % blockSize; + int offset = blockIndex * blockSize * codebookCount; + return Byte.toUnsignedInt(shuffles.get(offset + blockSize * codebookPosition + positionWithinBlock)); } @Override @@ -1364,7 +1357,7 @@ public float nvqUniformLoss(VectorFloat vector, float minValue, float maxValu @Override public float nvqSquareL2Distance8bit(VectorFloat vector, ByteSequence quantizedVector, - float alpha, float x0, float minValue, float maxValue) { + float alpha, float x0, float minValue, float maxValue) { FloatVector squaredSumVec = FloatVector.zero(FloatVector.SPECIES_PREFERRED); int vectorizedLength = ByteVector.SPECIES_PREFERRED.loopBound(quantizedVector.length()); @@ -1406,7 +1399,7 @@ public float nvqSquareL2Distance8bit(VectorFloat vector, ByteSequence quan @Override public float nvqDotProduct8bit(VectorFloat vector, ByteSequence quantizedVector, - float alpha, float x0, float minValue, float maxValue) { + float alpha, float x0, float minValue, float maxValue) { FloatVector dotProdVec = FloatVector.zero(FloatVector.SPECIES_PREFERRED); int vectorizedLength = ByteVector.SPECIES_PREFERRED.loopBound(quantizedVector.length()); @@ -1446,8 +1439,8 @@ public float nvqDotProduct8bit(VectorFloat vector, ByteSequence quantizedV @Override public float[] nvqCosine8bit(VectorFloat vector, - ByteSequence quantizedVector, float alpha, float x0, float minValue, float maxValue, - VectorFloat centroid) { + ByteSequence quantizedVector, float alpha, float x0, float minValue, float maxValue, + VectorFloat centroid) { if (vector.length() != centroid.length()) { throw new IllegalArgumentException("Vectors must have the same length"); } @@ -1546,35 +1539,8 @@ public void calculatePartialSums(VectorFloat codebook, int codebookIndex, int } } - @Override - public void calculatePartialSums(VectorFloat codebook, int codebookIndex, int size, int clusterCount, VectorFloat query, int queryOffset, VectorSimilarityFunction vsf, VectorFloat partialSums, VectorFloat partialBest) { - float best = vsf == VectorSimilarityFunction.EUCLIDEAN ? Float.MAX_VALUE : -Float.MAX_VALUE; - float val; - int codebookBase = codebookIndex * clusterCount; - for (int i = 0; i < clusterCount; i++) { - switch (vsf) { - case DOT_PRODUCT: - val = dotProduct(codebook, i * size, query, queryOffset, size); - partialSums.set(codebookBase + i, val); - best = Math.max(best, val); - break; - case EUCLIDEAN: - val = squareDistance(codebook, i * size, query, queryOffset, size); - partialSums.set(codebookBase + i, val); - best = Math.min(best, val); - break; - default: - throw new UnsupportedOperationException("Unsupported similarity function " + vsf); - } - } - partialBest.set(codebookIndex, best); - } - - - @Override public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude); } -} - +} \ No newline at end of file