Skip to content

Commit

Permalink
add save and load methods for OHGI
Browse files Browse the repository at this point in the history
  • Loading branch information
jbellis committed Oct 12, 2023
1 parent 31a54e9 commit 1788350
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ public static <T> void write(GraphIndex<T> graph,
throw new IllegalArgumentException("Run builder.cleanup() before writing the graph");
}
}

var view = graph.getView();

// graph-level properties
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,25 @@ public class ConcurrentNeighborSet {
/** the proportion of edges that are diverse at alpha=1.0. updated by removeAllNonDiverse */
private float shortEdges = Float.NaN;

public ConcurrentNeighborSet(
int nodeId, int maxConnections, NeighborSimilarity similarity, float alpha) {
public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity) {
this(nodeId, maxConnections, similarity, 1.0f);
}

public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity, float alpha) {
this(nodeId, maxConnections, similarity, alpha, new ConcurrentNeighborArray(maxConnections));
}

ConcurrentNeighborSet(int nodeId,
int maxConnections,
NeighborSimilarity similarity,
float alpha,
ConcurrentNeighborArray neighbors)
{
this.nodeId = nodeId;
this.maxConnections = maxConnections;
this.similarity = similarity;
neighborsRef = new AtomicReference<>(new ConcurrentNeighborArray(maxConnections));
this.alpha = alpha;
}

public ConcurrentNeighborSet(int nodeId, int maxConnections, NeighborSimilarity similarity) {
this(nodeId, maxConnections, similarity, 1.0f);
this.neighborsRef = new AtomicReference<>(neighbors);
}

private ConcurrentNeighborSet(ConcurrentNeighborSet old) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.PoolingSupport;
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
import io.github.jbellis.jvector.vector.VectorEncoding;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;

