23
23
#include < benchmark/benchmark.h>
24
24
25
25
#include < algorithm>
26
+ #include < atomic>
26
27
#include < chrono>
27
28
#include < cmath>
28
29
#include < condition_variable>
@@ -39,6 +40,7 @@ namespace raft::bench::ann {
39
40
40
41
std::mutex init_mutex;
41
42
std::condition_variable cond_var;
43
+ std::atomic_int processed_threads{0 };
42
44
43
45
static inline std::unique_ptr<AnnBase> current_algo{nullptr };
44
46
static inline std::shared_ptr<AlgoProperty> current_algo_props{nullptr };
@@ -198,7 +200,8 @@ void bench_search(::benchmark::State& state,
198
200
* Make sure the first thread loads the algo and dataset
199
201
*/
200
202
if (state.thread_index () == 0 ) {
201
- std::lock_guard lk (init_mutex);
203
+ std::unique_lock lk (init_mutex);
204
+ cond_var.wait (lk, [] { return processed_threads.load (std::memory_order_acquire) == 0 ; });
202
205
// algo is static to cache it between close search runs to save time on index loading
203
206
static std::string index_file = " " ;
204
207
if (index.file != index_file) {
@@ -247,11 +250,14 @@ void bench_search(::benchmark::State& state,
247
250
}
248
251
249
252
query_set = dataset->query_set (current_algo_props->query_memory_type );
253
+ processed_threads.store (state.threads (), std::memory_order_acq_rel);
250
254
cond_var.notify_all ();
251
255
} else {
252
- // All other threads will wait for the first thread to initialize the algo.
253
256
std::unique_lock lk (init_mutex);
254
- cond_var.wait (lk, [] { return current_algo_props.get () != nullptr ; });
257
+ // All other threads will wait for the first thread to initialize the algo.
258
+ cond_var.wait (lk, [&state] {
259
+ return processed_threads.load (std::memory_order_acquire) == state.threads ();
260
+ });
255
261
// gbench ensures that all threads are synchronized at the start of the benchmark loop.
256
262
// We are accessing shared variables (like current_algo, current_algo_probs) before the
257
263
// benchmark loop, therefore the synchronization here is necessary.
@@ -292,6 +298,7 @@ void bench_search(::benchmark::State& state,
292
298
293
299
// advance to the next batch
294
300
batch_offset = (batch_offset + n_queries) % query_set_size;
301
+
295
302
queries_processed += n_queries;
296
303
}
297
304
}
@@ -312,6 +319,10 @@ void bench_search(::benchmark::State& state,
312
319
313
320
if (state.skipped ()) { return ; }
314
321
322
+ // assume thread has finished processing successfully at this point
323
+ // last thread to finish processing notifies all
324
+ if (processed_threads-- == 0 ) { cond_var.notify_all (); }
325
+
315
326
// Use the last thread as a sanity check that all the threads are working.
316
327
if (state.thread_index () == state.threads () - 1 ) {
317
328
// evaluate recall
@@ -410,7 +421,6 @@ void register_search(std::shared_ptr<const Dataset<T>> dataset,
410
421
auto * b = ::benchmark::RegisterBenchmark (
411
422
index.name + suf, bench_search<T>, index, i, dataset, metric_objective)
412
423
->Unit (benchmark::kMillisecond )
413
- ->ThreadRange (threads[0 ], threads[1 ])
414
424
/* *
415
425
* The following are important for getting accuracy QPS measurements on both CPU
416
426
* and GPU These make sure that
@@ -420,6 +430,8 @@ void register_search(std::shared_ptr<const Dataset<T>> dataset,
420
430
*/
421
431
->MeasureProcessCPUTime ()
422
432
->UseRealTime ();
433
+
434
+ if (metric_objective == Objective::THROUGHPUT) { b->ThreadRange (threads[0 ], threads[1 ]); }
423
435
}
424
436
}
425
437
}
0 commit comments