Skip to content

Commit

Permalink
Satisfy pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
anaruse committed Dec 14, 2023
1 parent 080b221 commit c9bfc9a
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions cpp/include/raft/neighbors/detail/refine_host-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ template <typename DC, typename IdxT, typename DataT, typename DistanceT, typena
// If the number of queries is small, separete the distance calculation and
// the top-k calculation into separete loops, and apply finer-grained thread
// parallelism to the distance calculation loop.
if ( n_queries < size_t(suggested_n_threads) ) {
std::vector<std::vector<std::tuple<DistanceT, IdxT>>>
refined_pairs(n_queries, std::vector<std::tuple<DistanceT, IdxT>>(orig_k));
if (n_queries < size_t(suggested_n_threads)) {
std::vector<std::vector<std::tuple<DistanceT, IdxT>>> refined_pairs(
n_queries, std::vector<std::tuple<DistanceT, IdxT>>(orig_k));

// For efficiency, each thread should read a certain amount of array
// elements. The number of threads for distance computation is determined
// taking this into account.
auto n_elements = std::max(size_t(512), dim);
auto n_elements = std::max(size_t(512), dim);
auto max_n_threads = raft::div_rounding_up_safe<size_t>(n_queries * orig_k * dim, n_elements);
auto suggested_n_threads_for_distance = std::min(size_t(suggested_n_threads), max_n_threads);

Expand All @@ -65,9 +65,9 @@ template <typename DC, typename IdxT, typename DataT, typename DistanceT, typena
#pragma omp parallel for collapse(2) num_threads(suggested_n_threads_for_distance)
for (size_t i = 0; i < n_queries; i++) {
for (size_t j = 0; j < orig_k; j++) {
const DataT* query = queries.data_handle() + dim * i;
IdxT id = neighbor_candidates(i, j);
DistanceT distance = 0.0;
const DataT* query = queries.data_handle() + dim * i;
IdxT id = neighbor_candidates(i, j);
DistanceT distance = 0.0;
if (static_cast<size_t>(id) >= n_rows) {
distance = std::numeric_limits<DistanceT>::max();
} else {
Expand All @@ -76,7 +76,7 @@ template <typename DC, typename IdxT, typename DataT, typename DistanceT, typena
distance += DC::template eval<DistanceT>(query[k], row[k]);
}
}
refined_pairs[i][j] = std::make_tuple(distance, id);
refined_pairs[i][j] = std::make_tuple(distance, id);
}
}

Expand All @@ -86,10 +86,10 @@ template <typename DC, typename IdxT, typename DataT, typename DistanceT, typena
std::sort(refined_pairs[i].begin(), refined_pairs[i].end());
// Store first refined_k neighbors
for (size_t j = 0; j < refined_k; j++) {
indices(i, j) = std::get<1>(refined_pairs[i][j]);
if (distances.data_handle() != nullptr) {
distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[i][j]));
}
indices(i, j) = std::get<1>(refined_pairs[i][j]);
if (distances.data_handle() != nullptr) {
distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[i][j]));
}
}
}
return;
Expand Down

0 comments on commit c9bfc9a

Please sign in to comment.