Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance: use quick select algorithm to compute kth largest document count #603

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.klibisz.elastiknn.jmhbenchmarks

import com.klibisz.elastiknn.search.QuickSelect
import org.apache.lucene.search.KthGreatest
import org.openjdk.jmh.annotations._

import scala.util.Random

@State(Scope.Benchmark)
class KthGreatestBenchmarkFixtures {
val rng = new Random(0)
val k = 1000
val numDocs = 60000
val shortCounts: Array[Short] = (0 until numDocs).map(_ => rng.nextInt(Short.MaxValue).toShort).toArray
val copy = new Array[Short](shortCounts.length)
val expected = shortCounts.sorted.reverse.apply(k)
}

class KthGreatestBenchmarks {

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def sortBaseline(f: KthGreatestBenchmarkFixtures): Unit = {
val sorted = f.shortCounts.sorted
val actual = sorted.apply(f.shortCounts.length - f.k)
require(actual == f.expected, (actual, f.expected))
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def kthGreatest(f: KthGreatestBenchmarkFixtures): Unit = {
val actual = KthGreatest.kthGreatest(f.shortCounts, f.k)
require(actual.kthGreatest == f.expected, (actual.kthGreatest, f.expected))
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def unnikedRecursive(f: KthGreatestBenchmarkFixtures): Unit = {
System.arraycopy(f.shortCounts, 0, f.copy, 0, f.copy.length)
val actual = QuickSelect.selectRecursive(f.copy, f.k)
require(actual == f.expected, (actual, f.expected))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.klibisz.elastiknn.search;

import java.util.Random;

/**
* Based on https://gist.github.com/unnikked/14c19ba13f6a4bfd00a3
*/
public class QuickSelect {

private static final Random rng = new Random(0);

public static short selectRecursive(short[] array, int n) {
return recursive(array, 0, array.length - 1, n);
}

private static short recursive(short[] array, int left, int right, int k) {
if (left == right) { // If the list contains only one element,
return array[left]; // return that element
}

// select a pivotIndex between left and right
int pivotIndex = left + rng.nextInt(right - left);
pivotIndex = partition(array, left, right, pivotIndex);
// The pivot is in its final sorted position
if (k == pivotIndex) {
return array[k];
} else if (k < pivotIndex) {
return recursive(array, left, pivotIndex - 1, k);
} else {
return recursive(array, pivotIndex + 1, right, k);
}
}

private static int partition(short[] array, int left, int right, int pivotIndex) {
int pivotValue = array[pivotIndex];
swap(array, pivotIndex, right); // move pivot to end
int storeIndex = left;
for(int i = left; i < right; i++) {
if(array[i] > pivotValue) {
swap(array, storeIndex, i);
storeIndex++;
}
}
swap(array, right, storeIndex); // Move pivot to its final place
return storeIndex;
}

private static void swap(short[] array, int a, int b) {
short tmp = array[a];
array[a] = array[b];
array[b] = tmp;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,6 @@ private HitCounter countHits(LeafReader reader) throws IOException {
}

private DocIdSetIterator buildDocIdSetIterator(HitCounter counter) {
// TODO: Add back this logging once log4j mess has settled.
// if (counter.numHits() < candidates) {
// logger.warn(String.format(
// "Found fewer approximate matches [%d] than the requested number of candidates [%d]",
// counter.numHits(), candidates));
// }
if (counter.isEmpty()) return DocIdSetIterator.empty();
else {

Expand Down Expand Up @@ -114,18 +108,22 @@ public int nextDoc() {
docID = counter.minKey() - 1;
}

// Ensure that docs with count = kgr.kthGreatest are only emitted when there are fewer
// than `candidates` docs with count > kgr.kthGreatest.
while (true) {
// We've emitted `candidates` docs or the next doc would exceed the max key,
// so we stop emitting.
if (numEmitted == candidates || docID + 1 > counter.maxKey()) {
docID = DocIdSetIterator.NO_MORE_DOCS;
return docID();
} else {
// Increment and check the next doc.
docID++;
if (counter.get(docID) > kgr.kthGreatest) {
numEmitted++;
return docID();
} else if (counter.get(docID) == kgr.kthGreatest && numEq < candidates - kgr.numGreaterThan) {
}
// Ensure that docs with count = kgr.kthGreatest are only emitted when there are
// fewer than `candidates` docs with count > kgr.kthGreatest.
else if (counter.get(docID) == kgr.kthGreatest && numEq < candidates - kgr.numGreaterThan) {
numEq++;
numEmitted++;
return docID();
Expand Down