Skip to content

Improve performance of reconnectOrphanedNodes #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 27, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.util.AtomicFixedBitSet;
import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.ExceptionUtils;
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
Expand All @@ -29,6 +30,8 @@
import io.github.jbellis.jvector.vector.types.VectorFloat;
import org.agrona.collections.IntArrayList;
import org.agrona.collections.IntArrayQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Closeable;
import java.io.IOException;
Expand Down Expand Up @@ -56,6 +59,8 @@
* that spawning a new Thread per call is not advisable. This includes virtual threads.
*/
public class GraphIndexBuilder implements Closeable {
private static final Logger logger = LoggerFactory.getLogger(GraphIndexBuilder.class);

private final int beamWidth;
private final ExplicitThreadLocal<NodeArray> naturalScratch;
private final ExplicitThreadLocal<NodeArray> concurrentScratch;
Expand Down Expand Up @@ -235,86 +240,98 @@ public void cleanup() {
}

private void reconnectOrphanedNodes() {
var searchPathNeighbors = new ConcurrentHashMap<Integer, NodeArray>();
// It's possible that reconnecting one node will result in disconnecting another, since we are maintaining
// the maxConnections invariant. So, we do a best effort of 3 loops. We claim the entry node as an
// already used connectionTarget so that we don't clutter its edge list.
var connectionTargets = ConcurrentHashMap.<Integer>newKeySet();
connectionTargets.add(graph.entry());
for (int i = 0; i < 3; i++) {
// find all nodes reachable from the entry node
// Reconnection is best-effort: reconnecting one node may result in disconnecting another, since we are maintaining
// the maxConnections invariant. So, we do a maximum of 5 loops.
for (int i = 0; i < 5; i++) {
// determine the nodes reachable from the entry point at the start of this pass
var connectedNodes = new AtomicFixedBitSet(graph.getIdUpperBound());
connectedNodes.set(graph.entry());
ConcurrentNeighborMap.Neighbors self1 = graph.getNeighbors(graph.entry());
var entryNeighbors = (NodeArray) self1;
var entryNeighbors = graph.getNeighbors(graph.entry());
parallelExecutor.submit(() -> IntStream.range(0, entryNeighbors.size()).parallel().forEach(node -> findConnected(connectedNodes, entryNeighbors.getNode(node)))).join();
// Set of nodes that may be used as connection targets, initialized to all nodes reachable from the entry
// point. But since reconnection edges are usually worse (by distance and/or diversity) than the original
// ones, we update this as edges are added to avoid reusing the same target node more than once.
var connectionTargets = connectedNodes.copy();
// It's particularly important for the entry node to have high quality edges, so mark it
// as an invalid Target before we start.
connectionTargets.clear(graph.entry());

// Gather basic debug information about efficacy/efficiency of reconnection attempts
var nReconnectAttempts = new AtomicInteger();
var nReconnectedViaNeighbors = new AtomicInteger();
var nResumesRun = new AtomicInteger();
var nReconnectedViaSearch = new AtomicInteger();

// reconnect unreachable nodes
var nReconnected = new AtomicInteger();
simdExecutor.submit(() -> IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(node -> {
if (connectedNodes.get(node) || !graph.containsNode(node)) {
return;
}
nReconnected.incrementAndGet();
nReconnectAttempts.incrementAndGet();

// first, attempt to connect one of our own neighbors to us
// first, attempt to connect one of our own connected neighbors to us. Filtering
// to connected nodes tends to help for partitioned graphs with large partitions.
ConcurrentNeighborMap.Neighbors self = graph.getNeighbors(node);
var neighbors = (NodeArray) self;
if (connectToClosestNeighbor(node, neighbors, connectionTargets)) {
if (connectToClosestNeighbor(node, neighbors, connectionTargets) != null) {
nReconnectedViaNeighbors.incrementAndGet();
return;
}

// no unused candidate found -- search for more neighbors and try again
neighbors = searchPathNeighbors.get(node);
// run search again if neighbors is empty or if every neighbor is already in connection targets
if (neighbors == null || isSubset(neighbors, connectionTargets)) {
SearchResult result;
try (var gs = searchers.get()) {
var excludeBits = createExcludeBits(node, connectionTargets);
var ssp = scoreProvider.searchProviderFor(node);
int ep = graph.entry();
result = gs.searchInternal(ssp, beamWidth, beamWidth, 0.0f, 0.0f, ep, excludeBits);
} catch (Exception e) {
throw new RuntimeException(e);
}
neighbors = new NodeArray(result.getNodes().length);
toScratchCandidates(result.getNodes(), neighbors);
searchPathNeighbors.put(node, neighbors);
// if we can't find a connected neighbor to reconnect to, we'll have to search. We start with a small
// search, and we resume the search in a bounded loop to try to find an eligible connection target.
// This significantly improves behavior for large (1M+ node) partitioned graphs. We don't add
// connectionTargets to excludeBits because large partitions lead to excessively large excludeBits,
// greatly degrading search performance.
SearchResult result;
try (var gs = searchers.get()) {
var excludeBits = Bits.inverseOf(connectionTargets);
var ssp = scoreProvider.searchProviderFor(node);
int ep = graph.entry();
result = gs.searchInternal(ssp, beamWidth, beamWidth, 0.0f, 0.0f, ep, excludeBits);
} catch (Exception e) {
throw new RuntimeException(e);
}
neighbors = new NodeArray(result.getNodes().length);
toScratchCandidates(result.getNodes(), neighbors);
var reconnected = connectToClosestNeighbor(node, neighbors, connectionTargets);
if (reconnected != null) {
nReconnectedViaSearch.incrementAndGet();
// since we went to the trouble of finding the closest available neighbor, let `backlink`
// check to see if it should be added as an edge to the original node as well
var na = new NodeArray(1);
na.addInOrder(reconnected.node, reconnected.score);
graph.nodes.backlink(na, node, 1.0f);
}
connectToClosestNeighbor(node, neighbors, connectionTargets);
})).join();
if (nReconnected.get() == 0) {
break;
}
}
}

private boolean isSubset(NodeArray neighbors, Set<Integer> nodeIds) {
for (int i = 0; i < neighbors.size(); i++) {
if (!nodeIds.contains(neighbors.getNode(i))) {
return false;
logger.debug("Reconnecting {} nodes out of {} on pass {}. {} neighbor reconnects. {} searches/resumes run. {} nodes reconnected via search",
nReconnectAttempts.get(), graph.size(), i, nReconnectedViaNeighbors.get(), nResumesRun.get(), nReconnectedViaSearch.get());

if (nReconnectAttempts.get() == 0) {
break;
}
}
return true;
}

/**
* Connect `node` to the closest neighbor that is not already a connection target.
* @return true if such a neighbor was found.
* Connect `node` to the closest connected neighbor that is not already a connection target.
*
* @return the neighbor id if such a neighbor was found.
*/
private boolean connectToClosestNeighbor(int node, NodeArray neighbors, Set<Integer> connectionTargets) {
// connect this node to the closest neighbor that hasn't already been used as a connection target
private SearchResult.NodeScore connectToClosestNeighbor(int node, NodeArray neighbors, BitSet connectionTargets) {
// connect this node to the closest connected neighbor that hasn't already been used as a connection target
// (since this edge is likely to be the "worst" one in that target's neighborhood, it's likely to be
// overwritten by the next node to need reconnection if we don't choose a unique target)
for (int i = 0; i < neighbors.size(); i++) {
var neighborNode = neighbors.getNode(i);
if (!connectionTargets.get(neighborNode))
continue;

var neighborScore = neighbors.getScore(i);
if (connectionTargets.add(neighborNode)) {
graph.nodes.insertEdgeNotDiverse(neighborNode, node, neighborScore);
return true;
}
graph.nodes.insertEdgeNotDiverse(neighborNode, node, neighborScore);
connectionTargets.clear(neighborNode);
return new SearchResult.NodeScore(neighborNode, neighborScore);
}
return false;
return null;
}

private void findConnected(AtomicFixedBitSet connectedNodes, int start) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,5 +185,13 @@ public long ramBytesUsed() {
long storageSize = (long) storage.length() * longSizeInBytes + arrayOverhead;
return BASE_RAM_BYTES_USED + storageSize;
}

public AtomicFixedBitSet copy() {
AtomicFixedBitSet copy = new AtomicFixedBitSet(length());
for (int i = 0; i < storage.length(); i++) {
copy.storage.set(i, storage.get(i));
}
return copy;
}
}

5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@
<artifactId>agrona</artifactId>
<version>1.20.0</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>2.0.16</version>
</dependency>
</dependencies>
<dependencyManagement>
<dependencies>
Expand Down
Loading