Skip to content

Commit f1e0860

Browse files
authored
Merge branch 'branch-24.02' into imp-2402-update_cagra_build_constraint
2 parents 897daf6 + 88e9a55 commit f1e0860

16 files changed

+403
-85
lines changed

conda/recipes/libraft/conda_build_config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,9 @@ cuda11_cuda_profiler_api_host_version:
7171

7272
cuda11_cuda_profiler_api_run_version:
7373
- ">=11.4.240,<12"
74+
75+
spdlog_version:
76+
- ">=1.11.0,<1.12"
77+
78+
fmt_version:
79+
- ">=9.1.0,<10"

conda/recipes/libraft/meta.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,16 @@ outputs:
6363
{% endif %}
6464
- cuda-version ={{ cuda_version }}
6565
- librmm ={{ minor_version }}
66+
- spdlog {{ spdlog_version }}
67+
- fmt {{ fmt_version }}
6668
run:
6769
- {{ pin_compatible('cuda-version', max_pin='x', min_pin='x') }}
6870
{% if cuda_major == "11" %}
6971
- cudatoolkit
7072
{% endif %}
7173
- librmm ={{ minor_version }}
74+
- spdlog {{ spdlog_version }}
75+
- fmt {{ fmt_version }}
7276
about:
7377
home: https://rapids.ai/
7478
license: Apache-2.0

conda/recipes/raft-ann-bench-cpu/conda_build_config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,9 @@ h5py_version:
1818

1919
nlohmann_json_version:
2020
- ">=3.11.2"
21+
22+
spdlog_version:
23+
- ">=1.11.0,<1.12"
24+
25+
fmt_version:
26+
- ">=9.1.0,<10"

conda/recipes/raft-ann-bench-cpu/meta.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ requirements:
4848
- glog {{ glog_version }}
4949
- matplotlib
5050
- nlohmann_json {{ nlohmann_json_version }}
51+
- spdlog {{ spdlog_version }}
52+
- fmt {{ fmt_version }}
5153
- python
5254
- pyyaml
5355
- pandas

cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class RaftIvfFlatGpu : public ANN<T> {
7878
AlgoProperty get_preference() const override
7979
{
8080
AlgoProperty property;
81-
property.dataset_memory_type = MemoryType::Device;
81+
property.dataset_memory_type = MemoryType::HostMmap;
8282
property.query_memory_type = MemoryType::Device;
8383
return property;
8484
}

cpp/include/raft/neighbors/detail/ivf_flat_build.cuh

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels,
120120
uint32_t* list_sizes_ptr,
121121
IdxT n_rows,
122122
uint32_t dim,
123-
uint32_t veclen)
123+
uint32_t veclen,
124+
IdxT batch_offset = 0)
124125
{
125126
const IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x;
126127
if (i >= n_rows) { return; }
@@ -131,7 +132,7 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels,
131132
auto* list_data = list_data_ptrs[list_id];
132133

133134
// Record the source vector id in the index
134-
list_index[inlist_id] = source_ixs == nullptr ? i : source_ixs[i];
135+
list_index[inlist_id] = source_ixs == nullptr ? i + batch_offset : source_ixs[i];
135136

136137
// The data is written in interleaved groups of `index::kGroupSize` vectors
137138
using interleaved_group = Pow2<kIndexGroupSize>;
@@ -180,16 +181,33 @@ void extend(raft::resources const& handle,
180181

181182
auto new_labels = raft::make_device_vector<LabelT, IdxT>(handle, n_rows);
182183
raft::cluster::kmeans_balanced_params kmeans_params;
183-
kmeans_params.metric = index->metric();
184-
auto new_vectors_view = raft::make_device_matrix_view<const T, IdxT>(new_vectors, n_rows, dim);
184+
kmeans_params.metric = index->metric();
185185
auto orig_centroids_view =
186186
raft::make_device_matrix_view<const float, IdxT>(index->centers().data_handle(), n_lists, dim);
187-
raft::cluster::kmeans_balanced::predict(handle,
188-
kmeans_params,
189-
new_vectors_view,
190-
orig_centroids_view,
191-
new_labels.view(),
192-
utils::mapping<float>{});
187+
// Calculate the batch size for the input data if it's not accessible directly from the device
188+
constexpr size_t kReasonableMaxBatchSize = 65536;
189+
size_t max_batch_size = std::min<size_t>(n_rows, kReasonableMaxBatchSize);
190+
191+
// Predict the cluster labels for the new data, in batches if necessary
192+
utils::batch_load_iterator<T> vec_batches(new_vectors,
193+
n_rows,
194+
index->dim(),
195+
max_batch_size,
196+
stream,
197+
resource::get_workspace_resource(handle));
198+
199+
for (const auto& batch : vec_batches) {
200+
auto batch_data_view =
201+
raft::make_device_matrix_view<const T, IdxT>(batch.data(), batch.size(), index->dim());
202+
auto batch_labels_view = raft::make_device_vector_view<LabelT, IdxT>(
203+
new_labels.data_handle() + batch.offset(), batch.size());
204+
raft::cluster::kmeans_balanced::predict(handle,
205+
kmeans_params,
206+
batch_data_view,
207+
orig_centroids_view,
208+
batch_labels_view,
209+
utils::mapping<float>{});
210+
}
193211

194212
auto* list_sizes_ptr = index->list_sizes().data_handle();
195213
auto old_list_sizes_dev = raft::make_device_vector<uint32_t, IdxT>(handle, n_lists);
@@ -202,14 +220,19 @@ void extend(raft::resources const& handle,
202220
auto list_sizes_view =
203221
raft::make_device_vector_view<std::remove_pointer_t<decltype(list_sizes_ptr)>, IdxT>(
204222
list_sizes_ptr, n_lists);
205-
auto const_labels_view = make_const_mdspan(new_labels.view());
206-
raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle,
207-
new_vectors_view,
208-
const_labels_view,
209-
centroids_view,
210-
list_sizes_view,
211-
false,
212-
utils::mapping<float>{});
223+
for (const auto& batch : vec_batches) {
224+
auto batch_data_view =
225+
raft::make_device_matrix_view<const T, IdxT>(batch.data(), batch.size(), index->dim());
226+
auto batch_labels_view = raft::make_device_vector_view<const LabelT, IdxT>(
227+
new_labels.data_handle() + batch.offset(), batch.size());
228+
raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle,
229+
batch_data_view,
230+
batch_labels_view,
231+
centroids_view,
232+
list_sizes_view,
233+
false,
234+
utils::mapping<float>{});
235+
}
213236
} else {
214237
raft::stats::histogram<uint32_t, IdxT>(raft::stats::HistTypeAuto,
215238
reinterpret_cast<int32_t*>(list_sizes_ptr),
@@ -244,20 +267,39 @@ void extend(raft::resources const& handle,
244267
// we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter.
245268
raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream);
246269

247-
// Kernel to insert the new vectors
248-
const dim3 block_dim(256);
249-
const dim3 grid_dim(raft::ceildiv<IdxT>(n_rows, block_dim.x));
250-
build_index_kernel<<<grid_dim, block_dim, 0, stream>>>(new_labels.data_handle(),
251-
new_vectors,
252-
new_indices,
253-
index->data_ptrs().data_handle(),
254-
index->inds_ptrs().data_handle(),
255-
list_sizes_ptr,
256-
n_rows,
257-
dim,
258-
index->veclen());
259-
RAFT_CUDA_TRY(cudaPeekAtLastError());
260-
270+
utils::batch_load_iterator<IdxT> vec_indices(
271+
new_indices, n_rows, 1, max_batch_size, stream, resource::get_workspace_resource(handle));
272+
utils::batch_load_iterator<IdxT> idx_batch = vec_indices.begin();
273+
size_t next_report_offset = 0;
274+
size_t d_report_offset = n_rows * 5 / 100;
275+
for (const auto& batch : vec_batches) {
276+
auto batch_data_view =
277+
raft::make_device_matrix_view<const T, IdxT>(batch.data(), batch.size(), index->dim());
278+
// Kernel to insert the new vectors
279+
const dim3 block_dim(256);
280+
const dim3 grid_dim(raft::ceildiv<IdxT>(batch.size(), block_dim.x));
281+
build_index_kernel<T, IdxT, LabelT>
282+
<<<grid_dim, block_dim, 0, stream>>>(new_labels.data_handle() + batch.offset(),
283+
batch_data_view.data_handle(),
284+
idx_batch->data(),
285+
index->data_ptrs().data_handle(),
286+
index->inds_ptrs().data_handle(),
287+
list_sizes_ptr,
288+
batch.size(),
289+
dim,
290+
index->veclen(),
291+
batch.offset());
292+
RAFT_CUDA_TRY(cudaPeekAtLastError());
293+
294+
if (batch.offset() > next_report_offset) {
295+
float progress = batch.offset() * 100.0f / n_rows;
296+
RAFT_LOG_DEBUG("ivf_flat::extend added vectors %zu, %6.1f%% complete",
297+
static_cast<size_t>(batch.offset()),
298+
progress);
299+
next_report_offset += d_report_offset;
300+
}
301+
++idx_batch;
302+
}
261303
// Precompute the centers vector norms for L2Expanded distance
262304
if (!index->center_norms().has_value()) {
263305
index->allocate_center_norms(handle);

cpp/include/raft/neighbors/ivf_flat-ext.cuh

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ void build(raft::resources const& handle,
4848
raft::device_matrix_view<const T, IdxT, row_major> dataset,
4949
raft::neighbors::ivf_flat::index<T, IdxT>& idx) RAFT_EXPLICIT;
5050

51+
template <typename T, typename IdxT>
52+
auto build(raft::resources const& handle,
53+
const index_params& params,
54+
raft::host_matrix_view<const T, IdxT, row_major> dataset)
55+
-> index<T, IdxT> RAFT_EXPLICIT;
56+
57+
template <typename T, typename IdxT>
58+
void build(raft::resources const& handle,
59+
const index_params& params,
60+
raft::host_matrix_view<const T, IdxT, row_major> dataset,
61+
raft::neighbors::ivf_flat::index<T, IdxT>& idx) RAFT_EXPLICIT;
62+
5163
template <typename T, typename IdxT>
5264
auto extend(raft::resources const& handle,
5365
const index<T, IdxT>& orig_index,
@@ -74,6 +86,19 @@ void extend(raft::resources const& handle,
7486
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices,
7587
index<T, IdxT>* index) RAFT_EXPLICIT;
7688

89+
template <typename T, typename IdxT>
90+
auto extend(raft::resources const& handle,
91+
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
92+
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices,
93+
const raft::neighbors::ivf_flat::index<T, IdxT>& orig_index)
94+
-> raft::neighbors::ivf_flat::index<T, IdxT> RAFT_EXPLICIT;
95+
96+
template <typename T, typename IdxT>
97+
void extend(raft::resources const& handle,
98+
raft::host_matrix_view<const T, IdxT, row_major> new_vectors,
99+
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices,
100+
index<T, IdxT>* index) RAFT_EXPLICIT;
101+
77102
template <typename T, typename IdxT, typename IvfSampleFilterT>
78103
void search_with_filtering(raft::resources const& handle,
79104
const search_params& params,
@@ -137,6 +162,18 @@ void search(raft::resources const& handle,
137162
raft::resources const& handle, \
138163
const raft::neighbors::ivf_flat::index_params& params, \
139164
raft::device_matrix_view<const T, IdxT, row_major> dataset, \
165+
raft::neighbors::ivf_flat::index<T, IdxT>& idx); \
166+
\
167+
extern template auto raft::neighbors::ivf_flat::build<T, IdxT>( \
168+
raft::resources const& handle, \
169+
const raft::neighbors::ivf_flat::index_params& params, \
170+
raft::host_matrix_view<const T, IdxT, row_major> dataset) \
171+
->raft::neighbors::ivf_flat::index<T, IdxT>; \
172+
\
173+
extern template void raft::neighbors::ivf_flat::build<T, IdxT>( \
174+
raft::resources const& handle, \
175+
const raft::neighbors::ivf_flat::index_params& params, \
176+
raft::host_matrix_view<const T, IdxT, row_major> dataset, \
140177
raft::neighbors::ivf_flat::index<T, IdxT>& idx);
141178

142179
instantiate_raft_neighbors_ivf_flat_build(float, int64_t);
@@ -171,7 +208,20 @@ instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t);
171208
raft::resources const& handle, \
172209
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
173210
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
174-
raft::neighbors::ivf_flat::index<T, IdxT>* index);
211+
raft::neighbors::ivf_flat::index<T, IdxT>* index); \
212+
\
213+
extern template void raft::neighbors::ivf_flat::extend<T, IdxT>( \
214+
raft::resources const& handle, \
215+
raft::host_matrix_view<const T, IdxT, row_major> new_vectors, \
216+
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices, \
217+
raft::neighbors::ivf_flat::index<T, IdxT>* index); \
218+
\
219+
extern template auto raft::neighbors::ivf_flat::extend<T, IdxT>( \
220+
const raft::resources& handle, \
221+
raft::host_matrix_view<const T, IdxT, row_major> new_vectors, \
222+
std::optional<raft::host_vector_view<const IdxT, IdxT>> new_indices, \
223+
const raft::neighbors::ivf_flat::index<T, IdxT>& idx) \
224+
->raft::neighbors::ivf_flat::index<T, IdxT>;
175225

176226
instantiate_raft_neighbors_ivf_flat_extend(float, int64_t);
177227
instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t);

0 commit comments

Comments
 (0)