From 1788350e6daab0b7d2d03adf60d5e72a7d8308d6 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Thu, 12 Oct 2023 09:03:30 -0500 Subject: [PATCH] add save and load methods for OHGI --- .../jvector/disk/OnDiskGraphIndex.java | 1 + .../jvector/graph/ConcurrentNeighborSet.java | 22 ++++++---- .../jvector/graph/GraphIndexBuilder.java | 30 +++++++++++++- .../jvector/graph/OnHeapGraphIndex.java | 40 ++++++++++++++++++- .../io/github/jbellis/jvector/TestUtil.java | 6 +-- .../jvector/disk/TestOnDiskGraphIndex.java | 1 - .../jbellis/jvector/graph/TestDeletions.java | 24 ++++++++++- 7 files changed, 109 insertions(+), 15 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/OnDiskGraphIndex.java index c34a0ff48..c0d189916 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/OnDiskGraphIndex.java @@ -172,6 +172,7 @@ public static void write(GraphIndex graph, throw new IllegalArgumentException("Run builder.cleanup() before writing the graph"); } } + var view = graph.getView(); // graph-level properties diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java index 809dcdd65..f40a8b5a9 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java @@ -50,17 +50,25 @@ public class ConcurrentNeighborSet { /** the proportion of edges that are diverse at alpha=1.0. updated by removeAllNonDiverse */ private float shortEdges = Float.NaN; - public ConcurrentNeighborSet( - int nodeId, int maxConnections, NeighborSimilarity similarity, float alpha) { + public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity) { + this(nodeId, maxConnections, similarity, 1.0f); + } + + public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity, float alpha) { + this(nodeId, maxConnections, similarity, alpha, new ConcurrentNeighborArray(maxConnections)); + } + + ConcurrentNeighborSet(int nodeId, + int maxConnections, + NeighborSimilarity similarity, + float alpha, + ConcurrentNeighborArray neighbors) + { this.nodeId = nodeId; this.maxConnections = maxConnections; this.similarity = similarity; - neighborsRef = new AtomicReference<>(new ConcurrentNeighborArray(maxConnections)); this.alpha = alpha; - } - - public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity) { - this(nodeId, maxConnections, similarity, 1.0f); + this.neighborsRef = new AtomicReference<>(neighbors); } private ConcurrentNeighborSet(ConcurrentNeighborSet old) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 9ffbcbe09..04ba033e4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -17,6 +17,7 @@ package io.github.jbellis.jvector.graph; import io.github.jbellis.jvector.annotations.VisibleForTesting; +import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.util.BitSet; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.PoolingSupport; @@ -24,6 +25,7 @@ import io.github.jbellis.jvector.vector.VectorEncoding; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import java.io.IOException; import java.util.AbstractMap; import java.util.Arrays; import java.util.Comparator; @@ -49,6 +51,7 @@ public class GraphIndexBuilder { private final VectorSimilarityFunction similarityFunction; private final float neighborOverflow; + private final float alpha; private final VectorEncoding vectorEncoding; private final PoolingSupport> graphSearcher; @@ -62,6 +65,7 @@ public class GraphIndexBuilder { // and `vectorsCopy` later on when defining the ScoreFunction for search. private final PoolingSupport> vectors; private final PoolingSupport> vectorsCopy; + private final NeighborSimilarity similarity; /** * Reads all the vectors from vector values, builds a graph connecting them by their dense @@ -90,6 +94,7 @@ public GraphIndexBuilder( this.vectorEncoding = Objects.requireNonNull(vectorEncoding); this.similarityFunction = Objects.requireNonNull(similarityFunction); this.neighborOverflow = neighborOverflow; + this.alpha = alpha; if (M <= 0) { throw new IllegalArgumentException("maxConn must be positive"); } @@ -98,7 +103,7 @@ public GraphIndexBuilder( } this.beamWidth = beamWidth; - NeighborSimilarity similarity = node1 -> { + similarity = node1 -> { try (var v = vectors.get(); var vc = vectorsCopy.get()) { T v1 = v.get().vectorValue(node1); return (NeighborSimilarity.ExactScoreFunction) node2 -> scoreBetween(v1, vc.get().vectorValue(node2)); @@ -453,4 +458,27 @@ public int length() { throw new UnsupportedOperationException(); } } + + public void load(RandomAccessReader in) throws IOException { + if (graph.size() != 0) { + throw new IllegalStateException("Cannot load into a non-empty graph"); + } + + int size = in.readInt(); + int entryNode = in.readInt(); + int maxDegree = in.readInt(); + + for (int i = 0; i < size; i++) { + int node = in.readInt(); + int nNeighbors = in.readInt(); + var ca = new ConcurrentNeighborSet.ConcurrentNeighborArray(maxDegree); + for (int j = 0; j < nNeighbors; j++) { + int neighbor = in.readInt(); + ca.addInOrder(neighbor, similarity.score(node, neighbor)); + } + graph.addNode(node, new ConcurrentNeighborSet(node, maxDegree, similarity, alpha, ca)); + } + + graph.updateEntryNode(entryNode); + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index c2560552b..4b921815a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -24,6 +24,8 @@ package io.github.jbellis.jvector.graph; +import io.github.jbellis.jvector.disk.Io; +import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.util.Accountable; import io.github.jbellis.jvector.util.BitSet; import io.github.jbellis.jvector.util.Bits; @@ -31,6 +33,8 @@ import io.github.jbellis.jvector.util.RamUsageEstimator; import org.jctools.maps.NonBlockingHashMapLong; +import java.io.DataOutput; +import java.io.IOException; import java.util.Arrays; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -56,7 +60,6 @@ public class OnHeapGraphIndex implements GraphIndex, Accountable { // max neighbors/edges per node final int maxDegree; private final BiFunction neighborFactory; - private boolean hasPurgedNodes; OnHeapGraphIndex( int M, BiFunction neighborFactory) { @@ -98,6 +101,14 @@ public void addNode(int node) { maxNodeId.accumulateAndGet(node, Math::max); } + /** + * Only for use by Builder loading a saved graph + */ + void addNode(int node, ConcurrentNeighborSet neighbors) { + nodes.put(node, neighbors); + maxNodeId.accumulateAndGet(node, Math::max); + } + /** * Mark the given node deleted. Does NOT remove the node from the graph. */ @@ -249,7 +260,6 @@ public BitSet getDeletedNodes() { void removeNode(int node) { nodes.remove(node); - hasPurgedNodes = true; } @Override @@ -315,4 +325,30 @@ public void close() { // no-op } } + + public void save(DataOutput out) throws IOException { + if (deletedNodes.cardinality() > 0) { + throw new IllegalStateException("Cannot save a graph that has deleted nodes. Call cleanup() first"); + } + + // graph-level properties + var view = getView(); + out.writeInt(size()); + out.writeInt(view.entryNode()); + out.writeInt(maxDegree()); + + // neighbors + for (var entry : nodes.entrySet()) { + var i = (int) (long) entry.getKey(); + var neighbors = entry.getValue().iterator(); + out.writeInt(i); + + out.writeInt(neighbors.size()); + int n = 0; + for ( ; n < neighbors.size(); n++) { + out.writeInt(neighbors.nextInt()); + } + assert !neighbors.hasNext(); + } + } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java index 95bedae84..e185c7268 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java @@ -116,10 +116,10 @@ public static byte[] randomVector8(Random random, int dim) { } public static void writeGraph(GraphIndex graph, RandomAccessVectorValues vectors, Path outputPath) throws IOException { - try (var indexOutputWriter = openFileForWriting(outputPath)) + try (var out = openFileForWriting(outputPath)) { - OnDiskGraphIndex.write(graph, vectors, Function.identity(), indexOutputWriter); - indexOutputWriter.flush(); + OnDiskGraphIndex.write(graph, vectors, Function.identity(), out); + out.flush(); } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/disk/TestOnDiskGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/disk/TestOnDiskGraphIndex.java index 77941732c..63fae269a 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/disk/TestOnDiskGraphIndex.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/disk/TestOnDiskGraphIndex.java @@ -35,7 +35,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Function; import static io.github.jbellis.jvector.TestUtil.getNeighborNodes; import static org.junit.Assert.assertArrayEquals; diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java index 296cd7989..d0dac65db 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java @@ -3,13 +3,21 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import io.github.jbellis.jvector.LuceneTestCase; import io.github.jbellis.jvector.TestUtil; +import io.github.jbellis.jvector.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.disk.SimpleMappedReader; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.vector.VectorEncoding; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import org.junit.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.Arrays; +import java.util.function.Function; +import static io.github.jbellis.jvector.TestUtil.assertGraphEquals; +import static io.github.jbellis.jvector.TestUtil.openFileForWriting; import static io.github.jbellis.jvector.graph.GraphIndexTestCase.createRandomFloatVectors; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; @@ -45,7 +53,7 @@ public void testMarkDeleted() { } @Test - public void testCleanup() { + public void testCleanup() throws IOException { // graph of 10 vectors int dimension = 2; var ravv = MockVectorValues.fromValues(createRandomFloatVectors(10, dimension, getRandom())); @@ -75,5 +83,19 @@ public void testCleanup() { var v = Arrays.copyOf(ravv.vectorValue(nodeToIsolate), ravv.dimension); var results = GraphSearcher.search(v, 10, ravv, VectorEncoding.FLOAT32, VectorSimilarityFunction.COSINE, graph, Bits.ALL); assertEquals(nodeToIsolate, results.getNodes()[0].node); + + // check that we can save and load the graph with "holes" from the deletion + var testDirectory = Files.createTempDirectory(this.getClass().getSimpleName()); + var outputPath = testDirectory.resolve("on_heap_graph"); + try (var out = openFileForWriting(outputPath)) { + graph.save(out); + out.flush(); + } + var b2 = new GraphIndexBuilder<>(ravv, VectorEncoding.FLOAT32, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f); + try (var marr = new SimpleMappedReader(outputPath.toAbsolutePath().toString())) { + b2.load(marr); + } + var reloadedGraph = b2.getGraph(); + assertGraphEquals(graph, reloadedGraph); } }