Skip to content

Commit

Permalink
address review
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Dec 14, 2023
1 parent 1da21e4 commit e5cd5f6
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 128 deletions.
7 changes: 7 additions & 0 deletions cpp/include/raft/neighbors/cagra_hnswlib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@

namespace raft::neighbors::cagra_hnswlib {

/**
* @addtogroup cagra_hnswlib Build CAGRA index and search with hnswlib
* @{
*/

/**
* @brief Search hnswlib base layer only index constructed from a CAGRA index
*
Expand Down Expand Up @@ -85,4 +90,6 @@ void search(raft::resources const& res,
detail::search(res, params, idx, queries, neighbors, distances);
}

/**@}*/

} // namespace raft::neighbors::cagra_hnswlib
10 changes: 9 additions & 1 deletion cpp/include/raft/neighbors/cagra_hnswlib_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@

namespace raft::neighbors::cagra_hnswlib {

/**
* @defgroup cagra_hnswlib Build CAGRA index and search with hnswlib
* @{
*/

struct search_params : ann::search_params {
int ef; // size of the candidate list
int num_threads = 1; // number of host threads to use for concurrent searches
int num_threads = 0; // number of host threads to use for concurrent searches. Value of 0
// automatically maximizes parallelism
};

template <typename T>
Expand Down Expand Up @@ -59,4 +65,6 @@ struct index : ann::index {
raft::distance::DistanceType metric_;
};

/**@}*/

} // namespace raft::neighbors::cagra_hnswlib
149 changes: 22 additions & 127 deletions cpp/include/raft/neighbors/detail/cagra_hnswlib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,129 +18,16 @@

#include "../hnswlib_types.hpp"

#include <cstdint>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resources.hpp>

#include <atomic>
#include <future>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <thread>
#include <utility>
#include <omp.h>

#include <hnswlib.h>

