diff --git a/elastiknn-jmh-benchmarks/src/main/scala/com/klibisz/elastiknn/jmhbenchmarks/KthGreatestBenchmarks.scala b/elastiknn-jmh-benchmarks/src/main/scala/com/klibisz/elastiknn/jmhbenchmarks/KthGreatestBenchmarks.scala new file mode 100644 index 000000000..315b64102 --- /dev/null +++ b/elastiknn-jmh-benchmarks/src/main/scala/com/klibisz/elastiknn/jmhbenchmarks/KthGreatestBenchmarks.scala @@ -0,0 +1,52 @@ +package com.klibisz.elastiknn.jmhbenchmarks + +import com.klibisz.elastiknn.search.QuickSelect +import org.apache.lucene.search.KthGreatest +import org.openjdk.jmh.annotations._ + +import scala.util.Random + +@State(Scope.Benchmark) +class KthGreatestBenchmarkFixtures { + val rng = new Random(0) + val k = 1000 + val numDocs = 60000 + val shortCounts: Array[Short] = (0 until numDocs).map(_ => rng.nextInt(Short.MaxValue).toShort).toArray + val copy = new Array[Short](shortCounts.length) + val expected = shortCounts.sorted.reverse.apply(k) +} + +class KthGreatestBenchmarks { + + @Benchmark + @BenchmarkMode(Array(Mode.Throughput)) + @Fork(value = 1) + @Warmup(time = 5, iterations = 5) + @Measurement(time = 5, iterations = 5) + def sortBaseline(f: KthGreatestBenchmarkFixtures): Unit = { + val sorted = f.shortCounts.sorted + val actual = sorted.apply(f.shortCounts.length - f.k) + require(actual == f.expected, (actual, f.expected)) + } + + @Benchmark + @BenchmarkMode(Array(Mode.Throughput)) + @Fork(value = 1) + @Warmup(time = 5, iterations = 5) + @Measurement(time = 5, iterations = 5) + def kthGreatest(f: KthGreatestBenchmarkFixtures): Unit = { + val actual = KthGreatest.kthGreatest(f.shortCounts, f.k) + require(actual.kthGreatest == f.expected, (actual.kthGreatest, f.expected)) + } + + @Benchmark + @BenchmarkMode(Array(Mode.Throughput)) + @Fork(value = 1) + @Warmup(time = 5, iterations = 5) + @Measurement(time = 5, iterations = 5) + def unnikedRecursive(f: KthGreatestBenchmarkFixtures): Unit = { + System.arraycopy(f.shortCounts, 0, f.copy, 0, f.copy.length) + val actual = QuickSelect.selectRecursive(f.copy, f.k) + require(actual == f.expected, (actual, f.expected)) + } +} diff --git a/elastiknn-lucene/src/main/java/com/klibisz/elastiknn/search/QuickSelect.java b/elastiknn-lucene/src/main/java/com/klibisz/elastiknn/search/QuickSelect.java new file mode 100644 index 000000000..02d4d3455 --- /dev/null +++ b/elastiknn-lucene/src/main/java/com/klibisz/elastiknn/search/QuickSelect.java @@ -0,0 +1,53 @@ +package com.klibisz.elastiknn.search; + +import java.util.Random; + +/** + * Based on https://gist.github.com/unnikked/14c19ba13f6a4bfd00a3 + */ +public class QuickSelect { + + private static final Random rng = new Random(0); + + public static short selectRecursive(short[] array, int n) { + return recursive(array, 0, array.length - 1, n); + } + + private static short recursive(short[] array, int left, int right, int k) { + if (left == right) { // If the list contains only one element, + return array[left]; // return that element + } + + // select a pivotIndex between left and right + int pivotIndex = left + rng.nextInt(right - left); + pivotIndex = partition(array, left, right, pivotIndex); + // The pivot is in its final sorted position + if (k == pivotIndex) { + return array[k]; + } else if (k < pivotIndex) { + return recursive(array, left, pivotIndex - 1, k); + } else { + return recursive(array, pivotIndex + 1, right, k); + } + } + + private static int partition(short[] array, int left, int right, int pivotIndex) { + int pivotValue = array[pivotIndex]; + swap(array, pivotIndex, right); // move pivot to end + int storeIndex = left; + for(int i = left; i < right; i++) { + if(array[i] > pivotValue) { + swap(array, storeIndex, i); + storeIndex++; + } + } + swap(array, right, storeIndex); // Move pivot to its final place + return storeIndex; + } + + private static void swap(short[] array, int a, int b) { + short tmp = array[a]; + array[a] = array[b]; + array[b] = tmp; + } +} diff --git a/elastiknn-lucene/src/main/java/org/apache/lucene/search/MatchHashesAndScoreQuery.java b/elastiknn-lucene/src/main/java/org/apache/lucene/search/MatchHashesAndScoreQuery.java index 5132f10d0..9f21b6060 100644 --- a/elastiknn-lucene/src/main/java/org/apache/lucene/search/MatchHashesAndScoreQuery.java +++ b/elastiknn-lucene/src/main/java/org/apache/lucene/search/MatchHashesAndScoreQuery.java @@ -79,12 +79,6 @@ private HitCounter countHits(LeafReader reader) throws IOException { } private DocIdSetIterator buildDocIdSetIterator(HitCounter counter) { - // TODO: Add back this logging once log4j mess has settled. -// if (counter.numHits() < candidates) { -// logger.warn(String.format( -// "Found fewer approximate matches [%d] than the requested number of candidates [%d]", -// counter.numHits(), candidates)); -// } if (counter.isEmpty()) return DocIdSetIterator.empty(); else { @@ -114,18 +108,22 @@ public int nextDoc() { docID = counter.minKey() - 1; } - // Ensure that docs with count = kgr.kthGreatest are only emitted when there are fewer - // than `candidates` docs with count > kgr.kthGreatest. while (true) { + // We've emitted `candidates` docs or the next doc would exceed the max key, + // so we stop emitting. if (numEmitted == candidates || docID + 1 > counter.maxKey()) { docID = DocIdSetIterator.NO_MORE_DOCS; return docID(); } else { + // Increment and check the next doc. docID++; if (counter.get(docID) > kgr.kthGreatest) { numEmitted++; return docID(); - } else if (counter.get(docID) == kgr.kthGreatest && numEq < candidates - kgr.numGreaterThan) { + } + // Ensure that docs with count = kgr.kthGreatest are only emitted when there are + // fewer than `candidates` docs with count > kgr.kthGreatest. + else if (counter.get(docID) == kgr.kthGreatest && numEq < candidates - kgr.numGreaterThan) { numEq++; numEmitted++; return docID();