From 7fd74891f29f26b760c60e94882359d5d8dce6b1 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Sun, 8 Oct 2023 16:16:24 -0500 Subject: [PATCH] fix mergeNeighbors to not add duplicate nodes, and fix test to check for duplicates --- .../jvector/graph/ConcurrentNeighborSet.java | 70 ++++++++++++++----- .../graph/TestConcurrentNeighborSet.java | 30 ++++++-- 2 files changed, 77 insertions(+), 23 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java index 3c779dc02..59516f608 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java @@ -20,6 +20,7 @@ import io.github.jbellis.jvector.util.DocIdSetIterator; import io.github.jbellis.jvector.util.FixedBitSet; +import java.util.HashSet; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -208,16 +209,44 @@ static ConcurrentNeighborArray mergeNeighbors(NeighborArray a1, NeighborArray a2 ConcurrentNeighborArray merged = new ConcurrentNeighborArray(a1.size() + a2.size()); int i = 0, j = 0; + // since nodes are only guaranteed to be sorted by score -- ties can appear in any node order -- + // we need to remember all the nodes with the current score to avoid adding duplicates + var nodesWithLastScore = new HashSet<>(); + float lastAddedScore = Float.NaN; + + // loop through both source arrays, adding the highest score element to the merged array, + // until we reach the end of one of the sources while (i < a1.size() && j < a2.size()) { if (a1.score()[i] < a2.score[j]) { - merged.addInOrder(a2.node[j], a2.score[j]); + // add from a2 + if (a2.score[j] != lastAddedScore) { + nodesWithLastScore.clear(); + lastAddedScore = a2.score[j]; + } + if (nodesWithLastScore.add(a2.node[j])) { + merged.addInOrder(a2.node[j], a2.score[j]); + } j++; } else if (a1.score()[i] > a2.score[j]) { - merged.addInOrder(a1.node()[i], a1.score()[i]); + // add from a1 + if (a1.score()[i] != lastAddedScore) { + nodesWithLastScore.clear(); + lastAddedScore = a1.score()[i]; + } + if (nodesWithLastScore.add(a1.node()[i])) { + merged.addInOrder(a1.node()[i], a1.score()[i]); + } i++; } else { - merged.addInOrder(a1.node()[i], a1.score()[i]); - if (a2.node[j] != a1.node()[i]) { + // same score -- add both + if (a1.score()[i] != lastAddedScore) { + nodesWithLastScore.clear(); + lastAddedScore = a1.score()[i]; + } + if (nodesWithLastScore.add(a1.node()[i])) { + merged.addInOrder(a1.node()[i], a1.score()[i]); + } + if (nodesWithLastScore.add(a2.node()[j])) { merged.addInOrder(a2.node[j], a2.score[j]); } i++; @@ -226,25 +255,33 @@ static ConcurrentNeighborArray mergeNeighbors(NeighborArray a1, NeighborArray a2 } // If elements remain in a1, add them - while (i < a1.size()) { - // Skip duplicates between the remaining elements in a1 and the last added element in a2 - if (j > 0 && i < a1.size() && a1.node()[i] == a2.node[j - 1]) { + if (i < a1.size()) { + // avoid duplicates while adding nodes with the same score + while (i < a1.size && a1.score()[i] == lastAddedScore) { + if (!nodesWithLastScore.contains(a1.node()[i])) { + merged.addInOrder(a1.node()[i], a1.score()[i]); + } i++; - continue; } - merged.addInOrder(a1.node()[i], a1.score()[i]); - i++; + // the remaining nodes have a different score, so we can bulk-add them + System.arraycopy(a1.node, i, merged.node, merged.size, a1.size - i); + System.arraycopy(a1.score, i, merged.score, merged.size, a1.size - i); + merged.size += a1.size - i; } // If elements remain in a2, add them - while (j < a2.size()) { - // Skip duplicates between the remaining elements in a2 and the last added element in a1 - if (i > 0 && j < a2.size() && a2.node[j] == a1.node()[i - 1]) { + if (j < a2.size()) { + // avoid duplicates while adding nodes with the same score + while (j < a2.size && a2.score[j] == lastAddedScore) { + if (!nodesWithLastScore.contains(a2.node[j])) { + merged.addInOrder(a2.node[j], a2.score[j]); + } j++; - continue; } - merged.addInOrder(a2.node[j], a2.score[j]); - j++; + // the remaining nodes have a different score, so we can bulk-add them + System.arraycopy(a2.node, j, merged.node, merged.size, a2.size - j); + System.arraycopy(a2.score, j, merged.score, merged.size, a2.size - j); + merged.size += a2.size - j; } return merged; @@ -372,6 +409,7 @@ private boolean duplicateExistsNear(int insertionPoint, int newNode, float newSc * This modifies the array in place, preserving the relative order of the elements retained. *

* @param selected A BitSet where the bit at index i is set if the i-th element should be retained. + * (Thus, the elements of selected represent positions in the NeighborArray, NOT node ids.) */ public void retain(BitSet selected) { int writeIdx = 0; // index for where to write the next retained element diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentNeighborSet.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentNeighborSet.java index 2a62d7a1b..5226119dd 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentNeighborSet.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentNeighborSet.java @@ -23,9 +23,12 @@ import org.junit.Test; import java.util.Arrays; +import java.util.HashSet; import java.util.stream.IntStream; -import static org.junit.Assert.*; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; public class TestConcurrentNeighborSet extends RandomizedTest { private static final NeighborSimilarity simpleScore = a -> { @@ -140,18 +143,28 @@ public void testNoDuplicatesSameScores() { @Test public void testMergeCandidatesSimple() { - NeighborArray arr1 = new NeighborArray(3); + var arr1 = new NeighborArray(1); + arr1.addInOrder(1, 1.0f); + + var arr2 = new NeighborArray(1); + arr2.addInOrder(0, 2.0f); + + var merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2); + // Expected result: [0, 1] + assertEquals(2, merged.size()); + assertArrayEquals(new int[] {0, 1}, Arrays.copyOf(merged.node(), 2)); + + arr1 = new NeighborArray(3); arr1.addInOrder(3, 3.0f); arr1.addInOrder(2, 2.0f); arr1.addInOrder(1, 1.0f); - NeighborArray arr2 = new NeighborArray(3); + arr2 = new NeighborArray(3); arr2.addInOrder(4, 4.0f); arr2.addInOrder(2, 2.0f); arr2.addInOrder(1, 1.0f); - NeighborArray merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2); - + merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2); // Expected result: [4, 3, 2, 1] assertEquals(4, merged.size()); assertArrayEquals(new int[] {4, 3, 2, 1}, Arrays.copyOf(merged.node(), 4)); @@ -166,7 +179,6 @@ public void testMergeCandidatesSimple() { arr2.addInOrder(2, 2.0f); merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2); - // Expected result: [3, 2] assertEquals(2, merged.size()); assertArrayEquals(new int[] {3, 2}, Arrays.copyOf(merged.node(), 2)); @@ -175,6 +187,7 @@ public void testMergeCandidatesSimple() { } private void testMergeCandidatesOnce() { + // test merge where one array contains either exact duplicates, or duplicate scores, of the other int maxSize = 1 + getRandom().nextInt(5); NeighborArray arr1 = new NeighborArray(maxSize); @@ -217,11 +230,14 @@ private void testMergeCandidatesOnce() { var merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2); assert merged.size <= arr1.size() + arr2.size(); assert merged.size >= Math.max(arr1.size(), arr2.size()); + var uniqueNodes = new HashSet<>(); for (int i = 0; i < merged.size - 1; i++) { - assert merged.score[i] >= merged.score[i + 1]; + assertTrue(merged.score[i] >= merged.score[i + 1]); + assertTrue(uniqueNodes.add(merged.node[i])); } } + @Test public void testMergeCandidatesRandom() { for (int i = 0; i < 10000; i++) { testMergeCandidatesOnce();