Skip to content

Commit 1f8b421

Browse files
authored
Merge branch 'branch-23.12' into 23.10-ann-deletion
2 parents a63be75 + bafd2a8 commit 1f8b421

File tree

4 files changed

+27
-13
lines changed

4 files changed

+27
-13
lines changed

cpp/bench/ann/src/common/benchmark.hpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <benchmark/benchmark.h>
2424

2525
#include <algorithm>
26+
#include <atomic>
2627
#include <chrono>
2728
#include <cmath>
2829
#include <condition_variable>
@@ -39,6 +40,7 @@ namespace raft::bench::ann {
3940

4041
std::mutex init_mutex;
4142
std::condition_variable cond_var;
43+
std::atomic_int processed_threads{0};
4244

4345
static inline std::unique_ptr<AnnBase> current_algo{nullptr};
4446
static inline std::shared_ptr<AlgoProperty> current_algo_props{nullptr};
@@ -198,7 +200,8 @@ void bench_search(::benchmark::State& state,
198200
* Make sure the first thread loads the algo and dataset
199201
*/
200202
if (state.thread_index() == 0) {
201-
std::lock_guard lk(init_mutex);
203+
std::unique_lock lk(init_mutex);
204+
cond_var.wait(lk, [] { return processed_threads.load(std::memory_order_acquire) == 0; });
202205
// algo is static to cache it between close search runs to save time on index loading
203206
static std::string index_file = "";
204207
if (index.file != index_file) {
@@ -247,11 +250,14 @@ void bench_search(::benchmark::State& state,
247250
}
248251

249252
query_set = dataset->query_set(current_algo_props->query_memory_type);
253+
processed_threads.store(state.threads(), std::memory_order_acq_rel);
250254
cond_var.notify_all();
251255
} else {
252-
// All other threads will wait for the first thread to initialize the algo.
253256
std::unique_lock lk(init_mutex);
254-
cond_var.wait(lk, [] { return current_algo_props.get() != nullptr; });
257+
// All other threads will wait for the first thread to initialize the algo.
258+
cond_var.wait(lk, [&state] {
259+
return processed_threads.load(std::memory_order_acquire) == state.threads();
260+
});
255261
// gbench ensures that all threads are synchronized at the start of the benchmark loop.
256262
// We are accessing shared variables (like current_algo, current_algo_probs) before the
257263
// benchmark loop, therefore the synchronization here is necessary.
@@ -292,6 +298,7 @@ void bench_search(::benchmark::State& state,
292298

293299
// advance to the next batch
294300
batch_offset = (batch_offset + n_queries) % query_set_size;
301+
295302
queries_processed += n_queries;
296303
}
297304
}
@@ -312,6 +319,10 @@ void bench_search(::benchmark::State& state,
312319

313320
if (state.skipped()) { return; }
314321

322+
// assume thread has finished processing successfully at this point
323+
// last thread to finish processing notifies all
324+
if (processed_threads-- == 0) { cond_var.notify_all(); }
325+
315326
// Use the last thread as a sanity check that all the threads are working.
316327
if (state.thread_index() == state.threads() - 1) {
317328
// evaluate recall
@@ -410,7 +421,6 @@ void register_search(std::shared_ptr<const Dataset<T>> dataset,
410421
auto* b = ::benchmark::RegisterBenchmark(
411422
index.name + suf, bench_search<T>, index, i, dataset, metric_objective)
412423
->Unit(benchmark::kMillisecond)
413-
->ThreadRange(threads[0], threads[1])
414424
/**
415425
* The following are important for getting accuracy QPS measurements on both CPU
416426
* and GPU These make sure that
@@ -420,6 +430,8 @@ void register_search(std::shared_ptr<const Dataset<T>> dataset,
420430
*/
421431
->MeasureProcessCPUTime()
422432
->UseRealTime();
433+
434+
if (metric_objective == Objective::THROUGHPUT) { b->ThreadRange(threads[0], threads[1]); }
423435
}
424436
}
425437
}

cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ void HnswLib<T>::build(const T* dataset, size_t nrow, cudaStream_t)
147147
char buf[20];
148148
std::time_t now = std::time(nullptr);
149149
std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", std::localtime(&now));
150-
151150
printf("%s building %zu / %zu\n", buf, i, items_per_thread);
152151
fflush(stdout);
153152
}
@@ -163,13 +162,11 @@ void HnswLib<T>::set_search_param(const AnnSearchParam& param_)
163162
auto param = dynamic_cast<const SearchParam&>(param_);
164163
appr_alg_->ef_ = param.ef;
165164
metric_objective_ = param.metric_objective;
165+
num_threads_ = param.num_threads;
166166

167-
bool use_pool = (metric_objective_ == Objective::LATENCY && param.num_threads > 1) &&
168-
(!thread_pool_ || num_threads_ != param.num_threads);
169-
if (use_pool) {
170-
num_threads_ = param.num_threads;
171-
thread_pool_ = std::make_unique<FixedThreadPool>(num_threads_);
172-
}
167+
// Create a pool if multiple query threads have been set and the pool hasn't been created already
168+
bool create_pool = (metric_objective_ == Objective::LATENCY && num_threads_ > 1 && !thread_pool_);
169+
if (create_pool) { thread_pool_ = std::make_unique<FixedThreadPool>(num_threads_); }
173170
}
174171

175172
template <typename T>
@@ -180,7 +177,7 @@ void HnswLib<T>::search(
180177
// hnsw can only handle a single vector at a time.
181178
get_search_knn_results_(query + i * dim_, k, indices + i * k, distances + i * k);
182179
};
183-
if (metric_objective_ == Objective::LATENCY) {
180+
if (metric_objective_ == Objective::LATENCY && num_threads_ > 1) {
184181
thread_pool_->submit(f, batch_size);
185182
} else {
186183
for (int i = 0; i < batch_size; i++) {

cpp/include/raft/core/detail/copy.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#pragma once
1818
#include <cstdio>
19-
#include <execution>
2019
#include <raft/core/cuda_support.hpp>
2120
#include <raft/core/device_mdspan.hpp>
2221
#include <raft/core/error.hpp>

python/raft-ann-bench/src/raft-ann-bench/run/conf/datasets.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,36 +39,42 @@
3939
dims: 960
4040
base_file: gist-960-euclidean/base.fbin
4141
query_file: gist-960-euclidean/query.fbin
42+
groundtruth_neighbors_file: gist-960-euclidean/groundtruth.neighbors.ibin
4243
distance: euclidean
4344

4445
- name: glove-50-angular
4546
dims: 50
4647
base_file: glove-50-angular/base.fbin
4748
query_file: glove-50-angular/query.fbin
49+
groundtruth_neighbors_file: glove-50-angular/groundtruth.neighbors.ibin
4850
distance: euclidean
4951

5052
- name: glove-50-inner
5153
dims: 50
5254
base_file: glove-50-inner/base.fbin
5355
query_file: glove-50-inner/query.fbin
56+
groundtruth_neighbors_file: glove-50-inner/groundtruth.neighbors.ibin
5457
distance: euclidean
5558

5659
- name: glove-100-angular
5760
dims: 100
5861
base_file: glove-100-angular/base.fbin
5962
query_file: glove-100-angular/query.fbin
63+
groundtruth_neighbors_file: glove-100-angular/groundtruth.neighbors.ibin
6064
distance: euclidean
6165

6266
- name: glove-100-inner
6367
dims: 100
6468
base_file: glove-100-inner/base.fbin
6569
query_file: glove-100-inner/query.fbin
70+
groundtruth_neighbors_file: glove-100-inner/groundtruth.neighbors.ibin
6671
distance: euclidean
6772

6873
- name: lastfm-65-angular
6974
dims: 65
7075
base_file: lastfm-65-angular/base.fbin
7176
query_file: lastfm-65-angular/query.fbin
77+
groundtruth_neighbors_file: lastfm-65-angular/groundtruth.neighbors.ibin
7278
distance: euclidean
7379

7480
- name: mnist-784-euclidean

0 commit comments

Comments
 (0)