diff --git a/.gitignore b/.gitignore index 9fc38bae4..25335ca03 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ target/ +local/ .mvn/wrapper/maven-wrapper.jar .java-version diff --git a/README.md b/README.md index 8fef5996b..a393d4682 100644 --- a/README.md +++ b/README.md @@ -310,6 +310,8 @@ but as soon as you supply a filter, it wants at least one match in every submodu You can run `SiftSmall` and `Bench` directly to get an idea of what all is going on here. `Bench` will automatically download required datasets to the `fvec` and `hdf5` directories. The files used by `SiftSmall` can be found in the [siftsmall directory](./siftsmall) in the project root. +*Looking for memory sizing guidance?* See [devdocs/memory_estimator.md](devdocs/memory_estimator.md) for the full reference. + To run either class, you can use the Maven exec-plugin via the following incantations: > `mvn compile exec:exec@bench` diff --git a/devdocs/memory_estimator.md b/devdocs/memory_estimator.md new file mode 100644 index 000000000..e66ce03f3 --- /dev/null +++ b/devdocs/memory_estimator.md @@ -0,0 +1,60 @@ + + +## Estimating Memory Requirements + +JVector includes a core `MemoryCostEstimator` utility that projects the RAM footprint of a graph build before you invest in the full construction. Unlike the generic `RamUsageEstimator` (which inspects existing Java objects), this tool is domain-aware: it spins up a *representative* mini-index with `GraphIndexBuilder`, records the bytes consumed by graph structures, PQ vectors/codebooks, and thread-local buffers, then extrapolates what a full-scale build would cost. In other words, it answers “how much RAM will this configuration require?” rather than “how much does this in-memory object currently use?”. + +Key differences from `RamUsageEstimator`: + +- **Configuration-driven.** You supply an `IndexConfig`; the estimator knows about graph degree, hierarchy, overflow ratio, PQ knobs, etc. +- **Includes build/serving buffers.** It models the per-thread scratch space that `GraphIndexBuilder` and `GraphSearcher` allocate, which a generic object walk would miss. +- **Ram bytes roll-up.** Instead of calling `ramBytesUsed()` on each structure and summing blindly, the estimator samples per-layer node measurements, separates fixed from per-node overhead, and extrapolates—giving more nuance than the aggregate numbers you get from `RamUsageEstimator`. +- **Predictive rather than introspective.** You can estimate requirements before ever building the full index; `RamUsageEstimator` only reports on objects you already have. + +Use `RamUsageEstimator` when you need an exact footprint of *current* objects (e.g., debugging) and `MemoryCostEstimator` when planning or sizing JVector indexes. + +### Quick Start + +```java +int dimension = 768; + +MemoryCostEstimator.IndexConfig config = MemoryCostEstimator.IndexConfig.defaultConfig(dimension); +MemoryCostEstimator.MemoryModel model = MemoryCostEstimator.createModel(config, 2_000); + +MemoryCostEstimator.Estimate serving = model.estimateBytes(10_000_000); // 10M vectors, steady state +MemoryCostEstimator.Estimate indexing = model.estimateBytesWithIndexingBuffers(10_000_000, 16); // 16 build threads +MemoryCostEstimator.Estimate searching = model.estimateBytesWithSearchBuffers(10_000_000, 64); // 64 query threads + +long centralBytes = serving.value(); +long margin = serving.marginBytes(); +System.out.printf("Servicing: %d bytes ± %d bytes (%.0f%%)\n", centralBytes, margin, serving.marginFraction() * 100); +``` + +Pick a sample size between 1,000 and 10,000 vectors; larger samples tighten the projection at the cost of a longer warm-up build. The returned `MemoryModel` offers `estimateBytes` for steady-state usage plus helpers to account for thread-local scratch space during indexing or query serving. Each method now returns an `Estimate` (central value and ±margin so you can reason about best/worst cases directly). + +### Accuracy + +`MemoryCostEstimatorAccuracyTest` exercises the estimator against real graph builds with and without PQ. The measured footprint must stay within a 5% tolerance (roughly twice the worst error we have observed so far). Running the full six-configuration, seven-dimensionality matrix takes ~23 seconds on an M2 Max laptop; budget additional time on slower hardware. For deep dives you can crank the sample size to the 10,000-vector cap—just note that the 64d cases grow to a few seconds each while the 4096d + PQ runs stretch past a minute apiece. Current coverage includes + +* Hierarchical, M=16, no PQ — 64–4096 dims stay within +0.6–1.7% relative error (well below the 20% guardrail) +* Hierarchical, M=16, PQ (16×256 up to 1024d, 32×256 for ≥2048d) — 64–4096 dims land between +0.04% and +1.45% +* Flat, M=32, no PQ — 64–4096 dims sit in the +0.0–1.0% range +* Flat, M=24, PQ (16×256 up to 1024d, 32×256 for ≥2048d) — 64–4096 dims stay within +0.0–1.2% +* Hierarchical, high degree (M=48, no PQ) — 64–4096 dims remain inside +0.1–1.0% +* Hierarchical, cosine similarity (M=16, no PQ) — 64–4096 dims land between +0.6–2.3% + +If you tune different configurations regularly (e.g., larger degrees or alternative similarity functions), extend the matrix so the backing data stays relevant. diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/MemoryCostEstimator.java b/jvector-base/src/main/java/io/github/jbellis/jvector/MemoryCostEstimator.java new file mode 100644 index 000000000..7d28fac00 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/MemoryCostEstimator.java @@ -0,0 +1,367 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector; + +import io.github.jbellis.jvector.graph.GraphIndexBuilder; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.OnHeapGraphIndex; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.util.PhysicalCoreExecutor; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ThreadLocalRandom; + +/** + * Predictive sizing utility for JVector indexes. Unlike generic heap estimators such as {@code + * RamUsageEstimator}, this class is configuration-aware: it constructs a small, representative + * sample index using {@link GraphIndexBuilder} (and {@link ProductQuantization} when configured), + * records the measured footprint of graph structures, PQ vectors/codebooks, and thread-local + * buffers, then extrapolates the expected memory cost for larger datasets. Use this when planning + * or capacity-sizing an index rather than when inspecting already-instantiated objects. + */ +public final class MemoryCostEstimator { + private static final VectorTypeSupport VTS = VectorizationProvider.getInstance().getVectorTypeSupport(); + + private static final int MAX_SAMPLE_SIZE = 10_000; + private static final double DEFAULT_MARGIN_FRACTION = 0.20; // 20% + + private MemoryCostEstimator() { + } + + /** + * Configuration describing the index build and optional PQ settings. + */ + public static class IndexConfig { + public final int dimension; + public final int maxDegree; + public final float overflowRatio; + public final boolean useHierarchy; + public final VectorSimilarityFunction similarityFunction; + public final Integer pqSubspaces; + public final Integer pqClusters; + public final Boolean pqCenter; + + public IndexConfig(int dimension, + int maxDegree, + float overflowRatio, + boolean useHierarchy, + VectorSimilarityFunction similarityFunction, + Integer pqSubspaces, + Integer pqClusters, + Boolean pqCenter) { + this.dimension = dimension; + this.maxDegree = maxDegree; + this.overflowRatio = overflowRatio; + this.useHierarchy = useHierarchy; + this.similarityFunction = similarityFunction; + this.pqSubspaces = pqSubspaces; + this.pqClusters = pqClusters; + this.pqCenter = pqCenter; + } + + public static IndexConfig defaultConfig(int dimension) { + return new IndexConfig( + dimension, + 16, + 1.5f, + true, + VectorSimilarityFunction.EUCLIDEAN, + 16, + 256, + Boolean.TRUE + ); + } + + public static IndexConfig withoutPQ(int dimension, int maxDegree, boolean useHierarchy) { + return new IndexConfig( + dimension, + maxDegree, + 1.5f, + useHierarchy, + VectorSimilarityFunction.EUCLIDEAN, + null, + null, + null + ); + } + + public boolean usesPQ() { + return pqSubspaces != null && pqClusters != null && pqCenter != null; + } + } + + /** + * Captures per-vector and fixed costs derived from a sample build. + */ + public static class MemoryModel { + public final IndexConfig config; + public final int sampleSize; + + public final long bytesPerNodeGraph; + public final long bytesPerNodePQ; + public final long fixedCodebookBytes; + public final long fixedGraphOverhead; + public final double hierarchyFactor; + public final long bytesPerThreadIndexing; + public final long bytesPerThreadSearch; + private final double marginFraction; + + public MemoryModel(IndexConfig config, + int sampleSize, + long bytesPerNodeGraph, + long fixedGraphOverhead, + double hierarchyFactor, + long bytesPerNodePQ, + long fixedCodebookBytes) { + this.config = config; + this.sampleSize = sampleSize; + this.bytesPerNodeGraph = Math.max(0L, bytesPerNodeGraph); + this.fixedGraphOverhead = Math.max(0L, fixedGraphOverhead); + this.hierarchyFactor = hierarchyFactor; + + if (config.usesPQ()) { + this.fixedCodebookBytes = Math.max(0L, fixedCodebookBytes); + this.bytesPerNodePQ = Math.max(0L, bytesPerNodePQ); + } else { + this.fixedCodebookBytes = 0L; + this.bytesPerNodePQ = 0L; + } + + this.bytesPerThreadIndexing = estimateThreadBuffersIndexing(config); + this.bytesPerThreadSearch = estimateThreadBuffersSearch(config); + this.marginFraction = DEFAULT_MARGIN_FRACTION; + } + + private static long estimateThreadBuffersIndexing(IndexConfig config) { + int beamWidth = 100; + int scratchSize = Math.max(beamWidth, config.maxDegree + 1); + + int objectHeader = 16; + int referenceBytes = 8; + int arrayHeader = 16; + + long nodeArrayBytes = objectHeader + Integer.BYTES + (2L * referenceBytes) + (2L * arrayHeader) + + (long) scratchSize * (Integer.BYTES + Float.BYTES); + long twoNodeArrays = 2 * nodeArrayBytes; + + long nodeQueueBytes = objectHeader + referenceBytes + arrayHeader + 100L * Long.BYTES; + long graphSearcherBytes = 3 * nodeQueueBytes + (objectHeader + 8192L); + long ravvWrapperBytes = objectHeader + 2L * referenceBytes; + + return twoNodeArrays + graphSearcherBytes + ravvWrapperBytes; + } + + private static long estimateThreadBuffersSearch(IndexConfig config) { + int objectHeader = 16; + int referenceBytes = 8; + int arrayHeader = 16; + + long nodeQueueBytes = objectHeader + referenceBytes + arrayHeader + 100L * Long.BYTES; + long graphSearcherBytes = 3 * nodeQueueBytes + (objectHeader + 4096L); + + long pqPartials = 0L; + if (config.usesPQ()) { + pqPartials = (long) config.pqSubspaces * 256 * Float.BYTES; + } + + return graphSearcherBytes + pqPartials; + } + + public Estimate estimateBytes(int numVectors) { + long graphBytes = fixedGraphOverhead + (long) (bytesPerNodeGraph * numVectors * hierarchyFactor); + long pqBytes = config.usesPQ() ? fixedCodebookBytes + bytesPerNodePQ * numVectors : 0L; + return estimate(graphBytes + pqBytes); + } + + public Estimate estimateBytesWithIndexingBuffers(int numVectors, int numThreads) { + Estimate base = estimateBytes(numVectors); + long adjusted = base.value + bytesPerThreadIndexing * (long) numThreads; + return estimate(adjusted); + } + + public Estimate estimateBytesWithSearchBuffers(int numVectors, int numThreads) { + Estimate base = estimateBytes(numVectors); + long adjusted = base.value + bytesPerThreadSearch * (long) numThreads; + return estimate(adjusted); + } + + private Estimate estimate(long value) { + return new Estimate(value, marginFraction); + } + } + + /** + * Build a sample index to derive a {@link MemoryModel} for the supplied configuration. The + * sample size should remain modest (default cap: 10 000) because this method actually + * runs {@link GraphIndexBuilder} and (optionally) {@link ProductQuantization} to gather real + * measurements. If you only need object-level metrics for already instantiated structures, + * prefer {@link io.github.jbellis.jvector.util.RamUsageEstimator} instead. + */ + public static MemoryModel createModel(IndexConfig config, int sampleSize) throws Exception { + if (sampleSize <= 0) { + throw new IllegalArgumentException("sampleSize must be positive; received " + sampleSize); + } + if (sampleSize > MAX_SAMPLE_SIZE) { + throw new IllegalArgumentException( + "sampleSize " + sampleSize + " exceeds the maximum " + MAX_SAMPLE_SIZE + + "; reduce it to keep MemoryCostEstimator runtime manageable."); + } + + List> vectors = generateRandomVectors(config.dimension, sampleSize); + RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, config.dimension); + + long graphBytes; + long pqBytes = 0L; + long fixedCodebookBytes = 0L; + long bytesPerNodePQ = 0L; + long bytesPerNodeGraph; + long fixedGraphOverhead; + double hierarchyFactor; + + BuildScoreProvider bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, config.similarityFunction); + try (GraphIndexBuilder builder = new GraphIndexBuilder( + bsp, + ravv.dimension(), + config.maxDegree, + 100, + config.overflowRatio, + 1.2f, + config.useHierarchy, + true + )) { + ImmutableGraphIndex graph = builder.build(ravv); + graphBytes = graph.ramBytesUsed(); + + if (!(graph instanceof OnHeapGraphIndex)) { + throw new IllegalStateException("MemoryCostEstimator expects an OnHeapGraphIndex sample"); + } + OnHeapGraphIndex onHeapGraph = (OnHeapGraphIndex) graph; + + long sumNodeBytes = 0L; + long level0NodeBytes = 0L; + for (int level = 0; level <= onHeapGraph.getMaxLevel(); level++) { + long perNode = onHeapGraph.ramBytesUsedOneNode(level); + long nodesAtLevel = onHeapGraph.size(level); + long levelBytes = perNode * nodesAtLevel; + sumNodeBytes += levelBytes; + if (level == 0) { + level0NodeBytes = levelBytes; + } + } + + fixedGraphOverhead = Math.max(0L, graphBytes - sumNodeBytes); + bytesPerNodeGraph = sampleSize == 0 ? 0L : level0NodeBytes / sampleSize; + hierarchyFactor = level0NodeBytes == 0 + ? 1.0 + : Math.max(1.0, (double) sumNodeBytes / (double) level0NodeBytes); + if (!config.useHierarchy) { + hierarchyFactor = 1.0; + } + + if (config.usesPQ()) { + ProductQuantization pq = ProductQuantization.compute( + ravv, + config.pqSubspaces, + config.pqClusters, + config.pqCenter, + KMeansPlusPlusClusterer.UNWEIGHTED, + PhysicalCoreExecutor.pool(), + ForkJoinPool.commonPool() + ); + PQVectors pqVectors = pq.encodeAll(ravv, PhysicalCoreExecutor.pool()); + pqBytes = pqVectors.ramBytesUsed(); + fixedCodebookBytes = pq.ramBytesUsed(); + bytesPerNodePQ = sampleSize == 0 ? 0L : Math.max(0L, (pqBytes - fixedCodebookBytes) / sampleSize); + } + } + + return new MemoryModel( + config, + sampleSize, + bytesPerNodeGraph, + fixedGraphOverhead, + hierarchyFactor, + bytesPerNodePQ, + fixedCodebookBytes + ); + } + + /** + * Represents a point estimate with a relative margin of error. + */ + public static final class Estimate { + private final long value; + private final double marginFraction; + + private Estimate(long value, double marginFraction) { + this.value = value; + this.marginFraction = marginFraction; + } + + /** Central estimate in bytes. */ + public long value() { + return value; + } + + /** Margin of error in bytes (computed as value * marginFraction, rounded). */ + public long marginBytes() { + return Math.round(value * marginFraction); + } + + /** Relative margin as a fraction (e.g., 0.20 == ±20%). */ + public double marginFraction() { + return marginFraction; + } + + /** Lower bound (value - margin). */ + public long lowerBound() { + return Math.max(0L, value - marginBytes()); + } + + /** Upper bound (value + margin). */ + public long upperBound() { + return value + marginBytes(); + } + } + + private static List> generateRandomVectors(int dimension, int count) { + List> vectors = new ArrayList<>(count); + Random rng = ThreadLocalRandom.current(); + + for (int i = 0; i < count; i++) { + VectorFloat vec = VTS.createFloatVector(dimension); + for (int d = 0; d < dimension; d++) { + vec.set(d, rng.nextFloat() * 2f - 1f); + } + vectors.add(vec); + } + + return vectors; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/RamUsageEstimator.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/RamUsageEstimator.java index 0bdb1763b..1cf462bd8 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/RamUsageEstimator.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/RamUsageEstimator.java @@ -44,6 +44,10 @@ *

