From c9bfc9ad4ce09f37337ff96deefa7626510a247a Mon Sep 17 00:00:00 2001 From: Akira Naruse Date: Thu, 14 Dec 2023 17:20:33 +0900 Subject: [PATCH] Satisfy pre-commit --- .../raft/neighbors/detail/refine_host-inl.hpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp index 5a9e22cc90..7b7e916dfa 100644 --- a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp +++ b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp @@ -47,14 +47,14 @@ template >> - refined_pairs(n_queries, std::vector>(orig_k)); + if (n_queries < size_t(suggested_n_threads)) { + std::vector>> refined_pairs( + n_queries, std::vector>(orig_k)); // For efficiency, each thread should read a certain amount of array // elements. The number of threads for distance computation is determined // taking this into account. - auto n_elements = std::max(size_t(512), dim); + auto n_elements = std::max(size_t(512), dim); auto max_n_threads = raft::div_rounding_up_safe(n_queries * orig_k * dim, n_elements); auto suggested_n_threads_for_distance = std::min(size_t(suggested_n_threads), max_n_threads); @@ -65,9 +65,9 @@ template (id) >= n_rows) { distance = std::numeric_limits::max(); } else { @@ -76,7 +76,7 @@ template (query[k], row[k]); } } - refined_pairs[i][j] = std::make_tuple(distance, id); + refined_pairs[i][j] = std::make_tuple(distance, id); } } @@ -86,10 +86,10 @@ template (refined_pairs[i][j]); - if (distances.data_handle() != nullptr) { - distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[i][j])); - } + indices(i, j) = std::get<1>(refined_pairs[i][j]); + if (distances.data_handle() != nullptr) { + distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[i][j])); + } } } return;