Skip to content

Commit

Permalink
Adds DenseIntMap for building graph with much less contention. back t…
Browse files Browse the repository at this point in the history
…o zero dependency (#128)
  • Loading branch information
tjake authored Oct 18, 2023
1 parent 0a2ecb5 commit d31f1c8
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 119 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# JVector
<a href="https://trendshift.io/repositories/2946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/2946" alt="jbellis%2Fjvector | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>

JVector is a pure Java embedded vector search engine, used by [DataStax Astra DB](https://www.datastax.com/products/datastax-astra) and (soon) Apache Cassandra.
JVector is a pure Java, zero dependency, embedded vector search engine, used by [DataStax Astra DB](https://www.datastax.com/products/datastax-astra) and (soon) Apache Cassandra.

What is JVector?
- Algorithmic-fast. JVector uses state of the art graph algorithms inspired by DiskANN and related research that offer high recall and low latency.
Expand Down
8 changes: 0 additions & 8 deletions jvector-base/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,4 @@
</parent>
<artifactId>jvector-base</artifactId>
<name>Base</name>

<dependencies>
<dependency>
<groupId>org.jctools</groupId>
<artifactId>jctools-core</artifactId>
<version>4.0.1</version>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public long addGraphNode(int node, RandomAccessVectorValues<T> vectors) {

// do this before adding to in-progress, so a concurrent writer checking
// the in-progress set doesn't have to worry about uninitialized neighbor sets
graph.addNode(node);
ConcurrentNeighborSet newNodeNeighbors = graph.addNode(node);

insertionsInProgress.add(node);
ConcurrentSkipListSet<Integer> inProgressBefore = insertionsInProgress.clone();
Expand All @@ -174,7 +174,7 @@ public long addGraphNode(int node, RandomAccessVectorValues<T> vectors) {
// Update neighbors with these candidates.
var natural = getNaturalCandidates(candidates.getNodes(), naturalScratchPooled.get());
var concurrent = getConcurrentCandidates(node, inProgressBefore, concurrentScratchPooled.get(), vectors, vc.get());
updateNeighbors(node, natural, concurrent);
updateNeighbors(newNodeNeighbors, natural, concurrent);
graph.markComplete(node);
} finally {
insertionsInProgress.remove(node);
Expand Down Expand Up @@ -219,10 +219,9 @@ private int approximateMedioid() {
}
}

private void updateNeighbors(int node, NeighborArray natural, NeighborArray concurrent) {
ConcurrentNeighborSet neighbors = graph.getNeighbors(node);
neighbors.insertDiverse(natural, concurrent);
neighbors.backlink(graph::getNeighbors, neighborOverflow);
private void updateNeighbors(ConcurrentNeighborSet nodeNeighbors, NeighborArray natural, NeighborArray concurrent) {
nodeNeighbors.insertDiverse(natural, concurrent);
nodeNeighbors.backlink(graph::getNeighbors, neighborOverflow);
}

private NeighborArray getNaturalCandidates(SearchResult.NodeScore[] candidates, NeighborArray scratch) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.util.Accountable;
import io.github.jbellis.jvector.util.DenseIntMap;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import org.jctools.maps.NonBlockingHashMapLong;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.stream.IntStream;

/**
* An {@link GraphIndex} that offers concurrent access; for typical graphs you will get significant
Expand All @@ -42,9 +42,9 @@
public final class OnHeapGraphIndex<T> implements GraphIndex<T>, Accountable {

// the current graph entry node on the top level. -1 if not set
private final AtomicLong entryPoint = new AtomicLong(-1);
private final AtomicInteger entryPoint = new AtomicInteger(-1);

private final NonBlockingHashMapLong<ConcurrentNeighborSet> nodes;
private final DenseIntMap<ConcurrentNeighborSet> nodes;

// max neighbors/edges per node
final int nsize0;
Expand All @@ -54,8 +54,7 @@ public final class OnHeapGraphIndex<T> implements GraphIndex<T>, Accountable {
int M, BiFunction<Integer, Integer, ConcurrentNeighborSet> neighborFactory) {
this.neighborFactory = neighborFactory;
this.nsize0 = 2 * M;

this.nodes = new NonBlockingHashMapLong<>(1024);
this.nodes = new DenseIntMap<>(1024);
}

/**
Expand All @@ -67,6 +66,7 @@ public ConcurrentNeighborSet getNeighbors(int node) {
return nodes.get(node);
}


@Override
public int size() {
return nodes.size();
Expand All @@ -84,9 +84,12 @@ public int size() {
* <p>It is also the responsibility of the caller to ensure that each node is only added once.
*
* @param node the node to add, represented as an ordinal on the level 0.
* @return the neighbor set for this node
*/
public void addNode(int node) {
nodes.put(node, neighborFactory.apply(node, maxDegree()));
public ConcurrentNeighborSet addNode(int node) {
ConcurrentNeighborSet newNeighborSet = neighborFactory.apply(node, maxDegree());
nodes.put(node, newNeighborSet);
return newNeighborSet;
}

/** must be called after addNode once neighbors are linked in all levels. */
Expand All @@ -112,19 +115,20 @@ public int maxDegree() {
}

int entry() {
return (int) entryPoint.get();
return entryPoint.get();
}

@Override
public NodesIterator getNodes() {
// We avoid the temptation to optimize this by using ArrayNodesIterator.
// This is because, while the graph will contain sequential ordinals once the graph is complete,
// we should not assume that that is the only time it will be called.
var keysInts = Arrays.stream(nodes.keySetLong()).iterator();
return new NodesIterator(nodes.size()) {
int size = nodes.size();
var keysInts = IntStream.range(0, size).iterator();
return new NodesIterator(size) {
@Override
public int nextInt() {
return keysInts.next().intValue();
return keysInts.next();
}

@Override
Expand All @@ -137,20 +141,14 @@ public boolean hasNext() {
@Override
public long ramBytesUsed() {
// the main graph structure
long total = concurrentHashMapRamUsed(size());
long chmSize = concurrentHashMapRamUsed(size());
long total = (long) size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF;
long neighborSize = neighborsRamUsed(maxDegree()) * size();

total += chmSize + neighborSize;

return total;
return total + neighborSize + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
}

public long ramBytesUsedOneNode(int nodeLevel) {
int entryCount = (int) (nodeLevel / CHM_LOAD_FACTOR);
var graphBytesUsed =
chmEntriesRamUsed(entryCount)
+ neighborsRamUsed(maxDegree())
neighborsRamUsed(maxDegree())
+ nodeLevel * neighborsRamUsed(maxDegree());
var clockBytesUsed = Integer.BYTES;
return graphBytesUsed + clockBytesUsed;
Expand All @@ -171,42 +169,6 @@ private static long neighborsRamUsed(int count) {
return neighborSetBytes + (long) count * (Integer.BYTES + Float.BYTES);
}

private static final float CHM_LOAD_FACTOR = 0.75f; // this is hardcoded inside ConcurrentHashMap

/**
* caller's responsibility to divide number of entries by load factor to get internal node count
*/
private static long chmEntriesRamUsed(int internalEntryCount) {
long REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
long chmNodeBytes =
REF_BYTES // node itself in Node[]
+ 3L * REF_BYTES
+ Integer.BYTES; // node internals

return internalEntryCount * chmNodeBytes;
}

private static long concurrentHashMapRamUsed(int externalSize) {
long REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
long AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
long CORES = Runtime.getRuntime().availableProcessors();

// CHM has a striped counter Cell implementation, we expect at most one per core
long chmCounters = AH_BYTES + CORES * (REF_BYTES + Long.BYTES);

int nodeCount = (int) (externalSize / CHM_LOAD_FACTOR);

long chmSize =
chmEntriesRamUsed(nodeCount) // nodes
+ nodeCount * REF_BYTES
+ AH_BYTES // nodes array
+ Long.BYTES
+ 3 * Integer.BYTES
+ 3 * REF_BYTES // extra internal fields
+ chmCounters
+ REF_BYTES; // the Map reference itself
return chmSize;
}

@Override
public String toString() {
Expand All @@ -233,7 +195,7 @@ void validateEntryNode() {
return;
}
var en = entryPoint.get();
if (!(en >= 0 && nodes.containsKey(en))) {
if (!(en >= 0 && getNeighbors(en) != null)) {
throw new IllegalStateException("Entry node was incompletely added! " + en);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package io.github.jbellis.jvector.util;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.concurrent.locks.StampedLock;

/**
* A Map of int -> T where the int keys are dense and start at zero, but the
* size of the map is not known in advance. This provides fast, concurrent
* updates and minimizes contention when the map is resized.
*/
public class DenseIntMap<T> {
private volatile AtomicReferenceArray<T> objects;
private final AtomicInteger size;
private final StampedLock sl = new StampedLock();

public DenseIntMap(int initialSize) {
objects = new AtomicReferenceArray<>(initialSize);
size = new AtomicInteger();
}

/**
* @param key ordinal
*/
public void put(int key, T value) {
ensureCapacity(key);
long stamp;
do {
stamp = sl.tryOptimisticRead();
objects.set(key, value);
} while (!sl.validate(stamp));

size.incrementAndGet();
}

/**
* @return number of items that have been added
*/
public int size() {
return size.get();
}

/**
* @param key ordinal
* @return the value of the key, or null if not set
*/
public T get(int key) {
// since objects is volatile, we don't need to lock
var ref = objects;
if (key >= ref.length()) {
return null;
}
return ref.get(key);
}

private void ensureCapacity(int node) {
if (node < objects.length()) {
return;
}

long stamp = sl.writeLock();
try {
var oldArray = objects;
if (node >= oldArray.length()) {
int newSize = ArrayUtil.oversize(node + 1, RamUsageEstimator.NUM_BYTES_OBJECT_REF);
var newArray = new AtomicReferenceArray<T>(newSize);
for (int i = 0; i < oldArray.length(); i++) {
newArray.set(i, oldArray.get(i));
}
objects = newArray;
}
} finally {
sl.unlockWrite(stamp);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,8 @@ public void execute(Runnable run) {
public <T> T submit(Supplier<T> run) {
return pool.submit(run::get).join();
}

public static int getPhysicalCoreCount() {
return physicalCoreCount;
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package io.github.jbellis.jvector.util;

import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.Stream;

import org.jctools.queues.MpmcArrayQueue;

/**
* Allows any object to be pooled and released when work is done.
* This is an alternative to using {@link ThreadLocal}.
Expand Down Expand Up @@ -95,7 +94,7 @@ static class ThreadPooling<T> extends PoolingSupport<T>
{
private final int limit;
private final AtomicInteger created;
private final MpmcArrayQueue<T> queue;
private final LinkedBlockingQueue<T> queue;
private final Supplier<T> initialValue;

private ThreadPooling(Supplier<T> initialValue) {
Expand All @@ -106,7 +105,7 @@ private ThreadPooling(Supplier<T> initialValue) {
private ThreadPooling(int threadLimit, Supplier<T> initialValue) {
this.limit = threadLimit;
this.created = new AtomicInteger(0);
this.queue = new MpmcArrayQueue<>(threadLimit);
this.queue = new LinkedBlockingQueue<>(threadLimit);
this.initialValue = initialValue;
}

Expand Down
Loading

0 comments on commit d31f1c8

Please sign in to comment.