This class uses assumptions that were discovered for the Hotspot virtual machine. If you use a * non-OpenJDK/Oracle-based JVM, the measurements may be slightly wrong. * + *

For predictive, configuration-driven sizing of JVector indexes (as opposed to inspecting + * already instantiated objects) see {@code devdocs/memory_estimator.md} and + * {@link io.github.jbellis.jvector.MemoryCostEstimator}. + * * @see #shallowSizeOf(Object) * @see #shallowSizeOfInstance(Class) */ diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/MemoryCostEstimatorAccuracyTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/MemoryCostEstimatorAccuracyTest.java new file mode 100644 index 000000000..190608388 --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/MemoryCostEstimatorAccuracyTest.java @@ -0,0 +1,423 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import io.github.jbellis.jvector.LuceneTestCase; +import io.github.jbellis.jvector.MemoryCostEstimator; +import io.github.jbellis.jvector.MemoryCostEstimator.IndexConfig; +import io.github.jbellis.jvector.MemoryCostEstimator.MemoryModel; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.graph.OnHeapGraphIndex; +import io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.util.PhysicalCoreExecutor; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ForkJoinPool; + +import static org.junit.Assert.assertTrue; + +/** + * Validates {@link MemoryCostEstimator} projections across multiple deployment styles and vector + * dimensionalities (64 through 4096). Sample sizes scale with dimensionality to keep runtime + * reasonable (see {@link #sampleSizeFor(int)}). Recent runs produced the following relative errors + * (|estimate − measured|/measured): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
Relative error (%) by configuration and dimensionality
Configuration64d128d256d512d1024d2048d4096d
Hierarchical, no PQ (M=16)1.591.591.590.600.800.801.65
Hierarchical, PQ (M=16, 16×256 ≤1024d, 32×256 ≥2048d)1.451.401.290.400.040.160.09
Flat, no PQ (M=32)0.950.950.950.910.840.690.00
Flat, PQ (M=24, 16×256 ≤1024d, 32×256 ≥2048d)1.211.171.100.880.550.230.00
Hierarchical, high-M (M=48)0.100.210.510.190.420.600.96
Hierarchical, cosine2.261.591.590.600.800.801.65
+ * All scenarios remain well within the ±5 % tolerance enforced by the assertions. The full + * sweep (six configurations × seven dimensionalities) completes in ~23 s on an M2 Max laptop; + * expect higher runtimes if CPU parallelism or vector acceleration is limited. If you instead + * force every run to use the 10 000 vector sample cap, the 64d scenarios finish in a few seconds + * while the 4096d cases stretch past a minute apiece because Product Quantization dominates wall + * time at the higher dimensionalities. + */ +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class MemoryCostEstimatorAccuracyTest extends LuceneTestCase { + private static final VectorTypeSupport VTS = VectorizationProvider.getInstance().getVectorTypeSupport(); + + private static final int[] DIMENSIONS = {64, 128, 256, 512, 1024, 2048, 4096}; + private static final int BASE_SAMPLE_SIZE = 1_000; + private static final int BASE_VERIFICATION_SIZE = 5_000; + private static final double TOLERANCE = 0.05; // 5% relative error (roughly 2x worst observed) + + @Test + public void testEstimateAccuracyHierarchicalWithPQ() throws Exception { + for (int dimension : DIMENSIONS) { + IndexConfig config = new IndexConfig( + dimension, + 16, + 1.5f, + true, + VectorSimilarityFunction.EUCLIDEAN, + pqSubspacesFor(dimension), + 256, + Boolean.TRUE + ); + + verifyEstimateAccuracy(config, dimension, "hierarchical+PQ"); + } + } + + @Test + public void testEstimateAccuracyHierarchicalWithoutPQ() throws Exception { + for (int dimension : DIMENSIONS) { + IndexConfig config = IndexConfig.withoutPQ(dimension, 16, true); + verifyEstimateAccuracy(config, dimension, "hierarchical"); + } + } + + @Test + public void testEstimateAccuracyFlatNoPQ() throws Exception { + for (int dimension : DIMENSIONS) { + IndexConfig config = IndexConfig.withoutPQ(dimension, 32, false); + verifyEstimateAccuracy(config, dimension, "flat"); + } + } + + @Test + public void testEstimateAccuracyFlatPQ() throws Exception { + for (int dimension : DIMENSIONS) { + IndexConfig config = new IndexConfig( + dimension, + 24, + 1.4f, + false, + VectorSimilarityFunction.EUCLIDEAN, + pqSubspacesFor(dimension) / 2, + 256, + Boolean.FALSE + ); + + verifyEstimateAccuracy(config, dimension, "flat+PQ"); + } + } + + @Test + public void testEstimateAccuracyHierarchicalHighDegree() throws Exception { + for (int dimension : DIMENSIONS) { + IndexConfig config = new IndexConfig( + dimension, + 48, + 1.5f, + true, + VectorSimilarityFunction.EUCLIDEAN, + null, + null, + null + ); + + verifyEstimateAccuracy(config, dimension, "hierarchical-highM"); + } + } + + @Test + public void testEstimateAccuracyHierarchicalCosine() throws Exception { + for (int dimension : DIMENSIONS) { + IndexConfig config = new IndexConfig( + dimension, + 16, + 1.5f, + true, + VectorSimilarityFunction.COSINE, + null, + null, + null + ); + + verifyEstimateAccuracy(config, dimension, "hierarchical-cosine"); + } + } + + private void verifyEstimateAccuracy(IndexConfig config, int dimension, String label) throws Exception { + int sampleSize = Math.max(sampleSizeFor(dimension), minimumSampleSize(config)); + int verificationSize = verificationSizeFor(dimension); + + MemoryModel model = buildModel(config, sampleSize, 7L); + Measurement measurement = measure(config, verificationSize, 13L); + + long estimatedBytes = model.estimateBytes(verificationSize).value(); + long measuredBytes = measurement.totalBytes(); + double relativeError = Math.abs(estimatedBytes - measuredBytes) / (double) measuredBytes; + + System.out.printf( + "Memory estimator: dim=%d, %s -> estimated=%d, measured=%d, error=%.2f%%%n", + dimension, + label, + estimatedBytes, + measuredBytes, + relativeError * 100.0 + ); + + assertTrue( + String.format( + "Estimated bytes=%d, measured bytes=%d, relative error=%.2f%%", + estimatedBytes, + measuredBytes, + relativeError * 100.0 + ), + relativeError <= TOLERANCE + ); + } + + private MemoryModel buildModel(IndexConfig config, int sampleSize, long seed) throws Exception { + Measurement measurement = measure(config, sampleSize, seed); + return new MemoryModel( + config, + sampleSize, + measurement.bytesPerNodeGraph(), + measurement.fixedGraphOverhead(), + measurement.hierarchyFactor(), + measurement.bytesPerNodePQ(), + measurement.fixedCodebookBytes() + ); + } + + private int sampleSizeFor(int dimension) { + // Scaling the sample size keeps 64d runs sub-second while preventing 4096d + PQ + // configurations from ballooning past a minute; the 10k cap is still available for + // deeper diagnostics, but not used in this quick regression sweep. + if (dimension >= 4096) { + return 200; + } + if (dimension >= 2048) { + return 300; + } + if (dimension >= 1024) { + return 400; + } + if (dimension >= 512) { + return 600; + } + return BASE_SAMPLE_SIZE; + } + + private int verificationSizeFor(int dimension) { + if (dimension >= 4096) { + return 1_000; + } + if (dimension >= 2048) { + return 2_000; + } + if (dimension >= 1024) { + return 3_000; + } + if (dimension >= 512) { + return 4_000; + } + return BASE_VERIFICATION_SIZE; + } + + private int minimumSampleSize(IndexConfig config) { + if (config.usesPQ()) { + return config.pqClusters; + } + return 1; + } + + private int pqSubspacesFor(int dimension) { + if (dimension >= 4096) { + return 64; + } + if (dimension >= 2048) { + return 48; + } + if (dimension >= 1024) { + return 32; + } + return 16; + } + + private Measurement measure(IndexConfig config, int size, long seed) throws Exception { + List> vectors = generateVectors(config.dimension, size, seed); + var ravv = new ListRandomAccessVectorValues(vectors, config.dimension); + + long graphBytes; + long pqBytes = 0L; + long bytesPerNodeGraph; + long fixedGraphOverhead; + double hierarchyFactor; + long bytesPerNodePQ = 0L; + long fixedCodebookBytes = 0L; + + BuildScoreProvider bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, config.similarityFunction); + try (GraphIndexBuilder builder = new GraphIndexBuilder( + bsp, + ravv.dimension(), + config.maxDegree, + 100, + config.overflowRatio, + 1.2f, + config.useHierarchy, + true + )) { + ImmutableGraphIndex graph = builder.build(ravv); + graphBytes = graph.ramBytesUsed(); + + OnHeapGraphIndex onHeapGraph = (OnHeapGraphIndex) graph; + long sumNodeBytes = 0L; + long level0NodeBytes = 0L; + for (int level = 0; level <= onHeapGraph.getMaxLevel(); level++) { + long perNode = onHeapGraph.ramBytesUsedOneNode(level); + long nodesAtLevel = onHeapGraph.size(level); + long levelBytes = perNode * nodesAtLevel; + sumNodeBytes += levelBytes; + if (level == 0) { + level0NodeBytes = levelBytes; + } + } + + fixedGraphOverhead = Math.max(0L, graphBytes - sumNodeBytes); + bytesPerNodeGraph = size == 0 ? 0L : level0NodeBytes / size; + hierarchyFactor = level0NodeBytes == 0 + ? 1.0 + : Math.max(1.0, (double) sumNodeBytes / (double) level0NodeBytes); + if (!config.useHierarchy) { + hierarchyFactor = 1.0; + } + + if (config.usesPQ()) { + ProductQuantization pq = ProductQuantization.compute( + ravv, + config.pqSubspaces, + config.pqClusters, + config.pqCenter, + KMeansPlusPlusClusterer.UNWEIGHTED, + PhysicalCoreExecutor.pool(), + ForkJoinPool.commonPool() + ); + + PQVectors pqVectors = pq.encodeAll(ravv, PhysicalCoreExecutor.pool()); + pqBytes = pqVectors.ramBytesUsed(); + fixedCodebookBytes = pq.ramBytesUsed(); + bytesPerNodePQ = size == 0 ? 0L : Math.max(0L, (pqBytes - fixedCodebookBytes) / size); + } + } + + return new Measurement( + graphBytes, + pqBytes, + bytesPerNodeGraph, + fixedGraphOverhead, + hierarchyFactor, + bytesPerNodePQ, + fixedCodebookBytes + ); + } + + private List> generateVectors(int dimension, int count, long seed) { + Random rng = new Random(seed); + List> vectors = new ArrayList<>(count); + + for (int i = 0; i < count; i++) { + VectorFloat vec = VTS.createFloatVector(dimension); + for (int d = 0; d < dimension; d++) { + vec.set(d, rng.nextFloat() * 2f - 1f); + } + vectors.add(vec); + } + + return vectors; + } + + private static class Measurement { + private final long graphBytes; + private final long pqBytes; + private final long bytesPerNodeGraph; + private final long fixedGraphOverhead; + private final double hierarchyFactor; + private final long bytesPerNodePQ; + private final long fixedCodebookBytes; + + private Measurement(long graphBytes, + long pqBytes, + long bytesPerNodeGraph, + long fixedGraphOverhead, + double hierarchyFactor, + long bytesPerNodePQ, + long fixedCodebookBytes) { + this.graphBytes = graphBytes; + this.pqBytes = pqBytes; + this.bytesPerNodeGraph = bytesPerNodeGraph; + this.fixedGraphOverhead = fixedGraphOverhead; + this.hierarchyFactor = hierarchyFactor; + this.bytesPerNodePQ = bytesPerNodePQ; + this.fixedCodebookBytes = fixedCodebookBytes; + } + + long bytesPerNodeGraph() { + return bytesPerNodeGraph; + } + + long fixedGraphOverhead() { + return fixedGraphOverhead; + } + + double hierarchyFactor() { + return hierarchyFactor; + } + + long bytesPerNodePQ() { + return bytesPerNodePQ; + } + + long fixedCodebookBytes() { + return fixedCodebookBytes; + } + + long totalBytes() { + return graphBytes + pqBytes; + } + } +}