import java.io.IOException;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Comparator;
Expand All @@ -49,6 +51,7 @@ public class GraphIndexBuilder<T> {

private final VectorSimilarityFunction similarityFunction;
private final float neighborOverflow;
private final float alpha;
private final VectorEncoding vectorEncoding;
private final PoolingSupport<GraphSearcher<?>> graphSearcher;

Expand All @@ -62,6 +65,7 @@ public class GraphIndexBuilder<T> {
// and `vectorsCopy` later on when defining the ScoreFunction for search.
private final PoolingSupport<RandomAccessVectorValues<T>> vectors;
private final PoolingSupport<RandomAccessVectorValues<T>> vectorsCopy;
private final NeighborSimilarity similarity;

/**
* Reads all the vectors from vector values, builds a graph connecting them by their dense
Expand Down Expand Up @@ -90,6 +94,7 @@ public GraphIndexBuilder(
this.vectorEncoding = Objects.requireNonNull(vectorEncoding);
this.similarityFunction = Objects.requireNonNull(similarityFunction);
this.neighborOverflow = neighborOverflow;
this.alpha = alpha;
if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
}
Expand All @@ -98,7 +103,7 @@ public GraphIndexBuilder(
}
this.beamWidth = beamWidth;

NeighborSimilarity similarity = node1 -> {
similarity = node1 -> {
try (var v = vectors.get(); var vc = vectorsCopy.get()) {
T v1 = v.get().vectorValue(node1);
return (NeighborSimilarity.ExactScoreFunction) node2 -> scoreBetween(v1, vc.get().vectorValue(node2));
Expand Down Expand Up @@ -453,4 +458,27 @@ public int length() {
throw new UnsupportedOperationException();
}
}

public void load(RandomAccessReader in) throws IOException {
if (graph.size() != 0) {
throw new IllegalStateException("Cannot load into a non-empty graph");
}

int size = in.readInt();
int entryNode = in.readInt();
int maxDegree = in.readInt();

for (int i = 0; i < size; i++) {
int node = in.readInt();
int nNeighbors = in.readInt();
var ca = new ConcurrentNeighborSet.ConcurrentNeighborArray(maxDegree);
for (int j = 0; j < nNeighbors; j++) {
int neighbor = in.readInt();
ca.addInOrder(neighbor, similarity.score(node, neighbor));
}
graph.addNode(node, new ConcurrentNeighborSet(node, maxDegree, similarity, alpha, ca));
}

graph.updateEntryNode(entryNode);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@

package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.disk.Io;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.util.Accountable;
import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.GrowableBitSet;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import org.jctools.maps.NonBlockingHashMapLong;

import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
Expand All @@ -56,7 +60,6 @@ public class OnHeapGraphIndex<T> implements GraphIndex<T>, Accountable {
// max neighbors/edges per node
final int maxDegree;
private final BiFunction<Integer, Integer, ConcurrentNeighborSet> neighborFactory;
private boolean hasPurgedNodes;

OnHeapGraphIndex(
int M, BiFunction<Integer, Integer, ConcurrentNeighborSet> neighborFactory) {
Expand Down Expand Up @@ -98,6 +101,14 @@ public void addNode(int node) {
maxNodeId.accumulateAndGet(node, Math::max);
}

/**
* Only for use by Builder loading a saved graph
*/
void addNode(int node, ConcurrentNeighborSet neighbors) {
nodes.put(node, neighbors);
maxNodeId.accumulateAndGet(node, Math::max);
}

/**
* Mark the given node deleted. Does NOT remove the node from the graph.
*/
Expand Down Expand Up @@ -249,7 +260,6 @@ public BitSet getDeletedNodes() {

void removeNode(int node) {
nodes.remove(node);
hasPurgedNodes = true;
}

@Override
Expand Down Expand Up @@ -315,4 +325,30 @@ public void close() {
// no-op
}
}

public void save(DataOutput out) throws IOException {
if (deletedNodes.cardinality() > 0) {
throw new IllegalStateException("Cannot save a graph that has deleted nodes. Call cleanup() first");
}

// graph-level properties
var view = getView();
out.writeInt(size());
out.writeInt(view.entryNode());
out.writeInt(maxDegree());

// neighbors
for (var entry : nodes.entrySet()) {
var i = (int) (long) entry.getKey();
var neighbors = entry.getValue().iterator();
out.writeInt(i);

out.writeInt(neighbors.size());
int n = 0;
for ( ; n < neighbors.size(); n++) {
out.writeInt(neighbors.nextInt());
}
assert !neighbors.hasNext();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ public static byte[] randomVector8(Random random, int dim) {
}

public static <T> void writeGraph(GraphIndex<T> graph, RandomAccessVectorValues<T> vectors, Path outputPath) throws IOException {
try (var indexOutputWriter = openFileForWriting(outputPath))
try (var out = openFileForWriting(outputPath))
{
OnDiskGraphIndex.write(graph, vectors, Function.identity(), indexOutputWriter);
indexOutputWriter.flush();
OnDiskGraphIndex.write(graph, vectors, Function.identity(), out);
out.flush();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static io.github.jbellis.jvector.TestUtil.getNeighborNodes;
import static org.junit.Assert.assertArrayEquals;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import io.github.jbellis.jvector.LuceneTestCase;
import io.github.jbellis.jvector.TestUtil;
import io.github.jbellis.jvector.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.disk.SimpleMappedReader;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.vector.VectorEncoding;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import org.junit.Test;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.function.Function;

import static io.github.jbellis.jvector.TestUtil.assertGraphEquals;
import static io.github.jbellis.jvector.TestUtil.openFileForWriting;
import static io.github.jbellis.jvector.graph.GraphIndexTestCase.createRandomFloatVectors;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
Expand Down Expand Up @@ -45,7 +53,7 @@ public void testMarkDeleted() {
}

@Test
public void testCleanup() {
public void testCleanup() throws IOException {
// graph of 10 vectors
int dimension = 2;
var ravv = MockVectorValues.fromValues(createRandomFloatVectors(10, dimension, getRandom()));
Expand Down Expand Up @@ -75,5 +83,19 @@ public void testCleanup() {
var v = Arrays.copyOf(ravv.vectorValue(nodeToIsolate), ravv.dimension);
var results = GraphSearcher.search(v, 10, ravv, VectorEncoding.FLOAT32, VectorSimilarityFunction.COSINE, graph, Bits.ALL);
assertEquals(nodeToIsolate, results.getNodes()[0].node);

// check that we can save and load the graph with "holes" from the deletion
var testDirectory = Files.createTempDirectory(this.getClass().getSimpleName());
var outputPath = testDirectory.resolve("on_heap_graph");
try (var out = openFileForWriting(outputPath)) {
graph.save(out);
out.flush();
}
var b2 = new GraphIndexBuilder<>(ravv, VectorEncoding.FLOAT32, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f);
try (var marr = new SimpleMappedReader(outputPath.toAbsolutePath().toString())) {
b2.load(marr);
}
var reloadedGraph = b2.getGraph();
assertGraphEquals(graph, reloadedGraph);
}
}

0 comments on commit 1788350

Please sign in to comment.