Skip to content

Commit

Permalink
fix mergeNeighbors to not add duplicate nodes, and fix test to check …
Browse files Browse the repository at this point in the history
…for duplicates
  • Loading branch information
jbellis committed Oct 9, 2023
1 parent bfc0fcf commit 7fd7489
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.github.jbellis.jvector.util.DocIdSetIterator;
import io.github.jbellis.jvector.util.FixedBitSet;

import java.util.HashSet;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

Expand Down Expand Up @@ -208,16 +209,44 @@ static ConcurrentNeighborArray mergeNeighbors(NeighborArray a1, NeighborArray a2
ConcurrentNeighborArray merged = new ConcurrentNeighborArray(a1.size() + a2.size());
int i = 0, j = 0;

// since nodes are only guaranteed to be sorted by score -- ties can appear in any node order --
// we need to remember all the nodes with the current score to avoid adding duplicates
var nodesWithLastScore = new HashSet<>();
float lastAddedScore = Float.NaN;

// loop through both source arrays, adding the highest score element to the merged array,
// until we reach the end of one of the sources
while (i < a1.size() && j < a2.size()) {
if (a1.score()[i] < a2.score[j]) {
merged.addInOrder(a2.node[j], a2.score[j]);
// add from a2
if (a2.score[j] != lastAddedScore) {
nodesWithLastScore.clear();
lastAddedScore = a2.score[j];
}
if (nodesWithLastScore.add(a2.node[j])) {
merged.addInOrder(a2.node[j], a2.score[j]);
}
j++;
} else if (a1.score()[i] > a2.score[j]) {
merged.addInOrder(a1.node()[i], a1.score()[i]);
// add from a1
if (a1.score()[i] != lastAddedScore) {
nodesWithLastScore.clear();
lastAddedScore = a1.score()[i];
}
if (nodesWithLastScore.add(a1.node()[i])) {
merged.addInOrder(a1.node()[i], a1.score()[i]);
}
i++;
} else {
merged.addInOrder(a1.node()[i], a1.score()[i]);
if (a2.node[j] != a1.node()[i]) {
// same score -- add both
if (a1.score()[i] != lastAddedScore) {
nodesWithLastScore.clear();
lastAddedScore = a1.score()[i];
}
if (nodesWithLastScore.add(a1.node()[i])) {
merged.addInOrder(a1.node()[i], a1.score()[i]);
}
if (nodesWithLastScore.add(a2.node()[j])) {
merged.addInOrder(a2.node[j], a2.score[j]);
}
i++;
Expand All @@ -226,25 +255,33 @@ static ConcurrentNeighborArray mergeNeighbors(NeighborArray a1, NeighborArray a2
}

// If elements remain in a1, add them
while (i < a1.size()) {
// Skip duplicates between the remaining elements in a1 and the last added element in a2
if (j > 0 && i < a1.size() && a1.node()[i] == a2.node[j - 1]) {
if (i < a1.size()) {
// avoid duplicates while adding nodes with the same score
while (i < a1.size && a1.score()[i] == lastAddedScore) {
if (!nodesWithLastScore.contains(a1.node()[i])) {
merged.addInOrder(a1.node()[i], a1.score()[i]);
}
i++;
continue;
}
merged.addInOrder(a1.node()[i], a1.score()[i]);
i++;
// the remaining nodes have a different score, so we can bulk-add them
System.arraycopy(a1.node, i, merged.node, merged.size, a1.size - i);
System.arraycopy(a1.score, i, merged.score, merged.size, a1.size - i);
merged.size += a1.size - i;
}

// If elements remain in a2, add them
while (j < a2.size()) {
// Skip duplicates between the remaining elements in a2 and the last added element in a1
if (i > 0 && j < a2.size() && a2.node[j] == a1.node()[i - 1]) {
if (j < a2.size()) {
// avoid duplicates while adding nodes with the same score
while (j < a2.size && a2.score[j] == lastAddedScore) {
if (!nodesWithLastScore.contains(a2.node[j])) {
merged.addInOrder(a2.node[j], a2.score[j]);
}
j++;
continue;
}
merged.addInOrder(a2.node[j], a2.score[j]);
j++;
// the remaining nodes have a different score, so we can bulk-add them
System.arraycopy(a2.node, j, merged.node, merged.size, a2.size - j);
System.arraycopy(a2.score, j, merged.score, merged.size, a2.size - j);
merged.size += a2.size - j;
}

return merged;
Expand Down Expand Up @@ -372,6 +409,7 @@ private boolean duplicateExistsNear(int insertionPoint, int newNode, float newSc
* This modifies the array in place, preserving the relative order of the elements retained.
* <p>
* @param selected A BitSet where the bit at index i is set if the i-th element should be retained.
* (Thus, the elements of selected represent positions in the NeighborArray, NOT node ids.)
*/
public void retain(BitSet selected) {
int writeIdx = 0; // index for where to write the next retained element
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
import org.junit.Test;

import java.util.Arrays;
import java.util.HashSet;
import java.util.stream.IntStream;

import static org.junit.Assert.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class TestConcurrentNeighborSet extends RandomizedTest {
private static final NeighborSimilarity simpleScore = a -> {
Expand Down Expand Up @@ -140,18 +143,28 @@ public void testNoDuplicatesSameScores() {

@Test
public void testMergeCandidatesSimple() {
NeighborArray arr1 = new NeighborArray(3);
var arr1 = new NeighborArray(1);
arr1.addInOrder(1, 1.0f);

var arr2 = new NeighborArray(1);
arr2.addInOrder(0, 2.0f);

var merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2);
// Expected result: [0, 1]
assertEquals(2, merged.size());
assertArrayEquals(new int[] {0, 1}, Arrays.copyOf(merged.node(), 2));

arr1 = new NeighborArray(3);
arr1.addInOrder(3, 3.0f);
arr1.addInOrder(2, 2.0f);
arr1.addInOrder(1, 1.0f);

NeighborArray arr2 = new NeighborArray(3);
arr2 = new NeighborArray(3);
arr2.addInOrder(4, 4.0f);
arr2.addInOrder(2, 2.0f);
arr2.addInOrder(1, 1.0f);

NeighborArray merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2);

merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2);
// Expected result: [4, 3, 2, 1]
assertEquals(4, merged.size());
assertArrayEquals(new int[] {4, 3, 2, 1}, Arrays.copyOf(merged.node(), 4));
Expand All @@ -166,7 +179,6 @@ public void testMergeCandidatesSimple() {
arr2.addInOrder(2, 2.0f);

merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2);

// Expected result: [3, 2]
assertEquals(2, merged.size());
assertArrayEquals(new int[] {3, 2}, Arrays.copyOf(merged.node(), 2));
Expand All @@ -175,6 +187,7 @@ public void testMergeCandidatesSimple() {
}

private void testMergeCandidatesOnce() {
// test merge where one array contains either exact duplicates, or duplicate scores, of the other
int maxSize = 1 + getRandom().nextInt(5);

NeighborArray arr1 = new NeighborArray(maxSize);
Expand Down Expand Up @@ -217,11 +230,14 @@ private void testMergeCandidatesOnce() {
var merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2);
assert merged.size <= arr1.size() + arr2.size();
assert merged.size >= Math.max(arr1.size(), arr2.size());
var uniqueNodes = new HashSet<>();
for (int i = 0; i < merged.size - 1; i++) {
assert merged.score[i] >= merged.score[i + 1];
assertTrue(merged.score[i] >= merged.score[i + 1]);
assertTrue(uniqueNodes.add(merged.node[i]));
}
}

@Test
public void testMergeCandidatesRandom() {
for (int i = 0; i < 10000; i++) {
testMergeCandidatesOnce();
Expand Down

0 comments on commit 7fd7489

Please sign in to comment.