namespace raft::neighbors::cagra_hnswlib::detail {

class FixedThreadPool {
public:
FixedThreadPool(int num_threads)
{
if (num_threads < 1) {
throw std::runtime_error("num_threads must >= 1");
} else if (num_threads == 1) {
return;
}

tasks_ = new Task_[num_threads];

threads_.reserve(num_threads);
for (int i = 0; i < num_threads; ++i) {
threads_.emplace_back([&, i] {
auto& task = tasks_[i];
while (true) {
std::unique_lock<std::mutex> lock(task.mtx);
task.cv.wait(lock,
[&] { return task.has_task || finished_.load(std::memory_order_relaxed); });
if (finished_.load(std::memory_order_relaxed)) { break; }

task.task();
task.has_task = false;
}
});
}
}

~FixedThreadPool()
{
if (threads_.empty()) { return; }

finished_.store(true, std::memory_order_relaxed);
for (unsigned i = 0; i < threads_.size(); ++i) {
auto& task = tasks_[i];
std::lock_guard<std::mutex>(task.mtx);

task.cv.notify_one();
threads_[i].join();
}

delete[] tasks_;
}

template <typename Func, typename IdxT>
void submit(Func f, IdxT len)
{
// Run functions in main thread if thread pool has no threads
if (threads_.empty()) {
for (IdxT i = 0; i < len; ++i) {
f(i);
}
return;
}

const int num_threads = threads_.size();
// one extra part for competition among threads
const IdxT items_per_thread = len / (num_threads + 1);
std::atomic<IdxT> cnt(items_per_thread * num_threads);

// Wrap function
auto wrapped_f = [&](IdxT start, IdxT end) {
for (IdxT i = start; i < end; ++i) {
f(i);
}

while (true) {
IdxT i = cnt.fetch_add(1, std::memory_order_relaxed);
if (i >= len) { break; }
f(i);
}
};

std::vector<std::future<void>> futures;
futures.reserve(num_threads);
for (int i = 0; i < num_threads; ++i) {
IdxT start = i * items_per_thread;
auto& task = tasks_[i];
{
std::lock_guard lock(task.mtx);
(void)lock; // stop nvcc warning
task.task = std::packaged_task<void()>([=] { wrapped_f(start, start + items_per_thread); });
futures.push_back(task.task.get_future());
task.has_task = true;
}
task.cv.notify_one();
}

for (auto& fut : futures) {
fut.wait();
}
return;
}

private:
struct alignas(64) Task_ {
std::mutex mtx;
std::condition_variable cv;
bool has_task = false;
std::packaged_task<void()> task;
};

Task_* tasks_;
std::vector<std::thread> threads_;
std::atomic<bool> finished_{false};
};

template <typename T>
void get_search_knn_results(hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const* idx,
const T* query,
Expand Down Expand Up @@ -170,18 +57,26 @@ void search(raft::resources const& res,
reinterpret_cast<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const*>(
idx.get_index());

// no-op when num_threads == 1, no synchronization overhead
FixedThreadPool thread_pool{params.num_threads};

auto f = [&](auto const& i) {
get_search_knn_results(hnswlib_index,
queries.data_handle() + i * queries.extent(1),
neighbors.extent(1),
neighbors.data_handle() + i * neighbors.extent(1),
distances.data_handle() + i * distances.extent(1));
};

thread_pool.submit(f, queries.extent(0));
// when num_threads == 0, automatically maximize parallelism
if (params.num_threads) {
#pragma omp parallel for num_threads(params.num_threads)
for (int64_t i = 0; i < queries.extent(0); ++i) {
get_search_knn_results(hnswlib_index,
queries.data_handle() + i * queries.extent(1),
neighbors.extent(1),
neighbors.data_handle() + i * neighbors.extent(1),
distances.data_handle() + i * distances.extent(1));
}
} else {
#pragma omp parallel for
for (int64_t i = 0; i < queries.extent(0); ++i) {
get_search_knn_results(hnswlib_index,
queries.data_handle() + i * queries.extent(1),
neighbors.extent(1),
neighbors.data_handle() + i * neighbors.extent(1),
distances.data_handle() + i * distances.extent(1));
}
}
}

} // namespace raft::neighbors::cagra_hnswlib::detail
7 changes: 7 additions & 0 deletions cpp/include/raft/neighbors/hnswlib_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@

namespace raft::neighbors::cagra_hnswlib {

/**
* @addtogroup cagra_hnswlib Build CAGRA index and search with hnswlib
* @{
*/

template <typename T>
struct hnsw_dist_t {
using type = void;
Expand Down Expand Up @@ -91,4 +96,6 @@ struct hnswlib_index : index<T> {
std::unique_ptr<hnswlib::SpaceInterface<typename hnsw_dist_t<T>::type>> space_;
};

/**@}*/

} // namespace raft::neighbors::cagra_hnswlib
11 changes: 11 additions & 0 deletions docs/source/cpp_api/neighbors_cagra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,14 @@ namespace *raft::neighbors::cagra*
:project: RAFT
:members:
:content-only:

CAGRA index build and hnswlib search
------------------------------------
``#include <raft/neighbors/cagra_hnswlib.hpp>``

namespace *raft::neighbors::cagra_hnswlib*

.. doxygengroup:: cagra_hnswlib
:project: RAFT
:members:
:content-only:
14 changes: 14 additions & 0 deletions docs/source/pylibraft_api/neighbors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ Serializer Methods

.. autofunction:: pylibraft.neighbors.cagra.load

CAGRA hnswlib
#############

.. autoclass:: pylibraft.neighbors.cagra_hnswlib.SearchParams
:members:

.. autofunction:: pylibraft.neighbors.cagra_hnswlib.search

Serializer Methods
------------------
.. autofunction:: pylibraft.neighbors.cagra_hnswlib.save

.. autofunction:: pylibraft.neighbors.cagra_hnswlib.load

IVF-Flat
########

Expand Down

0 comments on commit e5cd5f6

Please sign in to comment.