From 885519782cac6d7375b45c5e4494c524b1f9e4dd Mon Sep 17 00:00:00 2001 From: achirkin Date: Fri, 24 Nov 2023 15:26:40 +0100 Subject: [PATCH] Add ANN.copy support to cagra_hnswlib wrapper --- .../ann/src/raft/raft_cagra_hnswlib_wrapper.h | 56 ++++++++----------- 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h index e42cb5e7f2..3fd0a374b7 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h @@ -30,15 +30,12 @@ class RaftCagraHnswlib : public ANN { RaftCagraHnswlib(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) : ANN(metric, dim), - metric_(metric), - index_params_(param), - dimension_(dim), - handle_(cudaStreamPerThread) + cagra_build_{metric, dim, param, concurrent_searches}, + // HnswLib param values don't matter since we don't build with HnswLib + hnswlib_search_{metric, dim, typename HnswLib::BuildParam{50, 100}} { } - ~RaftCagraHnswlib() noexcept {} - void build(const T* dataset, size_t nrow, cudaStream_t stream) final; void set_search_param(const AnnSearchParam& param) override; @@ -60,62 +57,53 @@ class RaftCagraHnswlib : public ANN { property.query_memory_type = MemoryType::Host; return property; } + void save(const std::string& file) const override; void load(const std::string&) override; - std::unique_ptr> copy() override; + std::unique_ptr> copy() override + { + return std::make_unique>(*this); + } private: - raft::device_resources handle_; - Metric metric_; - BuildParam index_params_; - int dimension_; - - std::unique_ptr> cagra_build_; - std::unique_ptr> hnswlib_search_; - - Objective metric_objective_; + RaftCagra cagra_build_; + HnswLib hnswlib_search_; }; template void RaftCagraHnswlib::build(const T* dataset, size_t nrow, cudaStream_t stream) { - if (not cagra_build_) { - cagra_build_ = std::make_unique>(metric_, dimension_, index_params_); - } - cagra_build_->build(dataset, nrow, stream); + cagra_build_.build(dataset, nrow, stream); } template void RaftCagraHnswlib::set_search_param(const AnnSearchParam& param_) { - hnswlib_search_->set_search_param(param_); + hnswlib_search_.set_search_param(param_); } template void RaftCagraHnswlib::save(const std::string& file) const { - cagra_build_->save_to_hnswlib(file); + cagra_build_.save_to_hnswlib(file); } template void RaftCagraHnswlib::load(const std::string& file) { - typename HnswLib::BuildParam param; - // these values don't matter since we don't build with HnswLib - param.M = 50; - param.ef_construction = 100; - if (not hnswlib_search_) { - hnswlib_search_ = std::make_unique>(metric_, dimension_, param); - } - hnswlib_search_->load(file); - hnswlib_search_->set_base_layer_only(); + hnswlib_search_.load(file); + hnswlib_search_.set_base_layer_only(); } template -void RaftCagraHnswlib::search( - const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const +void RaftCagraHnswlib::search(const T* queries, + int batch_size, + int k, + size_t* neighbors, + float* distances, + cudaStream_t stream) const { - hnswlib_search_->search(queries, batch_size, k, neighbors, distances); + hnswlib_search_.search(queries, batch_size, k, neighbors, distances, stream); } } // namespace raft::bench::ann