Skip to content

Commit

Permalink
Add ANN.copy support to cagra_hnswlib wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Nov 24, 2023
1 parent 9804dcd commit 8855197
Showing 1 changed file with 22 additions and 34 deletions.
56 changes: 22 additions & 34 deletions cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,12 @@ class RaftCagraHnswlib : public ANN<T> {

RaftCagraHnswlib(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1)
: ANN<T>(metric, dim),
metric_(metric),
index_params_(param),
dimension_(dim),
handle_(cudaStreamPerThread)
cagra_build_{metric, dim, param, concurrent_searches},
// HnswLib param values don't matter since we don't build with HnswLib
hnswlib_search_{metric, dim, typename HnswLib<T>::BuildParam{50, 100}}
{
}

~RaftCagraHnswlib() noexcept {}

void build(const T* dataset, size_t nrow, cudaStream_t stream) final;

void set_search_param(const AnnSearchParam& param) override;
Expand All @@ -60,62 +57,53 @@ class RaftCagraHnswlib : public ANN<T> {
property.query_memory_type = MemoryType::Host;
return property;
}

void save(const std::string& file) const override;
void load(const std::string&) override;
std::unique_ptr<ANN<T>> copy() override;
std::unique_ptr<ANN<T>> copy() override
{
return std::make_unique<RaftCagraHnswlib<T, IdxT>>(*this);
}

private:
raft::device_resources handle_;
Metric metric_;
BuildParam index_params_;
int dimension_;

std::unique_ptr<RaftCagra<T, IdxT>> cagra_build_;
std::unique_ptr<HnswLib<T>> hnswlib_search_;

Objective metric_objective_;
RaftCagra<T, IdxT> cagra_build_;
HnswLib<T> hnswlib_search_;
};

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t stream)
{
if (not cagra_build_) {
cagra_build_ = std::make_unique<RaftCagra<T, IdxT>>(metric_, dimension_, index_params_);
}
cagra_build_->build(dataset, nrow, stream);
cagra_build_.build(dataset, nrow, stream);
}

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::set_search_param(const AnnSearchParam& param_)
{
hnswlib_search_->set_search_param(param_);
hnswlib_search_.set_search_param(param_);
}

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::save(const std::string& file) const
{
cagra_build_->save_to_hnswlib(file);
cagra_build_.save_to_hnswlib(file);
}

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::load(const std::string& file)
{
typename HnswLib<T>::BuildParam param;
// these values don't matter since we don't build with HnswLib
param.M = 50;
param.ef_construction = 100;
if (not hnswlib_search_) {
hnswlib_search_ = std::make_unique<HnswLib<T>>(metric_, dimension_, param);
}
hnswlib_search_->load(file);
hnswlib_search_->set_base_layer_only();
hnswlib_search_.load(file);
hnswlib_search_.set_base_layer_only();
}

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const
void RaftCagraHnswlib<T, IdxT>::search(const T* queries,
int batch_size,
int k,
size_t* neighbors,
float* distances,
cudaStream_t stream) const
{
hnswlib_search_->search(queries, batch_size, k, neighbors, distances);
hnswlib_search_.search(queries, batch_size, k, neighbors, distances, stream);
}

} // namespace raft::bench::ann

0 comments on commit 8855197

Please sign in to comment.