From 66d78756b463fe0c532376d6525dfa846cbcb0db Mon Sep 17 00:00:00 2001 From: Joel Knighton Date: Thu, 19 Sep 2024 10:42:57 -0500 Subject: [PATCH] Improve performance of reconnectOrphanedNodes by limiting neighbor connection targets to nodes that were reachable by the entry node at the start of the pass. Instead of using exclusion bits for connection targets, perform several rounds of resumes and post-filter for connectionTargets. Log basic debugging information when reconnecting orphaned nodes by introducing slf4j-api. --- .../jvector/graph/GraphIndexBuilder.java | 96 ++++++++++++------- pom.xml | 5 + 2 files changed, 66 insertions(+), 35 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 6a98ef8ce..8e390bbfe 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -21,6 +21,7 @@ import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.util.AtomicFixedBitSet; +import io.github.jbellis.jvector.util.BitSet; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.ExceptionUtils; import io.github.jbellis.jvector.util.ExplicitThreadLocal; @@ -29,6 +30,9 @@ import io.github.jbellis.jvector.vector.types.VectorFloat; import org.agrona.collections.IntArrayList; import org.agrona.collections.IntArrayQueue; +import org.agrona.collections.IntHashSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.Closeable; import java.io.IOException; @@ -55,6 +59,8 @@ * that spawning a new Thread per call is not advisable. This includes virtual threads. */ public class GraphIndexBuilder implements Closeable { + private static final Logger logger = LoggerFactory.getLogger(GraphIndexBuilder.class); + private final int beamWidth; private final ExplicitThreadLocal naturalScratch; private final ExplicitThreadLocal concurrentScratch; @@ -234,13 +240,12 @@ public void cleanup() { } private void reconnectOrphanedNodes() { - var searchPathNeighbors = new ConcurrentHashMap(); // It's possible that reconnecting one node will result in disconnecting another, since we are maintaining - // the maxConnections invariant. So, we do a best effort of 3 loops. We claim the entry node as an + // the maxConnections invariant. So, we do a best effort of 5 loops. We claim the entry node as an // already used connectionTarget so that we don't clutter its edge list. var connectionTargets = ConcurrentHashMap.newKeySet(); connectionTargets.add(graph.entry()); - for (int i = 0; i < 3; i++) { + for (int i = 0; i < 5; i++) { // find all nodes reachable from the entry node var connectedNodes = new AtomicFixedBitSet(graph.getIdUpperBound()); connectedNodes.set(graph.entry()); @@ -248,65 +253,86 @@ private void reconnectOrphanedNodes() { var entryNeighbors = (NodeArray) self1; parallelExecutor.submit(() -> IntStream.range(0, entryNeighbors.size()).parallel().forEach(node -> findConnected(connectedNodes, entryNeighbors.getNode(node)))).join(); - // reconnect unreachable nodes - var nReconnected = new AtomicInteger(); + // Gather basic debug information about efficacy/efficiency of reconnection attempts + var nReconnectAttempts = new AtomicInteger(); + var nReconnectedViaNeighbors = new AtomicInteger(); + var nResumesRun = new AtomicInteger(); + var nReconnectedViaSearch = new AtomicInteger(); + simdExecutor.submit(() -> IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(node -> { if (connectedNodes.get(node) || !graph.containsNode(node)) { return; } - nReconnected.incrementAndGet(); + nReconnectAttempts.incrementAndGet(); - // first, attempt to connect one of our own neighbors to us + // first, attempt to connect one of our own connected neighbors to us. Filtering + // to connected nodes tends to help for partitioned graphs with large partitions. ConcurrentNeighborMap.Neighbors self = graph.getNeighbors(node); var neighbors = (NodeArray) self; - if (connectToClosestNeighbor(node, neighbors, connectionTargets)) { + if (connectToClosestNeighbor(node, neighbors, connectionTargets, connectedNodes)) { + nReconnectedViaNeighbors.incrementAndGet(); return; } - // no unused candidate found -- search for more neighbors and try again - neighbors = searchPathNeighbors.get(node); - // run search again if neighbors is empty or if every neighbor is already in connection targets - if (neighbors == null || isSubset(neighbors, connectionTargets)) { - SearchResult result; - try (var gs = searchers.get()) { - var excludeBits = createExcludeBits(node, connectionTargets); - var ssp = scoreProvider.searchProviderFor(node); - int ep = graph.entry(); - result = gs.searchInternal(ssp, beamWidth, beamWidth, 0.0f, 0.0f, ep, excludeBits); - } catch (Exception e) { - throw new RuntimeException(e); + // if we can't find a connected neighbor to reconnect to, we'll have to search. We start with a small + // search, and we resume the search in a bounded loop to try to find an eligible connection target. + // This significantly improves behavior for large (1M+ node) partitioned graphs. We don't add + // connectionTargets to excludeBits because large partitions lead to excessively large excludeBits, + // greatly degrading search performance. + SearchResult result; + try (var gs = searchers.get()) { + IntHashSet neighborIds = new IntHashSet(); + for (int j = 0; j < neighbors.size(); j++) { + neighborIds.add(neighbors.getNode(j)); } + + var excludeBits = createExcludeBits(node, neighborIds); + var ssp = scoreProvider.searchProviderFor(node); + int ep = graph.entry(); + result = gs.searchInternal(ssp, beamWidth, beamWidth, 0.0f, 0.0f, ep, excludeBits); neighbors = new NodeArray(result.getNodes().length); toScratchCandidates(result.getNodes(), neighbors); - searchPathNeighbors.put(node, neighbors); + var reconnected = connectToClosestNeighbor(node, neighbors, connectionTargets, connectedNodes); + + var j = 0; + while (!reconnected && j < 50) { + j++; + nResumesRun.incrementAndGet(); + result = gs.resume(beamWidth, beamWidth); + toScratchCandidates(result.getNodes(), neighbors); + reconnected = connectToClosestNeighbor(node, neighbors, connectionTargets, connectedNodes); + } + + if (reconnected) + nReconnectedViaSearch.incrementAndGet(); + } catch (Exception e) + { + throw new RuntimeException(e); } - connectToClosestNeighbor(node, neighbors, connectionTargets); })).join(); - if (nReconnected.get() == 0) { - break; - } - } - } - private boolean isSubset(NodeArray neighbors, Set nodeIds) { - for (int i = 0; i < neighbors.size(); i++) { - if (!nodeIds.contains(neighbors.getNode(i))) { - return false; + logger.debug("Reconnecting {} nodes out of {} on pass {}. {} neighbor reconnects. {} searches/resumes run. {} nodes reconnected via search", + nReconnectAttempts.get(), graph.size(), i, nReconnectedViaNeighbors.get(), nResumesRun.get(), nReconnectedViaSearch.get()); + + if (nReconnectAttempts.get() == 0) { + break; } } - return true; } /** - * Connect `node` to the closest neighbor that is not already a connection target. + * Connect `node` to the closest connected neighbor that is not already a connection target. * @return true if such a neighbor was found. */ - private boolean connectToClosestNeighbor(int node, NodeArray neighbors, Set connectionTargets) { - // connect this node to the closest neighbor that hasn't already been used as a connection target + private boolean connectToClosestNeighbor(int node, NodeArray neighbors, Set connectionTargets, BitSet connectedNodes) { + // connect this node to the closest connected neighbor that hasn't already been used as a connection target // (since this edge is likely to be the "worst" one in that target's neighborhood, it's likely to be // overwritten by the next node to need reconnection if we don't choose a unique target) for (int i = 0; i < neighbors.size(); i++) { var neighborNode = neighbors.getNode(i); + if (!connectedNodes.get(neighborNode)) + continue; + var neighborScore = neighbors.getScore(i); if (connectionTargets.add(neighborNode)) { graph.nodes.insertEdgeNotDiverse(neighborNode, node, neighborScore); diff --git a/pom.xml b/pom.xml index 7f09347f3..83fdc21d8 100644 --- a/pom.xml +++ b/pom.xml @@ -185,6 +185,11 @@ agrona 1.20.0 + + org.slf4j + slf4j-api + 2.0.16 +