@@ -120,7 +120,8 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels,
120
120
uint32_t * list_sizes_ptr,
121
121
IdxT n_rows,
122
122
uint32_t dim,
123
- uint32_t veclen)
123
+ uint32_t veclen,
124
+ IdxT batch_offset = 0 )
124
125
{
125
126
const IdxT i = IdxT (blockDim .x ) * IdxT (blockIdx .x ) + threadIdx .x ;
126
127
if (i >= n_rows) { return ; }
@@ -131,7 +132,7 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels,
131
132
auto * list_data = list_data_ptrs[list_id];
132
133
133
134
// Record the source vector id in the index
134
- list_index[inlist_id] = source_ixs == nullptr ? i : source_ixs[i];
135
+ list_index[inlist_id] = source_ixs == nullptr ? i + batch_offset : source_ixs[i];
135
136
136
137
// The data is written in interleaved groups of `index::kGroupSize` vectors
137
138
using interleaved_group = Pow2<kIndexGroupSize >;
@@ -180,16 +181,33 @@ void extend(raft::resources const& handle,
180
181
181
182
auto new_labels = raft::make_device_vector<LabelT, IdxT>(handle, n_rows);
182
183
raft::cluster::kmeans_balanced_params kmeans_params;
183
- kmeans_params.metric = index ->metric ();
184
- auto new_vectors_view = raft::make_device_matrix_view<const T, IdxT>(new_vectors, n_rows, dim);
184
+ kmeans_params.metric = index ->metric ();
185
185
auto orig_centroids_view =
186
186
raft::make_device_matrix_view<const float , IdxT>(index ->centers ().data_handle (), n_lists, dim);
187
- raft::cluster::kmeans_balanced::predict (handle,
188
- kmeans_params,
189
- new_vectors_view,
190
- orig_centroids_view,
191
- new_labels.view (),
192
- utils::mapping<float >{});
187
+ // Calculate the batch size for the input data if it's not accessible directly from the device
188
+ constexpr size_t kReasonableMaxBatchSize = 65536 ;
189
+ size_t max_batch_size = std::min<size_t >(n_rows, kReasonableMaxBatchSize );
190
+
191
+ // Predict the cluster labels for the new data, in batches if necessary
192
+ utils::batch_load_iterator<T> vec_batches (new_vectors,
193
+ n_rows,
194
+ index ->dim (),
195
+ max_batch_size,
196
+ stream,
197
+ resource::get_workspace_resource (handle));
198
+
199
+ for (const auto & batch : vec_batches) {
200
+ auto batch_data_view =
201
+ raft::make_device_matrix_view<const T, IdxT>(batch.data (), batch.size (), index ->dim ());
202
+ auto batch_labels_view = raft::make_device_vector_view<LabelT, IdxT>(
203
+ new_labels.data_handle () + batch.offset (), batch.size ());
204
+ raft::cluster::kmeans_balanced::predict (handle,
205
+ kmeans_params,
206
+ batch_data_view,
207
+ orig_centroids_view,
208
+ batch_labels_view,
209
+ utils::mapping<float >{});
210
+ }
193
211
194
212
auto * list_sizes_ptr = index ->list_sizes ().data_handle ();
195
213
auto old_list_sizes_dev = raft::make_device_vector<uint32_t , IdxT>(handle, n_lists);
@@ -202,14 +220,19 @@ void extend(raft::resources const& handle,
202
220
auto list_sizes_view =
203
221
raft::make_device_vector_view<std::remove_pointer_t <decltype (list_sizes_ptr)>, IdxT>(
204
222
list_sizes_ptr, n_lists);
205
- auto const_labels_view = make_const_mdspan (new_labels.view ());
206
- raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes (handle,
207
- new_vectors_view,
208
- const_labels_view,
209
- centroids_view,
210
- list_sizes_view,
211
- false ,
212
- utils::mapping<float >{});
223
+ for (const auto & batch : vec_batches) {
224
+ auto batch_data_view =
225
+ raft::make_device_matrix_view<const T, IdxT>(batch.data (), batch.size (), index ->dim ());
226
+ auto batch_labels_view = raft::make_device_vector_view<const LabelT, IdxT>(
227
+ new_labels.data_handle () + batch.offset (), batch.size ());
228
+ raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes (handle,
229
+ batch_data_view,
230
+ batch_labels_view,
231
+ centroids_view,
232
+ list_sizes_view,
233
+ false ,
234
+ utils::mapping<float >{});
235
+ }
213
236
} else {
214
237
raft::stats::histogram<uint32_t , IdxT>(raft::stats::HistTypeAuto,
215
238
reinterpret_cast <int32_t *>(list_sizes_ptr),
@@ -244,20 +267,39 @@ void extend(raft::resources const& handle,
244
267
// we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter.
245
268
raft::copy (list_sizes_ptr, old_list_sizes_dev.data_handle (), n_lists, stream);
246
269
247
- // Kernel to insert the new vectors
248
- const dim3 block_dim (256 );
249
- const dim3 grid_dim (raft::ceildiv<IdxT>(n_rows, block_dim.x ));
250
- build_index_kernel<<<grid_dim, block_dim, 0 , stream>>> (new_labels.data_handle (),
251
- new_vectors,
252
- new_indices,
253
- index ->data_ptrs ().data_handle (),
254
- index ->inds_ptrs ().data_handle (),
255
- list_sizes_ptr,
256
- n_rows,
257
- dim,
258
- index ->veclen ());
259
- RAFT_CUDA_TRY (cudaPeekAtLastError ());
260
-
270
+ utils::batch_load_iterator<IdxT> vec_indices (
271
+ new_indices, n_rows, 1 , max_batch_size, stream, resource::get_workspace_resource (handle));
272
+ utils::batch_load_iterator<IdxT> idx_batch = vec_indices.begin ();
273
+ size_t next_report_offset = 0 ;
274
+ size_t d_report_offset = n_rows * 5 / 100 ;
275
+ for (const auto & batch : vec_batches) {
276
+ auto batch_data_view =
277
+ raft::make_device_matrix_view<const T, IdxT>(batch.data (), batch.size (), index ->dim ());
278
+ // Kernel to insert the new vectors
279
+ const dim3 block_dim (256 );
280
+ const dim3 grid_dim (raft::ceildiv<IdxT>(batch.size (), block_dim.x ));
281
+ build_index_kernel<T, IdxT, LabelT>
282
+ <<<grid_dim, block_dim, 0 , stream>>> (new_labels.data_handle () + batch.offset (),
283
+ batch_data_view.data_handle (),
284
+ idx_batch->data (),
285
+ index ->data_ptrs ().data_handle (),
286
+ index ->inds_ptrs ().data_handle (),
287
+ list_sizes_ptr,
288
+ batch.size (),
289
+ dim,
290
+ index ->veclen (),
291
+ batch.offset ());
292
+ RAFT_CUDA_TRY (cudaPeekAtLastError ());
293
+
294
+ if (batch.offset () > next_report_offset) {
295
+ float progress = batch.offset () * 100 .0f / n_rows;
296
+ RAFT_LOG_DEBUG (" ivf_flat::extend added vectors %zu, %6.1f%% complete" ,
297
+ static_cast <size_t >(batch.offset ()),
298
+ progress);
299
+ next_report_offset += d_report_offset;
300
+ }
301
+ ++idx_batch;
302
+ }
261
303
// Precompute the centers vector norms for L2Expanded distance
262
304
if (!index ->center_norms ().has_value ()) {
263
305
index ->allocate_center_norms (handle);
0 commit comments