Skip to content

Commit 31a54e9

Browse files
author
Jonathan Ellis
committed
instead of renumbering implicitly, let caller provide remapper
1 parent cc33203 commit 31a54e9

File tree

9 files changed

+86
-185
lines changed

9 files changed

+86
-185
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/disk/OnDiskGraphIndex.java

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import java.io.DataOutput;
2727
import java.io.IOException;
2828
import java.io.UncheckedIOException;
29+
import java.util.HashSet;
30+
import java.util.function.Function;
2931
import java.util.stream.IntStream;
3032

3133
public class OnDiskGraphIndex<T> implements GraphIndex<T>, AutoCloseable, Accountable
@@ -147,19 +149,28 @@ public void close() throws IOException {
147149
readerSupplier.close();
148150
}
149151

150-
// takes Graph and Vectors separately since I'm reluctant to introduce a Vectors reference
151-
// to OnHeapGraphIndex just for this method. Maybe that will end up the best solution,
152-
// but I'm not sure yet.
153-
public static <T> void write(GraphIndex<T> graph, RandomAccessVectorValues<T> vectors, DataOutput out) throws IOException {
152+
/**
153+
* @param graph the graph to write
154+
* @param vectors the vectors associated with each node
155+
* @param ordinalMapper A function that maps from the graph's ordinals to the ordinals in the output.
156+
* For simple use cases this can be the identity function. To deal with deleted nodes,
157+
* or to renumber the nodes to match rows or documents elsewhere, you may provide
158+
* something custom. The mapper must map from the graph's ordinals
159+
* [0..getMaxNodeId()], to the output's ordinals [0..size()), with no gaps and
160+
* no duplicates.
161+
* @param out the output to write to
162+
*/
163+
public static <T> void write(GraphIndex<T> graph,
164+
RandomAccessVectorValues<T> vectors,
165+
Function<Integer, Integer> ordinalMapper,
166+
DataOutput out)
167+
throws IOException
168+
{
154169
if (graph instanceof OnHeapGraphIndex) {
155170
var ohgi = (OnHeapGraphIndex<T>) graph;
156171
if (ohgi.getDeletedNodes().cardinality() > 0) {
157172
throw new IllegalArgumentException("Run builder.cleanup() before writing the graph");
158173
}
159-
if (ohgi.hasPurgedNodes()) {
160-
vectors = new RenumberingVectorValues<>(ohgi, vectors);
161-
graph = new RenumberingGraphIndex<>(ohgi);
162-
}
163174
}
164175
var view = graph.getView();
165176

@@ -170,15 +181,22 @@ public static <T> void write(GraphIndex<T> graph, RandomAccessVectorValues<T> ve
170181
out.writeInt(graph.maxDegree());
171182

172183
// for each graph node, write the associated vector and its neighbors
173-
for (int node = 0; node < graph.size(); node++) {
174-
out.writeInt(node); // unnecessary, but a reasonable sanity check
175-
Io.writeFloats(out, (float[]) vectors.vectorValue(node));
184+
var newOrdinals = new HashSet<Integer>();
185+
for (int originalOrdinal = 0; originalOrdinal <= graph.getMaxNodeId(); originalOrdinal++) {
186+
if (!graph.containsNode(originalOrdinal)) {
187+
continue;
188+
}
189+
190+
int newOrdinal = ordinalMapper.apply(originalOrdinal);
191+
newOrdinals.add(newOrdinal);
192+
out.writeInt(newOrdinal); // unnecessary, but a reasonable sanity check
193+
Io.writeFloats(out, (float[]) vectors.vectorValue(originalOrdinal));
176194

177-
var neighbors = view.getNeighborsIterator(node);
195+
var neighbors = view.getNeighborsIterator(originalOrdinal);
178196
out.writeInt(neighbors.size());
179197
int n = 0;
180198
for ( ; n < neighbors.size(); n++) {
181-
out.writeInt(neighbors.nextInt());
199+
out.writeInt(ordinalMapper.apply(neighbors.nextInt()));
182200
}
183201
assert !neighbors.hasNext();
184202

@@ -187,5 +205,16 @@ public static <T> void write(GraphIndex<T> graph, RandomAccessVectorValues<T> ve
187205
out.writeInt(-1);
188206
}
189207
}
208+
209+
// verify that the provided mapper was well-behaved
210+
if (newOrdinals.size() > graph.size()) {
211+
throw new IllegalArgumentException("graph modified during write");
212+
}
213+
if (newOrdinals.size() < graph.size()) {
214+
throw new IllegalArgumentException("ordinalMapper resulted in duplicate entries");
215+
}
216+
if (graph.size() > 0 && newOrdinals.stream().mapToInt(i -> i).max().getAsInt() != graph.size() - 1) {
217+
throw new IllegalArgumentException("ordinalMapper produced out-of-range entries");
218+
}
190219
}
191220
}

jvector-base/src/main/java/io/github/jbellis/jvector/disk/RenumberingGraphIndex.java

Lines changed: 0 additions & 99 deletions
This file was deleted.

jvector-base/src/main/java/io/github/jbellis/jvector/disk/RenumberingVectorValues.java

Lines changed: 0 additions & 48 deletions
This file was deleted.

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndex.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ default int getMaxNodeId() {
6868
return size();
6969
}
7070

71+
/**
72+
* @return true iff the graph contains the node with the given ordinal id
73+
*/
74+
default boolean containsNode(int nodeId) {
75+
return nodeId >= 0 && nodeId < size();
76+
}
77+
7178
@Override
7279
void close() throws IOException;
7380

jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,8 @@ int[] rawNodes() {
261261
return nodes.keySet().stream().mapToInt(i -> (int) (long) i).toArray();
262262
}
263263

264-
public boolean hasPurgedNodes() {
265-
return hasPurgedNodes;
266-
}
267-
268-
public boolean containsNode(int i) {
269-
return nodes.containsKey(i);
264+
public boolean containsNode(int nodeId) {
265+
return nodes.containsKey(nodeId);
270266
}
271267

272268
public double getAverageShortEdges() {

jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import java.util.List;
3939
import java.util.Set;
4040
import java.util.concurrent.atomic.LongAdder;
41+
import java.util.function.Function;
4142
import java.util.stream.Collectors;
4243
import java.util.stream.IntStream;
4344

@@ -59,7 +60,7 @@ private static void testRecall(int M, int efConstruction, List<Boolean> diskOpti
5960
var graphPath = testDirectory.resolve("graph" + M + efConstruction + ds.name);
6061
try {
6162
try (var outputStream = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(graphPath)))) {
62-
OnDiskGraphIndex.write(onHeapGraph, floatVectors, outputStream);
63+
OnDiskGraphIndex.write(onHeapGraph, floatVectors, Function.identity(), outputStream);
6364
}
6465
try (var onDiskGraph = new CachingGraphIndex(new OnDiskGraphIndex<>(ReaderSupplierFactory.open(graphPath), 0))) {
6566
int queryRuns = 2;

jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import java.util.HashSet;
3636
import java.util.concurrent.ExecutionException;
3737
import java.util.concurrent.atomic.AtomicInteger;
38+
import java.util.function.Function;
3839
import java.util.stream.IntStream;
3940

4041
public class SiftSmall {
@@ -62,7 +63,7 @@ public static void testRecall(ArrayList<float[]> baseVectors, ArrayList<float[]>
6263
var graphPath = testDirectory.resolve("graph_test");
6364
try {
6465
DataOutputStream outputFile = new DataOutputStream(new FileOutputStream(graphPath.toFile()));
65-
OnDiskGraphIndex.write(onHeapGraph, ravv, outputFile);
66+
OnDiskGraphIndex.write(onHeapGraph, ravv, Function.identity(), outputFile);
6667

6768
var onDiskGraph = new CachingGraphIndex(new OnDiskGraphIndex<>(ReaderSupplierFactory.open(graphPath), 0));
6869

jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import java.nio.file.attribute.BasicFileAttributes;
3636
import java.util.*;
3737
import java.util.concurrent.ConcurrentHashMap;
38+
import java.util.function.Function;
3839
import java.util.function.Supplier;
3940
import java.util.stream.Collectors;
4041
import java.util.stream.IntStream;
@@ -117,7 +118,7 @@ public static byte[] randomVector8(Random random, int dim) {
117118
public static <T> void writeGraph(GraphIndex<T> graph, RandomAccessVectorValues<T> vectors, Path outputPath) throws IOException {
118119
try (var indexOutputWriter = openFileForWriting(outputPath))
119120
{
120-
OnDiskGraphIndex.write(graph, vectors, indexOutputWriter);
121+
OnDiskGraphIndex.write(graph, vectors, Function.identity(), indexOutputWriter);
121122
indexOutputWriter.flush();
122123
}
123124
}

jvector-tests/src/test/java/io/github/jbellis/jvector/disk/TestOnDiskGraphIndex.java

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
import java.io.IOException;
3333
import java.nio.file.Files;
3434
import java.nio.file.Path;
35+
import java.util.HashMap;
3536
import java.util.List;
37+
import java.util.Map;
38+
import java.util.function.Function;
3639

3740
import static io.github.jbellis.jvector.TestUtil.getNeighborNodes;
3841
import static org.junit.Assert.assertArrayEquals;
@@ -98,24 +101,34 @@ public void testRenumberingOnDelete() throws IOException {
98101
assertEquals(1, getNeighborNodes(originalView, 2).size());
99102
assertTrue(getNeighborNodes(originalView, 2).contains(1));
100103

101-
// check renumbered
102-
var renumbered = new RenumberingGraphIndex<>(original);
103-
var renumberedView = renumbered.getView();
104-
assertEquals(2, renumbered.size());
105-
// 0 -> 1
106-
assertEquals(1, getNeighborNodes(renumberedView, 0).size());
107-
assertTrue(getNeighborNodes(renumberedView, 0).contains(1));
108-
// 1 -> 0
109-
assertEquals(1, getNeighborNodes(renumberedView, 1).size());
110-
assertTrue(getNeighborNodes(renumberedView, 1).contains(0));
111-
112-
// writing to disk should be the same as the renumbered
113-
var outputPath = testDirectory.resolve("large_graph");
114-
TestUtil.writeGraph(original, ravv, outputPath);
104+
// create renumbering map
105+
Map<Integer, Integer> oldToNewMap = new HashMap<>();
106+
int nextOrdinal = 0;
107+
for (int i = 0; i <= originalView.getMaxNodeId(); i++) {
108+
if (original.containsNode(i)) {
109+
oldToNewMap.put(i, nextOrdinal++);
110+
}
111+
}
112+
assertEquals(2, oldToNewMap.size());
113+
assertEquals(0, (int) oldToNewMap.get(1));
114+
assertEquals(1, (int) oldToNewMap.get(2));
115+
116+
// write the graph
117+
var outputPath = testDirectory.resolve("renumbered_graph");
118+
try (var indexOutputWriter = TestUtil.openFileForWriting(outputPath))
119+
{
120+
OnDiskGraphIndex.write(original, ravv, oldToNewMap::get, indexOutputWriter);
121+
indexOutputWriter.flush();
122+
}
123+
// check that written graph ordinals match the new ones
115124
try (var marr = new SimpleMappedReader(outputPath.toAbsolutePath().toString());
116-
var onDiskGraph = new OnDiskGraphIndex<float[]>(marr::duplicate, 0))
125+
var onDiskGraph = new OnDiskGraphIndex<float[]>(marr::duplicate, 0);
126+
var onDiskView = onDiskGraph.getView())
117127
{
118-
TestUtil.assertGraphEquals(renumbered, onDiskGraph);
128+
// 0 -> 1
129+
assertTrue(getNeighborNodes(onDiskView, 0).contains(1));
130+
// 1 -> 0
131+
assertTrue(getNeighborNodes(onDiskView, 1).contains(0));
119132
}
120133
}
121134

0 commit comments

Comments
 (0)