Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try quick select algorithm for KthGreatest implementation #543

Closed
2 tasks
alexklibisz opened this issue Jul 28, 2023 · 4 comments
Closed
2 tasks

Try quick select algorithm for KthGreatest implementation #543

alexklibisz opened this issue Jul 28, 2023 · 4 comments

Comments

@alexklibisz
Copy link
Owner

Background

I think there could be an opportunity to speed up approximate queries by re-implementing the kthGreatest method using the quick select algorithm.

At a high level, the kthGreatest method is used to find the kth greatest document frequency. We give it an array of counts, each one representing the number of times a distinct document was matched against a set of query terms. It returns the kth greatest count. Then we perform exact similarity scoring on each of the documents that match or exceed this kth greatest count.

There are some good example implementations of quick select on Leet code:

Deliverables

  • Implement and benchmark kthGreatest method using quick select
  • Report the results on this ticket or a PR, if it's good enough to merge

Related Issues

Blocked by #525

@alexklibisz
Copy link
Owner Author

alexklibisz commented Nov 26, 2023

I've partially implemented this in #603. I based much of the quickselect implementation on this excellent gist: https://gist.github.com/unnikked/14c19ba13f6a4bfd00a3

My latest iteration at time of writing is here:

package com.klibisz.elastiknn.search;
public class QuickSelect {
public static short selectRecursive(short[] array, int n) {
return recursive(array, 0, array.length - 1, n);
}
private static short recursive(short[] array, int left, int right, int k) {
if (left == right) { // If the list contains only one element,
return array[left]; // return that element
}
// select a pivotIndex between left and right
int pivotIndex = left + (right - left) / 2;
pivotIndex = partition(array, left, right, pivotIndex);
// The pivot is in its final sorted position
if (k == pivotIndex) {
return array[k];
} else if (k < pivotIndex) {
return recursive(array, left, pivotIndex - 1, k);
} else {
return recursive(array, pivotIndex + 1, right, k);
}
}
private static int partition(short[] array, int left, int right, int pivotIndex) {
int pivotValue = array[pivotIndex];
swap(array, pivotIndex, right); // move pivot to end
int storeIndex = left;
for(int i = left; i < right; i++) {
if(array[i] > pivotValue) {
swap(array, storeIndex, i);
storeIndex++;
}
}
swap(array, right, storeIndex); // Move pivot to its final place
return storeIndex;
}
private static void swap(short[] array, int a, int b) {
short tmp = array[a];
array[a] = array[b];
array[b] = tmp;
}
}

The benchmark is here:

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def unnikedRecursive(f: KthGreatestBenchmarkFixtures): Unit = {
System.arraycopy(f.shortCounts, 0, f.copy, 0, f.copy.length)
QuickSelect.selectRecursive(f.copy, f.k)
()
}

Unfortunately this particular implementation of the quickselect algorithm is somehow actually slower than just sorting. I would speculate that much of this comes from the fact I hae to make a full copy of the array at every iteration. This is necessary as the quickselect method is modifying the array (swapping around values) in order to compute its result, whereas the ArrayHitCounter expects those values to be immutable.

[info] Benchmark                                Mode  Cnt      Score    Error  Units
[info] KthGreatestBenchmarks.kthGreatest       thrpt    5  10796.563 ±  0.931  ops/s
[info] KthGreatestBenchmarks.sortBaseline      thrpt    5   2965.035 ± 85.854  ops/s
[info] KthGreatestBenchmarks.unnikedRecursive  thrpt    5   2171.902 ± 40.547  ops/s
[success] Total time: 153 s (02:33), completed Nov 26, 2023, 11:35:28 PM

@alexklibisz
Copy link
Owner Author

alexklibisz commented Nov 26, 2023

Quickselect is about 30% faster when I switch from a fixed pivot to a random pivot, line 19:

package com.klibisz.elastiknn.search;
import java.util.Random;
public class QuickSelect {
private static final Random rng = new Random(0);
public static short selectRecursive(short[] array, int n) {
return recursive(array, 0, array.length - 1, n);
}
private static short recursive(short[] array, int left, int right, int k) {
if (left == right) { // If the list contains only one element,
return array[left]; // return that element
}
// select a pivotIndex between left and right
int pivotIndex = left + rng.nextInt(right - left);
pivotIndex = partition(array, left, right, pivotIndex);
// The pivot is in its final sorted position
if (k == pivotIndex) {
return array[k];
} else if (k < pivotIndex) {
return recursive(array, left, pivotIndex - 1, k);
} else {
return recursive(array, pivotIndex + 1, right, k);
}
}
private static int partition(short[] array, int left, int right, int pivotIndex) {
int pivotValue = array[pivotIndex];
swap(array, pivotIndex, right); // move pivot to end
int storeIndex = left;
for(int i = left; i < right; i++) {
if(array[i] > pivotValue) {
swap(array, storeIndex, i);
storeIndex++;
}
}
swap(array, right, storeIndex); // Move pivot to its final place
return storeIndex;
}
private static void swap(short[] array, int a, int b) {
short tmp = array[a];
array[a] = array[b];
array[b] = tmp;
}
}

[info] Benchmark                                Mode  Cnt     Score    Error  Units
[info] KthGreatestBenchmarks.unnikedRecursive  thrpt    5  2788.150 ± 16.839  ops/s
[success] Total time: 51 s, completed Nov 26, 2023, 11:48:26 PM

But it still doesn't touch kthGreatest

@alexklibisz
Copy link
Owner Author

This feels similar to using hashmaps instead of arrays to count hits, summarized in this comment: #160 (comment)

Right now I'm benchmarking w/ a dataset of 60k vectors (Fashion Mnist). Optimizations like quickselect and primitive hashmaps might make a positive impact when I'm working with far more vectors. But Fashion Mnist is the benchmark I'm trying to optimize for now.

@alexklibisz
Copy link
Owner Author

Closing this for now. Might re-open if/when I'm benchmarking on a larger dataset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant