Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mergeNeighbors to not add duplicate nodes, and fix test to check for duplicates #119

Merged
merged 3 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.github.jbellis.jvector.util.DocIdSetIterator;
import io.github.jbellis.jvector.util.FixedBitSet;

import java.util.HashSet;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

Expand Down Expand Up @@ -208,16 +209,44 @@ static ConcurrentNeighborArray mergeNeighbors(NeighborArray a1, NeighborArray a2
ConcurrentNeighborArray merged = new ConcurrentNeighborArray(a1.size() + a2.size());
int i = 0, j = 0;

// since nodes are only guaranteed to be sorted by score -- ties can appear in any node order --
// we need to remember all the nodes with the current score to avoid adding duplicates
var nodesWithLastScore = new HashSet<>();
float lastAddedScore = Float.NaN;

// loop through both source arrays, adding the highest score element to the merged array,
// until we reach the end of one of the sources
while (i < a1.size() && j < a2.size()) {
if (a1.score()[i] < a2.score[j]) {
merged.addInOrder(a2.node[j], a2.score[j]);
// add from a2
if (a2.score[j] != lastAddedScore) {
nodesWithLastScore.clear();
lastAddedScore = a2.score[j];
}
if (nodesWithLastScore.add(a2.node[j])) {
merged.addInOrder(a2.node[j], a2.score[j]);
}
j++;
} else if (a1.score()[i] > a2.score[j]) {
merged.addInOrder(a1.node()[i], a1.score()[i]);
// add from a1
if (a1.score()[i] != lastAddedScore) {
nodesWithLastScore.clear();
lastAddedScore = a1.score()[i];
}
if (nodesWithLastScore.add(a1.node()[i])) {
merged.addInOrder(a1.node()[i], a1.score()[i]);
}
i++;
} else {
merged.addInOrder(a1.node()[i], a1.score()[i]);
if (a2.node[j] != a1.node()[i]) {
// same score -- add both
if (a1.score()[i] != lastAddedScore) {
nodesWithLastScore.clear();
lastAddedScore = a1.score()[i];
}
if (nodesWithLastScore.add(a1.node()[i])) {
merged.addInOrder(a1.node()[i], a1.score()[i]);
}
if (nodesWithLastScore.add(a2.node()[j])) {
merged.addInOrder(a2.node[j], a2.score[j]);
}
i++;
Expand All @@ -226,25 +255,33 @@ static ConcurrentNeighborArray mergeNeighbors(NeighborArray a1, NeighborArray a2
}

// If elements remain in a1, add them
while (i < a1.size()) {
// Skip duplicates between the remaining elements in a1 and the last added element in a2
if (j > 0 && i < a1.size() && a1.node()[i] == a2.node[j - 1]) {
if (i < a1.size()) {
// avoid duplicates while adding nodes with the same score
while (i < a1.size && a1.score()[i] == lastAddedScore) {
if (!nodesWithLastScore.contains(a1.node()[i])) {
merged.addInOrder(a1.node()[i], a1.score()[i]);
}
i++;
continue;
}
merged.addInOrder(a1.node()[i], a1.score()[i]);
i++;
// the remaining nodes have a different score, so we can bulk-add them
System.arraycopy(a1.node, i, merged.node, merged.size, a1.size - i);
System.arraycopy(a1.score, i, merged.score, merged.size, a1.size - i);
merged.size += a1.size - i;
}

// If elements remain in a2, add them
while (j < a2.size()) {
// Skip duplicates between the remaining elements in a2 and the last added element in a1
if (i > 0 && j < a2.size() && a2.node[j] == a1.node()[i - 1]) {
if (j < a2.size()) {
// avoid duplicates while adding nodes with the same score
while (j < a2.size && a2.score[j] == lastAddedScore) {
if (!nodesWithLastScore.contains(a2.node[j])) {
merged.addInOrder(a2.node[j], a2.score[j]);
}
j++;
continue;
}
merged.addInOrder(a2.node[j], a2.score[j]);
j++;
// the remaining nodes have a different score, so we can bulk-add them
System.arraycopy(a2.node, j, merged.node, merged.size, a2.size - j);
System.arraycopy(a2.score, j, merged.score, merged.size, a2.size - j);
merged.size += a2.size - j;
}

return merged;
Expand Down Expand Up @@ -372,6 +409,7 @@ private boolean duplicateExistsNear(int insertionPoint, int newNode, float newSc
* This modifies the array in place, preserving the relative order of the elements retained.
* <p>
* @param selected A BitSet where the bit at index i is set if the i-th element should be retained.
* (Thus, the elements of selected represent positions in the NeighborArray, NOT node ids.)
*/
public void retain(BitSet selected) {
int writeIdx = 0; // index for where to write the next retained element
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
import org.junit.Test;

import java.util.Arrays;
import java.util.HashSet;
import java.util.stream.IntStream;

import static org.junit.Assert.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class TestConcurrentNeighborSet extends RandomizedTest {
private static final NeighborSimilarity simpleScore = a -> {
Expand Down Expand Up @@ -140,18 +143,28 @@ public void testNoDuplicatesSameScores() {

@Test
public void testMergeCandidatesSimple() {
NeighborArray arr1 = new NeighborArray(3);
var arr1 = new NeighborArray(1);
arr1.addInOrder(1, 1.0f);

var arr2 = new NeighborArray(1);
arr2.addInOrder(0, 2.0f);

var merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2);
// Expected result: [0, 1]
assertEquals(2, merged.size());
assertArrayEquals(new int[] {0, 1}, Arrays.copyOf(merged.node(), 2));

arr1 = new NeighborArray(3);
arr1.addInOrder(3, 3.0f);
arr1.addInOrder(2, 2.0f);
arr1.addInOrder(1, 1.0f);

NeighborArray arr2 = new NeighborArray(3);
arr2 = new NeighborArray(3);
arr2.addInOrder(4, 4.0f);
arr2.addInOrder(2, 2.0f);
arr2.addInOrder(1, 1.0f);

NeighborArray merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2);

merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2);
// Expected result: [4, 3, 2, 1]
assertEquals(4, merged.size());
assertArrayEquals(new int[] {4, 3, 2, 1}, Arrays.copyOf(merged.node(), 4));
Expand All @@ -166,7 +179,6 @@ public void testMergeCandidatesSimple() {
arr2.addInOrder(2, 2.0f);

merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2);

// Expected result: [3, 2]
assertEquals(2, merged.size());
assertArrayEquals(new int[] {3, 2}, Arrays.copyOf(merged.node(), 2));
Expand All @@ -175,28 +187,24 @@ public void testMergeCandidatesSimple() {
}

private void testMergeCandidatesOnce() {
// test merge emphasizing dealing with tied scores
int maxSize = 1 + getRandom().nextInt(5);

// fill arr1 with nodes from 0..size, with random scores assigned (so random order of nodes)
NeighborArray arr1 = new NeighborArray(maxSize);
int a1Size;
if (getRandom().nextBoolean()) {
a1Size = maxSize;
} else {
a1Size = 1 + getRandom().nextInt(maxSize);
}
int a1Size = getRandom().nextBoolean() ? maxSize : 1 + getRandom().nextInt(maxSize);
for (int i = 0; i < a1Size; i++) {
arr1.insertSorted(i, getRandom().nextFloat());
}

// arr2 contains either
// -- an exact duplicates of the corresponding node in arr1, or
// -- a random score chosen from arr1
// this is designed to maximize the need for correct handling of corner cases in the merge
NeighborArray arr2 = new NeighborArray(maxSize);
int a2Size;
if (getRandom().nextBoolean()) {
a2Size = maxSize;
} else {
a2Size = 1 + getRandom().nextInt(maxSize);
}
int a2Size = getRandom().nextBoolean() ? maxSize : 1 + getRandom().nextInt(maxSize);
for (int i = 0; i < a2Size; i++) {
if (getRandom().nextBoolean()) {
if (i < a1Size && getRandom().nextBoolean()) {
// duplicate entry
int j = getRandom().nextInt(a1Size);
if (!arr2.contains(arr1.node[j])) {
Expand All @@ -214,14 +222,41 @@ private void testMergeCandidatesOnce() {
}
}

// merge!
var merged = ConcurrentNeighborSet.mergeNeighbors(arr1, arr2);

// sanity check
assert merged.size <= arr1.size() + arr2.size();
assert merged.size >= Math.max(arr1.size(), arr2.size());
var uniqueNodes = new HashSet<>();

// results should be sorted by score, and not contain duplicates
for (int i = 0; i < merged.size - 1; i++) {
assert merged.score[i] >= merged.score[i + 1];
assertTrue(merged.score[i] >= merged.score[i + 1]);
assertTrue(uniqueNodes.add(merged.node[i]));
}
assertTrue(uniqueNodes.add(merged.node[merged.size - 1]));

// results should contain all the nodes that were in the source arrays
for (int i = 0; i < arr1.size(); i++) {
assertTrue(String.format("%s missing%na1: %s%na2: %s%nmerged: %s%n",
arr1.node[i],
Arrays.toString(arr1.node),
Arrays.toString(arr2.node),
Arrays.toString(merged.node)),
uniqueNodes.contains(arr1.node[i]));
}
for (int i = 0; i < arr2.size(); i++) {
assertTrue(String.format("%s missing%na1: %s%na2: %s%nmerged: %s%n",
arr2.node[i],
Arrays.toString(arr1.node),
Arrays.toString(arr2.node),
Arrays.toString(merged.node)),
uniqueNodes.contains(arr2.node[i]));
}
}

@Test
public void testMergeCandidatesRandom() {
for (int i = 0; i < 10000; i++) {
testMergeCandidatesOnce();
Expand Down
Loading