diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h index 600061dead..1b8e332df1 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h @@ -24,6 +24,112 @@ namespace cuvs::bench { +class s_res { + public: + using large_mr_type = rmm::mr::managed_memory_resource; + using mr_type = rmm::mr::pool_memory_resource; + + s_res() + try + : orig_resource_{rmm::mr::get_current_device_resource()}, + large_mr_{}, + resource_(&large_mr_, 1024 * 1024 * 1024ull) { + rmm::mr::set_current_device_resource(&large_mr_); + } catch (const std::exception& e) { + auto cuda_status = cudaGetLastError(); + size_t free = 0; + size_t total = 0; + RAFT_CUDA_TRY_NO_THROW(cudaMemGetInfo(&free, &total)); + RAFT_FAIL( + "Failed to initialize shared raft resources (NB: latest cuda status = %s, free memory = %zu, " + "total memory = %zu): %s", + cudaGetErrorName(cuda_status), + free, + total, + e.what()); + } + + s_res(s_res&&) = delete; + auto operator=(s_res&&) -> s_res& = delete; + s_res(const s_res& res) = delete; + auto operator=(const s_res& other) -> s_res& = delete; + + ~s_res() noexcept { rmm::mr::set_current_device_resource(orig_resource_); } + + auto get_small_memory_resource() noexcept + { + return static_cast(&resource_); + } + + auto get_large_memory_resource() noexcept + { + return static_cast(&large_mr_); + } + + private: + rmm::mr::device_memory_resource* orig_resource_; + large_mr_type large_mr_; + mr_type resource_; +}; + +class cagra_build_raft_resources { + public: + /** + * This constructor has the shared state passed unmodified but creates the local state anew. + * It's used by the copy constructor. + */ + explicit cagra_build_raft_resources(const std::shared_ptr& shared_res) + : shared_res_{shared_res}, + res_{std::make_unique( + rmm::cuda_stream_view(get_stream_from_global_pool()))} + { + // set the large workspace resource to the raft handle, but without the deleter + // (this resource is managed by the shared_res). + raft::resource::set_workspace_resource( + *res_, + std::shared_ptr(shared_res_->get_small_memory_resource(), + raft::void_op{})); + raft::resource::set_large_workspace_resource( + *res_, + std::shared_ptr(shared_res_->get_large_memory_resource(), + raft::void_op{})); + } + + /** Default constructor creates all resources anew. */ + cagra_build_raft_resources() : cagra_build_raft_resources{std::make_shared()} {} + + cagra_build_raft_resources(cagra_build_raft_resources&&); + auto operator=(cagra_build_raft_resources&&) -> cagra_build_raft_resources&; + ~cagra_build_raft_resources() = default; + cagra_build_raft_resources(const cagra_build_raft_resources& res) + : cagra_build_raft_resources{res.shared_res_} + { + } + auto operator=(const cagra_build_raft_resources& other) -> cagra_build_raft_resources& + { + this->shared_res_ = other.shared_res_; + return *this; + } + + operator raft::resources&() noexcept { return *res_; } // NOLINT + operator const raft::resources&() const noexcept { return *res_; } // NOLINT + + /** Get the main stream */ + [[nodiscard]] auto get_sync_stream() const noexcept + { + return raft::resource::get_cuda_stream(*res_); + } + + private: + /** The resources shared among multiple raft handles / threads. */ + std::shared_ptr shared_res_; + /** + * Until we make the use of copies of raft::resources thread-safe, each benchmark wrapper must + * have its own copy of it. + */ + std::unique_ptr res_ = std::make_unique(); +}; + template class cuvs_cagra_hnswlib : public algo, public algo_gpu { public: @@ -82,7 +188,7 @@ class cuvs_cagra_hnswlib : public algo, public algo_gpu { } private: - configured_raft_resources handle_{}; + cagra_build_raft_resources handle_{}; build_param build_param_; search_param search_param_; std::shared_ptr> hnsw_index_;