diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 7f1ce7666b..08b5d60316 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -526,6 +526,7 @@ if(NOT BUILD_CPU_ONLY) src/distance/detail/pairwise_matrix/dispatch_rbf.cu src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int64_t.cu src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int64_t.cu + src/distance/detail/pairwise_matrix/dispatch_bitwise_hamming_uint8_t_uint32_t_uint32_t_int64_t.cu src/distance/distance.cu src/distance/pairwise_distance.cu src/distance/sparse_distance.cu diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index a839cecf56..23135def99 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -125,6 +125,15 @@ struct balanced_params : base_params { * Number of training iterations */ uint32_t n_iters = 20; + + /** + * If true, treats uint8_t input data as bit-packed binary data where each byte contains 8 bits. + * Bits are expanded on-the-fly to {-1, +1} floats during training. + * When enabled: + * - Input data dimension represents packed dimension (actual_dim / 8) + * - Output centroids dimension is expanded (packed_dim * 8) + */ + bool is_packed_binary = false; }; /** diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index 23c6dd4944..e1d72320a8 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -204,6 +204,11 @@ struct index : cuvs::neighbors::index { raft::device_matrix_view centers() noexcept; raft::device_matrix_view centers() const noexcept; + /** packed k-means cluster centers corresponding to the lists [n_lists, dim] when the + * BitwiseHamming metric is selected */ + raft::device_matrix_view binary_centers() noexcept; + raft::device_matrix_view binary_centers() const noexcept; + /** * (Optional) Precomputed norms of the `centers` w.r.t. the chosen distance metric [n_lists]. * @@ -229,7 +234,10 @@ struct index : cuvs::neighbors::index { /** Total length of the index. */ IdxT size() const noexcept; - /** Dimensionality of the data. */ + /** Dimensionality of the data. + * @note For binary index, this returns the dimensionality of the byte dataset, which is the + * number of bits / 8. + */ uint32_t dim() const noexcept; /** Number of clusters/inverted lists. */ @@ -255,6 +263,8 @@ struct index : cuvs::neighbors::index { void check_consistency(); + bool binary_index() const noexcept; + private: /** * TODO: in theory, we can lift this to the template parameter and keep it at hardware maximum @@ -267,7 +277,9 @@ struct index : cuvs::neighbors::index { std::vector>> lists_; raft::device_vector list_sizes_; raft::device_matrix centers_; + raft::device_matrix binary_centers_; std::optional> center_norms_; + bool binary_index_; // Computed members raft::device_vector data_ptrs_; diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index f5dc759725..7e74fac099 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -39,9 +39,11 @@ #include #include +#include #include #include +#include "../../neighbors/detail/ann_utils.cuh" #include #include #include @@ -51,6 +53,28 @@ namespace cuvs::cluster::kmeans::detail { constexpr static inline float kAdjustCentersWeight = 7.0f; +/** + * @brief Create a transform iterator for on-the-fly bit expansion + * + * This helper function creates a thrust transform iterator that expands packed + * uint8_t data into float values on-the-fly (bit 1 → +1.0f, bit 0 → -1.0f), + * + * @tparam IdxT index type + * + * @param packed_data Pointer to packed uint8_t data [n_rows, packed_dim] + * @param n_rows Number of rows + * @param expanded_dim Dimension in expanded (bit) space + * @return A transform iterator that yields float values for each bit + */ +template +auto make_bitwise_expanded_iterator(const uint8_t* packed_data, IdxT packed_dim) +{ + auto counting_iter = thrust::make_counting_iterator(0); + auto decoder = + cuvs::spatial::knn::detail::utils::bitwise_decode_op(packed_data, packed_dim); + return thrust::make_transform_iterator(counting_iter, decoder); +} + /** * @brief Predict labels for the dataset; floating-point types only. * @@ -204,6 +228,78 @@ inline std::enable_if_t> predict_core( } } +/** + * @brief Predict labels for the dataset; uint8_t only (specialization for BitwiseHamming). + */ +template +inline void predict_bitwise_hamming(const raft::resources& handle, + const cuvs::cluster::kmeans::balanced_params& params, + const uint8_t* centers, + IdxT n_clusters, + IdxT dim, + const uint8_t* dataset, + const uint8_t* dataset_norm, + IdxT n_rows, + LabelT* labels, + rmm::device_async_resource_ref mr) +{ + RAFT_EXPECTS(params.metric == cuvs::distance::DistanceType::BitwiseHamming, + "uint8_t data only supports BitwiseHamming distance"); + + auto stream = raft::resource::get_cuda_stream(handle); + + auto workspace = raft::make_device_mdarray( + handle, mr, raft::make_extents((sizeof(int)) * n_rows)); + + auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( + handle, mr, raft::make_extents(n_rows)); + raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + raft::matrix::fill(handle, minClusterAndDistance.view(), initial_value); + + cuvs::distance::fusedDistanceNNMinReduce, IdxT>( + minClusterAndDistance.data_handle(), + dataset, + centers, + nullptr, + nullptr, + n_rows, + n_clusters, + dim, + (void*)workspace.data_handle(), + false, + false, + true, + params.metric, + 0.0f, + stream); + + raft::linalg::map(handle, + raft::make_const_mdspan(minClusterAndDistance.view()), + raft::make_device_vector_view(labels, n_rows), + raft::compose_op, raft::key_op>()); +} + +template +inline void predict_bitwise_hamming(const raft::resources& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view centers, + raft::device_vector_view labels) +{ + cuvs::cluster::kmeans::balanced_params params; + params.metric = cuvs::distance::DistanceType::BitwiseHamming; + + predict_bitwise_hamming(handle, + params, + centers.data_handle(), + centers.extent(0), + centers.extent(1), + dataset.data_handle(), + nullptr, + dataset.extent(0), + labels.data_handle(), + raft::resource::get_workspace_resource(handle)); +} + /** * @brief Suggest a minibatch size for kmeans prediction. * @@ -257,6 +353,12 @@ constexpr auto calc_minibatch_size(IdxT n_clusters, /** * @brief Given the data and labels, calculate cluster centers and sizes in one sweep. * + * This function supports two modes: + * 1. Regular mode: Works with any data type T with optional type conversion via mapping_op + * 2. Packed binary mode: When T=uint8_t and is_packed_binary=true, treats data as bit-packed + * and expands bits on-the-fly (bit 1 → +1, bit 0 → -1) into float centers. + * In this mode, dim represents the packed dimension (dim_expanded / 8). + * * @note all pointers must be accessible on the device. * * @tparam T element type @@ -267,10 +369,10 @@ constexpr auto calc_minibatch_size(IdxT n_clusters, * @tparam MappingOpT type of the mapping operation * * @param[in] handle The raft handle. - * @param[inout] centers Pointer to the output [n_clusters, dim] + * @param[inout] centers Pointer to the output [n_clusters, dim] or [n_clusters, dim*8] if packed * @param[inout] cluster_sizes Number of rows in each cluster [n_clusters] * @param[in] n_clusters Number of clusters/centers - * @param[in] dim Dimensionality of the data + * @param[in] dim Dimensionality of the data (or packed dim if is_packed_binary=true) * @param[in] dataset Pointer to the data [n_rows, dim] * @param[in] n_rows Number of samples in the `dataset` * @param[in] labels Output predictions [n_rows] @@ -279,6 +381,8 @@ constexpr auto calc_minibatch_size(IdxT n_clusters, * the weighted average principle. * @param[in] mapping_op Mapping operation from T to MathT * @param[inout] mr (optional) Memory resource to use for temporary allocations on the device + * @param[in] is_packed_binary If true and T=uint8_t, treats data as bit-packed and expands + * on-the-fly */ template (centers, n_clusters, dim); + // For packed binary, dim is packed dimension, centers are in expanded dimension (dim * 8) + IdxT centers_dim = is_packed_binary ? (dim * 8) : dim; + + auto centersView = raft::make_device_matrix_view(centers, n_clusters, centers_dim); auto clusterSizesView = raft::make_device_vector_view(cluster_sizes, n_clusters); if (!reset_counters) { @@ -319,8 +427,27 @@ void calc_centers_and_sizes(const raft::resources& handle, temp_sizes = temp_cluster_sizes.data(); } + // Handle packed binary data with on-the-fly bit expansion + if (is_packed_binary) { + if constexpr (std::is_same_v) { + RAFT_EXPECTS(dim * 8 == centers_dim, "dim must be the packed dimension"); + auto decoded_dataset_iter = make_bitwise_expanded_iterator(dataset, dim); + raft::linalg::reduce_rows_by_key(decoded_dataset_iter, + centers_dim, + labels, + nullptr, + n_rows, + centers_dim, + n_clusters, + centers, + stream, + reset_counters); + } else { + RAFT_FAIL("Packed binary mode is only supported for uint8_t data type"); + } + } // Apply mapping only when the data and math types are different. - if constexpr (std::is_same_v) { + else if constexpr (std::is_same_v) { raft::linalg::reduce_rows_by_key( dataset, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); } else { @@ -429,14 +556,22 @@ void predict(const raft::resources& handle, auto mem_res = mr.value_or(raft::resource::get_workspace_resource(handle)); auto [max_minibatch_size, _mem_per_row] = calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); + IdxT transformed_dim = params.is_packed_binary ? dim * 8 : dim; rmm::device_uvector cur_dataset( - std::is_same_v ? 0 : max_minibatch_size * dim, stream, mem_res); - bool need_compute_norm = - dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || - params.metric == cuvs::distance::DistanceType::CosineExpanded); + std::is_same_v ? 0 : max_minibatch_size * transformed_dim, stream, mem_res); + bool need_compute_norm = dataset_norm == nullptr && + (params.metric == cuvs::distance::DistanceType::L2Expanded || + params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || + params.metric == cuvs::distance::DistanceType::CosineExpanded) && + !params.is_packed_binary; rmm::device_uvector cur_dataset_norm( - need_compute_norm ? max_minibatch_size : 0, stream, mem_res); + need_compute_norm || params.is_packed_binary ? max_minibatch_size : 0, stream, mem_res); + if (params.is_packed_binary) { + raft::matrix::fill( + handle, + raft::make_device_matrix_view(cur_dataset_norm.data(), max_minibatch_size, 1), + static_cast(transformed_dim)); + } const MathT* dataset_norm_ptr = nullptr; auto cur_dataset_ptr = cur_dataset.data(); for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { @@ -444,6 +579,14 @@ void predict(const raft::resources& handle, if constexpr (std::is_same_v) { cur_dataset_ptr = const_cast(dataset + offset * dim); + } else if (params.is_packed_binary) { + if constexpr (std::is_same_v) { + raft::linalg::map_offset(handle, + raft::make_device_matrix_view( + cur_dataset_ptr, minibatch_size, transformed_dim), + cuvs::spatial::knn::detail::utils::bitwise_decode_op( + dataset + offset * dim, dim)); + } } else { raft::linalg::map( handle, @@ -481,7 +624,7 @@ void predict(const raft::resources& handle, params, centers, n_clusters, - dim, + transformed_dim, cur_dataset_ptr, dataset_norm_ptr, minibatch_size, @@ -491,7 +634,7 @@ void predict(const raft::resources& handle, } template bool + rmm::device_async_resource_ref device_memory, + bool is_packed_binary = false) -> bool { raft::common::nvtx::range fun_scope( "adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); @@ -618,18 +762,38 @@ auto adjust_centers(MathT* centers, const dim3 block_dim(raft::WarpSize, kBlockDimY, 1); const dim3 grid_dim(raft::ceildiv(n_clusters, static_cast(kBlockDimY)), 1, 1); rmm::device_scalar update_count(0, stream, device_memory); - adjust_centers_kernel<<>>(centers, - n_clusters, - dim, - dataset, - n_rows, - labels, - cluster_sizes, - threshold, - average, - ofst, - update_count.data(), - mapping_op); + if (is_packed_binary) { + if constexpr (std::is_same_v) { + IdxT transformed_dim = dim * 8; + auto dataset_iterator = make_bitwise_expanded_iterator(dataset, dim); + adjust_centers_kernel<<>>(centers, + n_clusters, + transformed_dim, + dataset_iterator, + n_rows, + labels, + cluster_sizes, + threshold, + average, + ofst, + update_count.data(), + mapping_op); + } + } else { + adjust_centers_kernel<<>>(centers, + n_clusters, + dim, + dataset, + n_rows, + labels, + cluster_sizes, + threshold, + average, + ofst, + update_count.data(), + mapping_op); + } + adjusted = update_count.value(stream) > 0; // NB: rmm scalar performs the sync return adjusted; @@ -696,20 +860,43 @@ void balancing_em_iters(const raft::resources& handle, { auto stream = raft::resource::get_cuda_stream(handle); uint32_t balancing_counter = balancing_pullback; + IdxT transformed_dim = params.is_packed_binary ? dim * 8 : dim; for (uint32_t iter = 0; iter < n_iters; iter++) { // Balancing step - move the centers around to equalize cluster sizes // (but not on the first iteration) - if (iter > 0 && adjust_centers(cluster_centers, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - cluster_sizes, - balancing_threshold, - mapping_op, - stream, - device_memory)) { + bool did_adjust = false; + if (iter > 0) { + if (params.is_packed_binary) { + if constexpr (std::is_same_v) { + did_adjust = adjust_centers(cluster_centers, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + cluster_sizes, + balancing_threshold, + raft::identity_op{}, + stream, + device_memory, + true); + } + } else { + did_adjust = adjust_centers(cluster_centers, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + cluster_sizes, + balancing_threshold, + mapping_op, + stream, + device_memory, + false); + } + } + if (did_adjust) { if (balancing_counter++ >= balancing_pullback) { balancing_counter -= balancing_pullback; n_iters++; @@ -722,9 +909,9 @@ void balancing_em_iters(const raft::resources& handle, case cuvs::distance::DistanceType::CosineExpanded: case cuvs::distance::DistanceType::CorrelationExpanded: { auto clusters_in_view = raft::make_device_matrix_view( - cluster_centers, n_clusters, dim); + cluster_centers, n_clusters, transformed_dim); auto clusters_out_view = raft::make_device_matrix_view( - cluster_centers, n_clusters, dim); + cluster_centers, n_clusters, transformed_dim); raft::linalg::row_normalize( handle, clusters_in_view, clusters_out_view); break; @@ -753,6 +940,7 @@ void balancing_em_iters(const raft::resources& handle, n_rows, cluster_labels, true, + params.is_packed_binary, mapping_op, device_memory); } @@ -797,6 +985,7 @@ void build_clusters(const raft::resources& handle, n_rows, cluster_labels, true, + params.is_packed_binary, mapping_op, device_memory); @@ -920,19 +1109,23 @@ auto build_fine_clusters(const raft::resources& handle, rmm::device_async_resource_ref managed_memory, rmm::device_async_resource_ref device_memory) -> IdxT { - auto stream = raft::resource::get_cuda_stream(handle); + auto stream = raft::resource::get_cuda_stream(handle); + IdxT transformed_dim = params.is_packed_binary ? dim * 8 : dim; rmm::device_uvector mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory); - rmm::device_uvector mc_trainset_buf(mesocluster_size_max * dim, stream, device_memory); + // For packed binary: use uint8_t buffer. For non-packed: use MathT buffer + rmm::device_uvector mc_trainset_packed_buf( + params.is_packed_binary ? mesocluster_size_max * dim : 0, stream, device_memory); + rmm::device_uvector mc_trainset_buf( + !params.is_packed_binary ? mesocluster_size_max * dim : 0, stream, device_memory); rmm::device_uvector mc_trainset_norm_buf(mesocluster_size_max, stream, device_memory); auto mc_trainset_ids = mc_trainset_ids_buf.data(); - auto mc_trainset = mc_trainset_buf.data(); auto mc_trainset_norm = mc_trainset_norm_buf.data(); // label (cluster ID) of each vector rmm::device_uvector mc_trainset_labels(mesocluster_size_max, stream, device_memory); rmm::device_uvector mc_trainset_ccenters( - fine_clusters_nums_max * dim, stream, device_memory); + fine_clusters_nums_max * transformed_dim, stream, device_memory); // number of vectors in each cluster rmm::device_uvector mc_trainset_csizes_tmp( fine_clusters_nums_max, stream, device_memory); @@ -960,36 +1153,75 @@ auto build_fine_clusters(const raft::resources& handle, "Number of fine clusters must be non-zero for a non-empty mesocluster"); } - thrust::transform_iterator mapping_itr(dataset_mptr, mapping_op); - raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream); + // Gather data based on mode + if (params.is_packed_binary) { + // Packed binary: gather raw uint8_t without transformation + if constexpr (std::is_same_v) { + raft::matrix::gather( + dataset_mptr, dim, n_rows, mc_trainset_ids, k, mc_trainset_packed_buf.data(), stream); + } else { + RAFT_FAIL("Packed binary mode requires uint8_t data type"); + } + } else { + thrust::transform_iterator mapping_itr(dataset_mptr, mapping_op); + raft::matrix::gather( + mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset_buf.data(), stream); + } + if (params.metric == cuvs::distance::DistanceType::L2Expanded || params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || params.metric == cuvs::distance::DistanceType::CosineExpanded) { - thrust::gather(raft::resource::get_thrust_policy(handle), - mc_trainset_ids, - mc_trainset_ids + k, - dataset_norm_mptr, - mc_trainset_norm); + if (params.is_packed_binary) { + // For packed binary, norm is constant = transformed_dim + raft::matrix::fill(handle, + raft::make_device_matrix_view(mc_trainset_norm, k, 1), + static_cast(transformed_dim)); + } else { + thrust::gather(raft::resource::get_thrust_policy(handle), + mc_trainset_ids, + mc_trainset_ids + k, + dataset_norm_mptr, + mc_trainset_norm); + } } - build_clusters(handle, - params, - dim, - mc_trainset, - k, - fine_clusters_nums[i], - mc_trainset_ccenters.data(), - mc_trainset_labels.data(), - mc_trainset_csizes_tmp.data(), - mapping_op, - device_memory, - mc_trainset_norm); - + if (params.is_packed_binary) { + // Packed bnary: pass uint8_t*, build_clusters will expand on-the-fly + if constexpr (std::is_same_v) { + build_clusters(handle, + params, + dim, + mc_trainset_packed_buf.data(), + k, + fine_clusters_nums[i], + mc_trainset_ccenters.data(), + mc_trainset_labels.data(), + mc_trainset_csizes_tmp.data(), + raft::identity_op{}, // For packed binary, no additional mapping + device_memory, + mc_trainset_norm); + } else { + RAFT_FAIL("Packed binary mode requires uint8_t data type"); + } + } else { + build_clusters(handle, + params, + dim, + mc_trainset_buf.data(), + k, + fine_clusters_nums[i], + mc_trainset_ccenters.data(), + mc_trainset_labels.data(), + mc_trainset_csizes_tmp.data(), + mapping_op, // Passed but not used since T=MathT + device_memory, + mc_trainset_norm); + } raft::copy(handle, - raft::make_device_vector_view(cluster_centers + (dim * fine_clusters_csum[i]), - fine_clusters_nums[i] * dim), - raft::make_device_vector_view(mc_trainset_ccenters.data(), - fine_clusters_nums[i] * dim)); + raft::make_device_vector_view(cluster_centers + (dim * fine_clusters_csum[i]), + fine_clusters_nums[i] * dim), + raft::make_device_vector_view(mc_trainset_ccenters.data(), + fine_clusters_nums[i] * dim)); raft::resource::sync_stream(handle, stream); n_clusters_done += fine_clusters_nums[i]; } @@ -1027,8 +1259,9 @@ void build_hierarchical(const raft::resources& handle, MappingOpT mapping_op, MathT* inertia = nullptr) { - auto stream = raft::resource::get_cuda_stream(handle); - using LabelT = uint32_t; + auto stream = raft::resource::get_cuda_stream(handle); + using LabelT = uint32_t; + IdxT transformed_dim = params.is_packed_binary ? dim * 8 : dim; raft::common::nvtx::range fun_scope( "build_hierarchical(%zu, %u)", static_cast(n_rows), n_clusters); @@ -1047,14 +1280,15 @@ void build_hierarchical(const raft::resources& handle, const MathT* dataset_norm = nullptr; if ((params.metric == cuvs::distance::DistanceType::L2Expanded || params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || - params.metric == cuvs::distance::DistanceType::CosineExpanded)) { + params.metric == cuvs::distance::DistanceType::CosineExpanded) && + !params.is_packed_binary) { dataset_norm_buf.resize(n_rows, stream); for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); if (params.metric == cuvs::distance::DistanceType::CosineExpanded) compute_norm(handle, dataset_norm_buf.data() + offset, - dataset + dim * offset, + dataset + offset * dim, dim, minibatch_size, mapping_op, @@ -1063,14 +1297,21 @@ void build_hierarchical(const raft::resources& handle, else compute_norm(handle, dataset_norm_buf.data() + offset, - dataset + dim * offset, + dataset + offset * dim, dim, minibatch_size, mapping_op, raft::identity_op{}, device_memory); } - dataset_norm = (const MathT*)dataset_norm_buf.data(); + dataset_norm = dataset_norm_buf.data(); + } else if (params.is_packed_binary) { + dataset_norm_buf.resize(n_rows, stream); + raft::matrix::fill( + handle, + raft::make_device_matrix_view(dataset_norm_buf.data(), n_rows, 1), + static_cast(transformed_dim)); + dataset_norm = (const MathT*)dataset_norm_buf.data(); } /* Temporary workaround to cub::DeviceHistogram not supporting any type that isn't natively @@ -1082,7 +1323,8 @@ void build_hierarchical(const raft::resources& handle, rmm::device_uvector mesocluster_labels_buf(n_rows, stream, &managed_memory); rmm::device_uvector mesocluster_sizes_buf(n_mesoclusters, stream, &managed_memory); { - rmm::device_uvector mesocluster_centers_buf(n_mesoclusters * dim, stream, device_memory); + rmm::device_uvector mesocluster_centers_buf( + n_mesoclusters * transformed_dim, stream, device_memory); build_clusters(handle, params, dim, diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index 0c0df03397..17df855deb 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -72,7 +72,8 @@ void fit(const raft::resources& handle, MappingOpT mapping_op = raft::identity_op(), std::optional> inertia = std::nullopt) { - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), + RAFT_EXPECTS(X.extent(1) == centroids.extent(1) || + (params.is_packed_binary && X.extent(1) * 8 == centroids.extent(1)), "Number of features in dataset and centroids are different"); RAFT_EXPECTS(static_cast(X.extent(0)) * static_cast(X.extent(1)) <= static_cast(std::numeric_limits::max()), @@ -285,14 +286,16 @@ void calc_centers_and_sizes(const raft::resources& handle, raft::device_matrix_view centroids, raft::device_vector_view cluster_sizes, bool reset_counters = true, + bool is_packed_binary = false, MappingOpT mapping_op = raft::identity_op()) { RAFT_EXPECTS(X.extent(0) == labels.extent(0), "Number of rows in dataset and labels are different"); - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); + RAFT_EXPECTS( + is_packed_binary ? X.extent(1) * 8 == centroids.extent(1) : X.extent(1) == centroids.extent(1), + "Number of features in dataset and centroids are different"); RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0), - "Number of rows in centroids and clusyer_sizes are different"); + "Number of rows in centroids and cluster_sizes are different"); cuvs::cluster::kmeans::detail::calc_centers_and_sizes( handle, @@ -304,6 +307,7 @@ void calc_centers_and_sizes(const raft::resources& handle, X.extent(0), labels.data_handle(), reset_counters, + is_packed_binary, mapping_op, raft::resource::get_workspace_resource(handle)); } diff --git a/cpp/src/distance/detail/distance_ops/all_ops.cuh b/cpp/src/distance/detail/distance_ops/all_ops.cuh index f0a3984eb6..93573fffff 100644 --- a/cpp/src/distance/detail/distance_ops/all_ops.cuh +++ b/cpp/src/distance/detail/distance_ops/all_ops.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -9,6 +9,7 @@ #include "cutlass.cuh" // The distance operations: +#include "../distance_ops/bitwise_hamming.cuh" #include "../distance_ops/canberra.cuh" #include "../distance_ops/correlation.cuh" #include "../distance_ops/cosine.cuh" diff --git a/cpp/src/distance/detail/distance_ops/bitwise_hamming.cuh b/cpp/src/distance/detail/distance_ops/bitwise_hamming.cuh new file mode 100644 index 0000000000..1c543e2379 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/bitwise_hamming.cuh @@ -0,0 +1,60 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +namespace cuvs::distance::detail::ops { + +/** + * @brief the Bitwise Hamming distance matrix calculation + * It computes the following equation: + * + * c_ij = sum_k popcount(x_ik XOR y_kj) + * + * where x and y are binary data packed as uint8_t + */ +template +struct bitwise_hamming_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + IdxT k; + + bitwise_hamming_distance_op(IdxT k_) noexcept : k(k_) {} + + static constexpr bool use_norms = false; + static constexpr bool expensive_inner_loop = false; + + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + __device__ __forceinline__ void core(AccT& acc, DataT& x, DataT& y) const + { + static_assert(std::is_same_v, "BitwiseHamming only supports uint8_t"); + // Ensure proper masking and casting to avoid undefined behavior + uint32_t xor_val = static_cast(static_cast(x ^ y)); + uint32_t masked_val = xor_val & 0xffu; + int popcount = __popc(masked_val); + acc += static_cast(popcount); + } + + template + __device__ __forceinline__ void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + AccT* regxn, + AccT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + } +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/fused_distance_nn.cuh b/cpp/src/distance/detail/fused_distance_nn.cuh index f9dbd968ec..1838f1d29d 100644 --- a/cpp/src/distance/detail/fused_distance_nn.cuh +++ b/cpp/src/distance/detail/fused_distance_nn.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -7,6 +7,7 @@ #include "distance_ops/l2_exp.cuh" // ops::l2_exp_distance_op #include "fused_distance_nn/cutlass_base.cuh" +#include "fused_distance_nn/fused_bitwise_hamming_nn.cuh" #include "fused_distance_nn/fused_cosine_nn.cuh" #include "fused_distance_nn/fused_l2_nn.cuh" #include "fused_distance_nn/helper_structs.cuh" @@ -68,16 +69,31 @@ void fusedDistanceNNImpl(OutT* min, switch (metric) { case cuvs::distance::DistanceType::CosineExpanded: - fusedCosineNN( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); + if constexpr (std::is_same_v || std::is_same_v) { + RAFT_FAIL("Cosine distance is not supported for uint8_t/int8_t data types"); + } else { + fusedCosineNN( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); + } break; case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2Expanded: - // initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl. - fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream); + if constexpr (std::is_same_v || std::is_same_v) { + RAFT_FAIL("L2 distance is not supported for uint8_t/int8_t data types"); + } else { + fusedL2NNImpl( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream); + } break; - default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break; + case cuvs::distance::DistanceType::BitwiseHamming: + if constexpr (std::is_same_v) { + fusedBitwiseHammingNN( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); + } else { + RAFT_FAIL("BitwiseHamming distance only supports uint8_t data type"); + } + break; + default: RAFT_FAIL("only cosine/l2/bitwise hamming metric is supported with fusedDistanceNN"); } } diff --git a/cpp/src/distance/detail/fused_distance_nn/fused_bitwise_hamming_nn.cuh b/cpp/src/distance/detail/fused_distance_nn/fused_bitwise_hamming_nn.cuh new file mode 100644 index 0000000000..50796b5737 --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/fused_bitwise_hamming_nn.cuh @@ -0,0 +1,82 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include "../distance_ops/bitwise_hamming.cuh" // ops::bitwise_hamming_distance_op +#include "../pairwise_distance_base.cuh" // PairwiseDistances +#include "helper_structs.cuh" +#include "simt_kernel.cuh" + +namespace cuvs { +namespace distance { +namespace detail { + +/** + * @brief Fused BitwiseHamming distance and 1-nearest-neighbor + * + * This implementation is only meaningful for uint8_t data type. + * The if constexpr in fusedDistanceNNImpl ensures it's only called for uint8_t. + */ +template +void fusedBitwiseHammingNN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + cudaStream_t stream) +{ + typedef Policy P; + + dim3 blk(P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + using kv_pair_type = raft::KeyValuePair; + using distance_op_type = ops::bitwise_hamming_distance_op; + distance_op_type distance_op{k}; + auto kernel = fusedDistanceNNkernel; + + constexpr size_t shmemSize = P::SmemSize; + + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>(min, + x, + y, + nullptr, + nullptr, + m, + n, + k, + maxVal, + workspace, + redOp, + pairRedOp, + distance_op, + raft::identity_op{}); + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh b/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh index 762c720568..bdac833d7b 100644 --- a/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh +++ b/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -27,8 +27,15 @@ namespace detail { template struct KVPMinReduceImpl { typedef raft::KeyValuePair KVP; - DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + // Use index as tiebreaker for consistent behavior when distances are equal + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) + { + return (b.value < a.value || (b.value == a.value && b.key < a.key)) ? b : a; + } + DI KVP operator()(const KVP& a, const KVP& b) + { + return (b.value < a.value || (b.value == a.value && b.key < a.key)) ? b : a; + } }; // KVPMinReduce @@ -38,14 +45,16 @@ struct MinAndDistanceReduceOpImpl { DI void operator()(LabelT rid, KVP* out, const KVP& other) const { - if (other.value < out->value) { + // Use index as tiebreaker for consistent behavior when distances are equal + if (other.value < out->value || (other.value == out->value && other.key < out->key)) { out->key = other.key; out->value = other.value; } } DI void operator()(LabelT rid, volatile KVP* out, const KVP& other) const { - if (other.value < out->value) { + // Use index as tiebreaker for consistent behavior when distances are equal + if (other.value < out->value || (other.value == out->value && other.key < out->key)) { out->key = other.key; out->value = other.value; } @@ -123,7 +132,11 @@ struct kvp_cg_min_reduce_op { using AccTypeT = AccType; using IndexT = Index; // functor signature. - __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } + // Use index as tiebreaker for consistent behavior when distances are equal + __host__ __device__ KVP operator()(KVP a, KVP b) const + { + return (a.value < b.value || (a.value == b.value && a.key < b.key)) ? a : b; + } __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } diff --git a/cpp/src/distance/detail/fused_distance_nn/simt_kernel.cuh b/cpp/src/distance/detail/fused_distance_nn/simt_kernel.cuh index 4211ff653d..b378829a49 100644 --- a/cpp/src/distance/detail/fused_distance_nn/simt_kernel.cuh +++ b/cpp/src/distance/detail/fused_distance_nn/simt_kernel.cuh @@ -1,14 +1,15 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once -#include "../distance_ops/l2_exp.cuh" // ops::l2_exp_distance_op -#include "../pairwise_distance_base.cuh" // PairwiseDistances -#include // raft::KeyValuePair -#include // Policy +#include "../distance_ops/bitwise_hamming.cuh" // ops::bitwise_hamming_distance_op +#include "../distance_ops/l2_exp.cuh" // ops::l2_exp_distance_op +#include "../pairwise_distance_base.cuh" // PairwiseDistances +#include // raft::KeyValuePair +#include // Policy #include // size_t #include // std::numeric_limits @@ -71,104 +72,115 @@ __launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedDistanceNNkernel(OutT* min, OpT distance_op, FinalLambda fin_op) { -// compile only if below non-ampere arch. -#if __CUDA_ARCH__ < 800 - extern __shared__ char smem[]; + // For hamming-like distances, we need this kernel on all architectures + // For other distances, only use for pre-ampere architectures + +#if __CUDA_ARCH__ >= 800 + static constexpr bool compile = + std::is_same_v>; +#else + static constexpr bool compile = true; +#endif - typedef raft::KeyValuePair KVPair; - KVPair val[P::AccRowsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {0, maxVal}; - } + if constexpr (compile) { + extern __shared__ char smem[]; - // epilogue operation lambda for final value calculation - auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( - DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - - // intra thread reduce - const auto acccolid = threadIdx.x % P::AccThCols; - const auto accrowid = threadIdx.x / P::AccThCols; + using AccT = std::conditional_t, uint32_t, DataT>; + typedef raft::KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; - KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < n) { - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } + val[i] = {0, maxVal}; } - }; - auto rowEpilog_lambda = - [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { + // epilogue operation lambda for final value calculation + auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( + AccT acc[P::AccRowsPerTh][P::AccColsPerTh], + AccT * regxn, + AccT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { KVPReduceOpT pairRed_op(pairRedOp); - ReduceOpT red_op(redOp); + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); - - // reduce #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll - for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, - // but the shfl op applies the modulo internally. - auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); - auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); - KVPair tmp = {tmpkey, tmpvalue}; - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } } } + }; - updateReducedVal(mutex, min, val, red_op, m, gridStrideY); + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); - // reset the val array. + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + + // reduce #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {0, maxVal}; - } - }; + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, + // but the shfl op applies the modulo internally. + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } - IdxT lda = k, ldb = k, ldd = n; - constexpr bool row_major = true; - constexpr bool write_out = false; - PairwiseDistances - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - xn, - yn, - nullptr, // Output pointer - smem, - distance_op, - epilog_lambda, - fin_op, - rowEpilog_lambda); - obj.run(); -#endif + updateReducedVal(mutex, min, val, red_op, m, gridStrideY); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + constexpr bool row_major = true; + constexpr bool write_out = false; + using AccT = std::conditional_t, uint32_t, DataT>; + PairwiseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + reinterpret_cast(xn), + reinterpret_cast(yn), + nullptr, // Output pointer + smem, + distance_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); + } } } // namespace detail diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh index c93a2f3f2b..6ec63a8ad4 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -67,6 +67,11 @@ void pairwise_matrix_dispatch(OpT distance_op, instantiate_cuvs_distance_detail_pairwise_matrix_dispatch( \ OpT, half, float, float, FinOpT, IdxT); +#define instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_bitwise_hamming(OpT, \ + IdxT) \ + instantiate_cuvs_distance_detail_pairwise_matrix_dispatch( \ + OpT, uint8_t, uint32_t, uint32_t, raft::identity_op, IdxT); + /* * Hierarchy of instantiations: * @@ -112,5 +117,8 @@ instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo( instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_default( cuvs::distance::detail::ops::l2_exp_distance_op, int64_t); +instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo_bitwise_hamming( + cuvs::distance::detail::ops::bitwise_hamming_distance_op, int64_t); + #undef instantiate_cuvs_distance_detail_pairwise_matrix_dispatch_by_algo #undef instantiate_cuvs_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py index 0cfa0c2c2a..5ee4e128b3 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py @@ -221,3 +221,36 @@ def arch_headers(archs): "\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n" ) print(f"src/distance/detail/pairwise_matrix/{path}") + +# Bitwise Hamming with uint8_t/uint32_t types +bitwise_hamming_instances = [ + dict( + DataT="uint8_t", + AccT="uint32_t", + OutT="uint32_t", + IdxT="int64_t", + ), +] + +for dt in bitwise_hamming_instances: + DataT, AccT, OutT, IdxT = ( + dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"] + ) + path = f"dispatch_bitwise_hamming_{DataT}_{AccT}_{OutT}_{IdxT}.cu" + with open(path, "w") as f: + f.write(header) + f.write( + '#include "../distance_ops/bitwise_hamming.cuh" // bitwise_hamming_distance_op\n' + ) + f.write(arch_headers([60])) # SM60 architecture + f.write(macro) + + OpT = "cuvs::distance::detail::ops::bitwise_hamming_distance_op" + FinOpT = "raft::identity_op" + f.write( + f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n" + ) + f.write( + "\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n" + ) + print(f"src/distance/detail/pairwise_matrix/{path}") diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_bitwise_hamming_uint8_t_uint32_t_uint32_t_int64_t.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_bitwise_hamming_uint8_t_uint32_t_uint32_t_int64_t.cu new file mode 100644 index 0000000000..a6c5d21ae8 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_bitwise_hamming_uint8_t_uint32_t_uint32_t_int64_t.cu @@ -0,0 +1,45 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "../distance_ops/bitwise_hamming.cuh" // bitwise_hamming_distance_op +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const OutT* x_norm, \ + const OutT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::bitwise_hamming_distance_op, + uint8_t, + uint32_t, + uint32_t, + raft::identity_op, + int64_t); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/fused_distance_nn-inl.cuh b/cpp/src/distance/fused_distance_nn-inl.cuh index 3fa80a9b60..e51ce04d77 100644 --- a/cpp/src/distance/fused_distance_nn-inl.cuh +++ b/cpp/src/distance/fused_distance_nn-inl.cuh @@ -98,99 +98,99 @@ void fusedDistanceNN(OutT* min, auto py = reinterpret_cast(y); if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { if (is_skinny) { - detail::fusedDistanceNNImpl< - DataT, - OutT, - IdxT, - typename raft::linalg::Policy4x4Skinny::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); + constexpr int max_veclen = std::min(4, 16 / sizeof(DataT)); + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); } else { - detail::fusedDistanceNNImpl< - DataT, - OutT, - IdxT, - typename raft::linalg::Policy4x4::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); + constexpr int max_veclen = std::min(4, 16 / sizeof(DataT)); + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); } } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { if (is_skinny) { - detail::fusedDistanceNNImpl< - DataT, - OutT, - IdxT, - typename raft::linalg::Policy4x4Skinny::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); + constexpr int max_veclen = std::min(4, 8 / sizeof(DataT)); + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); } else { - detail::fusedDistanceNNImpl< - DataT, - OutT, - IdxT, - typename raft::linalg::Policy4x4::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); + constexpr int max_veclen = std::min(4, 8 / sizeof(DataT)); + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); } } else { if (is_skinny) { @@ -289,8 +289,9 @@ void fusedDistanceNNMinReduce(OutT* min, float metric_arg, cudaStream_t stream) { - MinAndDistanceReduceOp redOp; - KVPMinReduce pairRedOp; + using AccT = std::conditional_t, uint32_t, DataT>; + MinAndDistanceReduceOp redOp; + KVPMinReduce pairRedOp; fusedDistanceNN(min, x, diff --git a/cpp/src/neighbors/detail/ann_utils.cuh b/cpp/src/neighbors/detail/ann_utils.cuh index 82bd6e755a..c40b107878 100644 --- a/cpp/src/neighbors/detail/ann_utils.cuh +++ b/cpp/src/neighbors/detail/ann_utils.cuh @@ -199,6 +199,26 @@ HDI constexpr auto mapping::operator()(const float& x) const -> int8_t return static_cast(std::clamp(x * 128.0f, -128.0f, 127.0f)); } +template +struct bitwise_decode_op { + bitwise_decode_op(const uint8_t* const binary_vecs, IdxT compressed_dim) + : binary_vecs(binary_vecs), compressed_dim(compressed_dim) + { + uncompressed_dim = compressed_dim << 3; + } + const uint8_t* binary_vecs; + IdxT compressed_dim; + IdxT uncompressed_dim; + HDI constexpr auto operator()(const IdxT& i) -> OutT + { + IdxT row_id = i / uncompressed_dim; + IdxT col_id = i % uncompressed_dim; + return static_cast( + -1 + 2 * static_cast( + (binary_vecs[row_id * compressed_dim + (col_id >> 3)] >> (col_id & 7)) & 1)); + }; +}; + /** * @brief Sets the first num bytes of the block of memory pointed by ptr to the specified value. * diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index 06862c083d..55a7ceea71 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -12,6 +12,9 @@ #include #include #include +#include +#include +#include #include "../../cluster/kmeans_balanced.cuh" #include "../detail/ann_utils.cuh" @@ -58,7 +61,11 @@ auto clone(const raft::resources& res, const index& source) -> index(n_rows)); cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.metric = index->metric(); - auto orig_centroids_view = - raft::make_device_matrix_view(index->centers().data_handle(), n_lists, dim); // Calculate the batch size for the input data if it's not accessible directly from the device constexpr size_t kReasonableMaxBatchSize = 65536; size_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); @@ -199,6 +204,7 @@ void extend(raft::resources const& handle, copy_stream = raft::resource::get_stream_from_stream_pool(handle); } } + // Predict the cluster labels for the new data, in batches if necessary utils::batch_load_iterator vec_batches(new_vectors, n_rows, @@ -214,8 +220,27 @@ void extend(raft::resources const& handle, raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); auto batch_labels_view = raft::make_device_vector_view( new_labels.data_handle() + batch.offset(), batch.size()); - cuvs::cluster::kmeans::predict( - handle, kmeans_params, batch_data_view, orig_centroids_view, batch_labels_view); + auto centroids_view = raft::make_device_matrix_view( + index->binary_centers().data_handle(), n_lists, dim); + + if (index->binary_index()) { + if constexpr (std::is_same_v) { + cuvs::cluster::kmeans::detail::predict_bitwise_hamming( + handle, batch_data_view, centroids_view, batch_labels_view); + } else { + RAFT_FAIL("BitwiseHamming distance is only supported with uint8_t data type, got %s", + typeid(T).name()); + } + } else { + auto orig_centroids_view = raft::make_device_matrix_view( + index->centers().data_handle(), n_lists, dim); + cuvs::cluster::kmeans_balanced::predict(handle, + kmeans_params, + batch_data_view, + orig_centroids_view, + batch_labels_view, + utils::mapping{}); + } vec_batches.prefetch_next_batch(); // User needs to make sure kernel finishes its work before we overwrite batch in the next // iteration if different streams are used for kernel and copy. @@ -231,23 +256,67 @@ void extend(raft::resources const& handle, // Calculate the centers and sizes on the new data, starting from the original values if (index->adaptive_centers()) { - auto centroids_view = raft::make_device_matrix_view( - index->centers().data_handle(), index->centers().extent(0), index->centers().extent(1)); auto list_sizes_view = raft::make_device_vector_view, IdxT>( list_sizes_ptr, n_lists); - for (const auto& batch : vec_batches) { - auto batch_data_view = - raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); - auto batch_labels_view = raft::make_device_vector_view( - new_labels.data_handle() + batch.offset(), batch.size()); - cuvs::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, - batch_data_view, - batch_labels_view, - centroids_view, - list_sizes_view, - false, - utils::mapping{}); + + if (index->binary_index()) { + if constexpr (std::is_same_v) { + // For binary data, we need to work in the expanded space and then convert back + rmm::device_uvector temp_expanded_centers( + n_lists * dim * 8, stream, raft::resource::get_workspace_resource(handle)); + auto expanded_centers_view = raft::make_device_matrix_view( + temp_expanded_centers.data(), n_lists, dim * 8); + + raft::linalg::map_offset( + handle, + expanded_centers_view, + utils::bitwise_decode_op(index->binary_centers().data_handle(), dim)); + + vec_batches.reset(); + for (const auto& batch : vec_batches) { + auto batch_labels_view = raft::make_device_vector_view( + new_labels.data_handle() + batch.offset(), batch.size()); + auto batch_data_view = + raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); + cuvs::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, + batch_data_view, + batch_labels_view, + expanded_centers_view, + list_sizes_view, + false, + true, + raft::identity_op{}); + } + + // Convert updated centroids back to binary format + cuvs::preprocessing::quantize::binary::quantizer temp_quantizer(handle); + cuvs::preprocessing::quantize::binary::transform( + handle, temp_quantizer, expanded_centers_view, index->binary_centers()); + + } else { + // Error: BitwiseHamming with non-uint8_t type + RAFT_FAIL("BitwiseHamming distance is only supported with uint8_t data type, got %s", + typeid(T).name()); + } + } else { + auto centroids_view = raft::make_device_matrix_view( + index->centers().data_handle(), index->centers().extent(0), index->centers().extent(1)); + vec_batches.reset(); + for (const auto& batch : vec_batches) { + auto batch_data_view = + raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); + auto batch_labels_view = raft::make_device_vector_view( + new_labels.data_handle() + batch.offset(), batch.size()); + cuvs::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, + batch_data_view, + batch_labels_view, + centroids_view, + list_sizes_view, + false, + false, + utils::mapping{}); + } } } else { raft::stats::histogram(raft::stats::HistTypeAuto, @@ -389,14 +458,22 @@ inline auto build(raft::resources const& handle, auto stream = raft::resource::get_cuda_stream(handle); cuvs::common::nvtx::range fun_scope( "ivf_flat::build(%zu, %u)", size_t(n_rows), dim); + + if (params.metric == cuvs::distance::DistanceType::BitwiseHamming && + !std::is_same_v) { + RAFT_FAIL("BitwiseHamming distance is only supported with uint8_t input type, got %s", + typeid(T).name()); + } static_assert(std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, "unsupported data type"); + RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); RAFT_EXPECTS(params.metric != cuvs::distance::DistanceType::CosineExpanded || dim > 1, "Cosine metric requires more than one dim"); index index(handle, params, dim); + utils::memzero( index.accum_sorted_sizes().data_handle(), index.accum_sorted_sizes().size(), stream); utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream); @@ -408,6 +485,7 @@ inline auto build(raft::resources const& handle, auto trainset_ratio = std::max( 1, n_rows / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); auto n_rows_train = n_rows / trainset_ratio; + rmm::device_uvector trainset( n_rows_train * index.dim(), stream, raft::resource::get_large_workspace_resource(handle)); // TODO: a proper sampling @@ -421,12 +499,39 @@ inline auto build(raft::resources const& handle, stream)); auto trainset_const_view = raft::make_device_matrix_view(trainset.data(), n_rows_train, index.dim()); - auto centers_view = raft::make_device_matrix_view( - index.centers().data_handle(), index.n_lists(), index.dim()); + cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = index.metric(); - cuvs::cluster::kmeans::fit(handle, kmeans_params, trainset_const_view, centers_view); + kmeans_params.metric = + index.binary_index() ? cuvs::distance::DistanceType::L2Expanded : index.metric(); + kmeans_params.is_packed_binary = index.binary_index(); + if constexpr (std::is_same_v) { + if (index.binary_index()) { + rmm::device_uvector decoded_centers(index.n_lists() * index.dim() * 8, + stream, + raft::resource::get_workspace_resource(handle)); + auto decoded_centers_view = raft::make_device_matrix_view( + decoded_centers.data(), index.n_lists(), index.dim() * 8); + + cuvs::cluster::kmeans_balanced::fit( + handle, kmeans_params, trainset_const_view, decoded_centers_view, raft::identity_op{}); + + // Convert decoded centers back to binary format + cuvs::preprocessing::quantize::binary::quantizer temp_quantizer(handle); + cuvs::preprocessing::quantize::binary::transform( + handle, temp_quantizer, decoded_centers_view, index.binary_centers()); + } else { + auto centers_view = raft::make_device_matrix_view( + index.centers().data_handle(), index.n_lists(), index.dim()); + cuvs::cluster::kmeans_balanced::fit( + handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); + } + } else { + auto centers_view = raft::make_device_matrix_view( + index.centers().data_handle(), index.n_lists(), index.dim()); + cuvs::cluster::kmeans_balanced::fit( + handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); + } } // add the data if necessary diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh index 4c0bb3644a..0553619de2 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh @@ -23,6 +23,9 @@ #include +#include +#include + namespace cuvs::neighbors::ivf_flat::detail { using namespace cuvs::spatial::knn::detail; // NOLINT @@ -1130,6 +1133,20 @@ struct inner_prod_dist { } }; +template +struct hamming_dist { + static_assert(std::is_same_v, "hamming_dist only supports uint8_t data type"); + __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) + { + if constexpr (Veclen > 1) { + // x and y are uint32_t, so no static_cast is needed. + acc += __popc(x ^ y); + } else { + acc += __popc(static_cast(x ^ y) & 0xffu); + } + } +}; + /** Select the distance computation function and forward the rest of the arguments. */ template {1.0f}, raft::mul_const_op{-1.0f}), std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when // adding here a new metric. + case cuvs::distance::DistanceType::BitwiseHamming: + if constexpr (std::is_same_v) { + return launch_kernel>( + {}, raft::identity_op{}, std::forward(args)...); + } else { + RAFT_FAIL("BitwiseHamming distance only supports uint8_t data type"); + } + break; default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); } } @@ -1309,6 +1342,11 @@ void ivfflat_interleaved_scan(const index& index, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) { + if (metric == cuvs::distance::DistanceType::BitwiseHamming && !std::is_same_v) { + RAFT_FAIL("BitwiseHamming distance is only supported with uint8_t data type, got %s", + typeid(T).name()); + } + const int capacity = raft::bound_by_power_of_two(k); auto filter_adapter = cuvs::neighbors::filtering::ivf_to_sample_filter( diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index 3379e7b8dc..98cce118a8 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -26,6 +26,9 @@ #include +#include "../../distance/detail/distance_ops/bitwise_hamming.cuh" +#include "../../distance/detail/pairwise_matrix/dispatch.cuh" + namespace cuvs::neighbors::ivf_flat::detail { using namespace cuvs::spatial::knn::detail; // NOLINT @@ -90,87 +93,117 @@ void search_impl(raft::resources const& handle, if constexpr (std::is_same_v) { converted_queries_ptr = const_cast(queries); } else { - raft::linalg::map( - handle, - raft::make_device_vector_view(converted_queries_ptr, n_queries * index.dim()), - utils::mapping{}, - raft::make_const_mdspan( - raft::make_device_vector_view(queries, n_queries * index.dim()))); + raft::linalg::unaryOp( + converted_queries_ptr, queries, n_queries * index.dim(), utils::mapping{}, stream); } - float alpha = 1.0f; - float beta = 0.0f; + if (index.metric() == cuvs::distance::DistanceType::BitwiseHamming) { + if constexpr (std::is_same_v) { + cuvs::distance::detail::ops::bitwise_hamming_distance_op distance_op{ + static_cast(index.dim())}; + + rmm::device_uvector uint32_distances( + n_queries * index.n_lists(), stream, search_mr); + + cuvs::distance::detail::pairwise_matrix_dispatch(distance_op, + static_cast(n_queries), + static_cast(index.n_lists()), + static_cast(index.dim()), + queries, + index.binary_centers().data_handle(), + nullptr, + nullptr, + uint32_distances.data(), + raft::identity_op{}, + stream, + true); + + // Convert uint32_t distances to float for compatibility with rest of pipeline + raft::linalg::unaryOp(distance_buffer_dev.data(), + uint32_distances.data(), + n_queries * index.n_lists(), + raft::cast_op{}, + stream); + } + } else { + float alpha = 1.0f; + float beta = 0.0f; + + // todo(lsugy): raft distance? (if performance is similar/better than gemm) + switch (index.metric()) { + case cuvs::distance::DistanceType::L2Expanded: + case cuvs::distance::DistanceType::L2SqrtExpanded: { + alpha = -2.0f; + beta = 1.0f; + raft::linalg::rowNorm(query_norm_dev.data(), + converted_queries_ptr, + static_cast(index.dim()), + static_cast(n_queries), + stream); + utils::outer_add(query_norm_dev.data(), + (IdxT)n_queries, + index.center_norms()->data_handle(), + (IdxT)index.n_lists(), + distance_buffer_dev.data(), + stream); + RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), + std::min(20, index.dim())); + RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); + break; + } + case cuvs::distance::DistanceType::CosineExpanded: { + raft::linalg::rowNorm(query_norm_dev.data(), + converted_queries_ptr, + static_cast(index.dim()), + static_cast(n_queries), + stream, + raft::sqrt_op{}); + alpha = -1.0f; + beta = 0.0f; + break; + } + default: { + alpha = 1.0f; + beta = 0.0f; + } + } - // todo(lsugy): raft distance? (if performance is similar/better than gemm) - switch (index.metric()) { - case cuvs::distance::DistanceType::L2Expanded: - case cuvs::distance::DistanceType::L2SqrtExpanded: { - alpha = -2.0f; - beta = 1.0f; - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - converted_queries_ptr, static_cast(n_queries), static_cast(index.dim())), - raft::make_device_vector_view(query_norm_dev.data(), - static_cast(n_queries))); - utils::outer_add(query_norm_dev.data(), - (IdxT)n_queries, - index.center_norms()->data_handle(), - (IdxT)index.n_lists(), + raft::linalg::gemm(handle, + true, + false, + index.n_lists(), + n_queries, + index.dim(), + &alpha, + index.centers().data_handle(), + index.dim(), + converted_queries_ptr, + index.dim(), + &beta, distance_buffer_dev.data(), + index.n_lists(), stream); - RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min(20, index.dim())); - RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - break; - } - case cuvs::distance::DistanceType::CosineExpanded: { - raft::linalg::norm( + + if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { + auto n_lists = index.n_lists(); + const auto* q_norm_ptr = query_norm_dev.data(); + const auto* index_center_norm_ptr = index.center_norms()->data_handle(); + raft::linalg::map_offset( handle, - raft::make_device_matrix_view( - converted_queries_ptr, static_cast(n_queries), static_cast(index.dim())), - raft::make_device_vector_view(query_norm_dev.data(), - static_cast(n_queries)), - raft::sqrt_op{}); - alpha = -1.0f; - beta = 0.0f; - break; - } - default: { - alpha = 1.0f; - beta = 0.0f; + distance_buffer_dev_view, + [=] __device__(const uint32_t idx, const float dist) { + const auto query = idx / n_lists; + const auto cluster = idx % n_lists; + return dist / (q_norm_ptr[query] * index_center_norm_ptr[cluster]); + }, + raft::make_const_mdspan(distance_buffer_dev_view)); } } - - raft::linalg::gemm(handle, - true, - false, - index.n_lists(), - n_queries, - index.dim(), - &alpha, - index.centers().data_handle(), - index.dim(), - converted_queries_ptr, - index.dim(), - &beta, - distance_buffer_dev.data(), - index.n_lists(), - stream); - - if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { - auto n_lists = index.n_lists(); - const auto* q_norm_ptr = query_norm_dev.data(); - const auto* index_center_norm_ptr = index.center_norms()->data_handle(); - raft::linalg::map_offset( - handle, - distance_buffer_dev_view, - [=] __device__(const uint32_t idx, const float dist) { - const auto query = idx / n_lists; - const auto cluster = idx % n_lists; - return dist / (q_norm_ptr[query] * index_center_norm_ptr[cluster]); - }, - raft::make_const_mdspan(distance_buffer_dev_view)); - } RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); cuvs::selection::select_k( diff --git a/cpp/src/neighbors/ivf_flat_index.cpp b/cpp/src/neighbors/ivf_flat_index.cpp index 77b24d4690..f8d7b17209 100644 --- a/cpp/src/neighbors/ivf_flat_index.cpp +++ b/cpp/src/neighbors/ivf_flat_index.cpp @@ -1,9 +1,13 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ +#include +#include #include +#include +#include namespace cuvs::neighbors::ivf_flat { @@ -38,12 +42,23 @@ index::index(raft::resources const& res, conservative_memory_allocation_{conservative_memory_allocation}, lists_{n_lists}, list_sizes_{raft::make_device_vector(res, n_lists)}, - centers_(raft::make_device_matrix(res, n_lists, dim)), + centers_(metric != cuvs::distance::DistanceType::BitwiseHamming + ? raft::make_device_matrix(res, n_lists, dim) + : raft::make_device_matrix(res, 0, 0)), + binary_centers_(metric != cuvs::distance::DistanceType::BitwiseHamming + ? raft::make_device_matrix(res, 0, 0) + : raft::make_device_matrix(res, n_lists, dim)), center_norms_(std::nullopt), data_ptrs_{raft::make_device_vector(res, n_lists)}, inds_ptrs_{raft::make_device_vector(res, n_lists)}, - accum_sorted_sizes_{raft::make_host_vector(n_lists + 1)} + accum_sorted_sizes_{raft::make_host_vector(n_lists + 1)}, + binary_index_(metric == cuvs::distance::DistanceType::BitwiseHamming) { + if (metric == cuvs::distance::DistanceType::BitwiseHamming && !std::is_same_v) { + RAFT_FAIL("BitwiseHamming distance is only supported with uint8_t data type, got %s", + typeid(T).name()); + } + check_consistency(); accum_sorted_sizes_(n_lists) = 0; } @@ -91,6 +106,19 @@ raft::device_matrix_view index: return centers_.view(); } +template +raft::device_matrix_view +index::binary_centers() noexcept +{ + return binary_centers_.view(); +} + +template +raft::device_matrix_view index::binary_centers() + const noexcept +{ + return binary_centers_.view(); +} template std::optional> index::center_norms() noexcept { @@ -135,7 +163,11 @@ IdxT index::size() const noexcept template uint32_t index::dim() const noexcept { - return centers_.extent(1); + if (binary_index_) { + return binary_centers_.extent(1); + } else { + return centers_.extent(1); + } } template @@ -209,10 +241,21 @@ void index::check_consistency() RAFT_EXPECTS(list_sizes_.extent(0) == n_lists, "inconsistent list size"); RAFT_EXPECTS(data_ptrs_.extent(0) == n_lists, "inconsistent list size"); RAFT_EXPECTS(inds_ptrs_.extent(0) == n_lists, "inconsistent list size"); - RAFT_EXPECTS( // - (centers_.extent(0) == list_sizes_.extent(0)) && // - (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), - "inconsistent number of lists (clusters)"); + if (binary_index_) { + RAFT_EXPECTS(binary_centers_.extent(0) == list_sizes_.extent(0), + "inconsistent number of lists (clusters)"); + } else { + RAFT_EXPECTS( // + (centers_.extent(0) == list_sizes_.extent(0)) && // + (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), + "inconsistent number of lists (clusters)"); + } +} + +template +bool index::binary_index() const noexcept +{ + return binary_index_; } template struct index; // Used for refine function diff --git a/cpp/tests/neighbors/ann_ivf_flat.cuh b/cpp/tests/neighbors/ann_ivf_flat.cuh index ffc1d03bf6..97dee0fa24 100644 --- a/cpp/tests/neighbors/ann_ivf_flat.cuh +++ b/cpp/tests/neighbors/ann_ivf_flat.cuh @@ -17,6 +17,8 @@ #include #include +#include +#include #include #include #include @@ -66,6 +68,11 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { void testIVFFlat() { + if ((ps.metric == cuvs::distance::DistanceType::BitwiseHamming) && + !(std::is_same_v)) { + GTEST_SKIP(); + } + size_t queries_size = ps.num_queries * ps.k; std::vector indices_ivfflat(queries_size); std::vector indices_naive(queries_size); @@ -205,41 +212,53 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { // Test the centroid invariants if (index_2.adaptive_centers()) { - // The centers must be up-to-date with the corresponding data - std::vector list_sizes(index_2.n_lists()); - std::vector list_indices(index_2.n_lists()); - rmm::device_uvector centroid(ps.dim, stream_); - raft::copy( - list_sizes.data(), index_2.list_sizes().data_handle(), index_2.n_lists(), stream_); - raft::copy( - list_indices.data(), index_2.inds_ptrs().data_handle(), index_2.n_lists(), stream_); - raft::resource::sync_stream(handle_); - for (uint32_t l = 0; l < index_2.n_lists(); l++) { - if (list_sizes[l] == 0) continue; - rmm::device_uvector cluster_data(list_sizes[l] * ps.dim, stream_); - cuvs::spatial::knn::detail::utils::copy_selected((IdxT)list_sizes[l], - (IdxT)ps.dim, - database.data(), - list_indices[l], - (IdxT)ps.dim, - cluster_data.data(), - (IdxT)ps.dim, - stream_); - raft::stats::mean( - centroid.data(), cluster_data.data(), ps.dim, list_sizes[l], false, stream_); - ASSERT_TRUE(cuvs::devArrMatch(index_2.centers().data_handle() + ps.dim * l, - centroid.data(), - ps.dim, - cuvs::CompareApprox(0.001), - stream_)); + // Skip centroid verification for BitwiseHamming metric + if (ps.metric != cuvs::distance::DistanceType::BitwiseHamming) { + // The centers must be up-to-date with the corresponding data + std::vector list_sizes(index_2.n_lists()); + std::vector list_indices(index_2.n_lists()); + rmm::device_uvector centroid(ps.dim, stream_); + raft::copy( + list_sizes.data(), index_2.list_sizes().data_handle(), index_2.n_lists(), stream_); + raft::copy( + list_indices.data(), index_2.inds_ptrs().data_handle(), index_2.n_lists(), stream_); + raft::resource::sync_stream(handle_); + for (uint32_t l = 0; l < index_2.n_lists(); l++) { + if (list_sizes[l] == 0) continue; + rmm::device_uvector cluster_data(list_sizes[l] * ps.dim, stream_); + cuvs::spatial::knn::detail::utils::copy_selected((IdxT)list_sizes[l], + (IdxT)ps.dim, + database.data(), + list_indices[l], + (IdxT)ps.dim, + cluster_data.data(), + (IdxT)ps.dim, + stream_); + raft::stats::mean( + centroid.data(), cluster_data.data(), ps.dim, list_sizes[l], false, stream_); + ASSERT_TRUE(cuvs::devArrMatch(index_2.centers().data_handle() + ps.dim * l, + centroid.data(), + ps.dim, + cuvs::CompareApprox(0.001), + stream_)); + } } } else { // The centers must be immutable - ASSERT_TRUE(cuvs::devArrMatch(index_2.centers().data_handle(), - idx.centers().data_handle(), - index_2.centers().size(), - cuvs::Compare(), - stream_)); + if (ps.metric == cuvs::distance::DistanceType::BitwiseHamming) { + // For BitwiseHamming, compare binary centers + ASSERT_TRUE(cuvs::devArrMatch(index_2.binary_centers().data_handle(), + idx.binary_centers().data_handle(), + index_2.binary_centers().size(), + cuvs::Compare(), + stream_)); + } else { + ASSERT_TRUE(cuvs::devArrMatch(index_2.centers().data_handle(), + idx.centers().data_handle(), + index_2.centers().size(), + cuvs::Compare(), + stream_)); + } } } float eps = std::is_same_v ? 0.005 : 0.001; @@ -256,6 +275,11 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { void testPacker() { + if ((ps.metric == cuvs::distance::DistanceType::BitwiseHamming) && + !(std::is_same_v)) { + GTEST_SKIP(); + } + ivf_flat::index_params index_params; ivf_flat::search_params search_params; index_params.n_lists = ps.nlist; @@ -388,6 +412,11 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { void testFilter() { + if ((ps.metric == cuvs::distance::DistanceType::BitwiseHamming) && + !(std::is_same_v)) { + GTEST_SKIP(); + } + size_t queries_size = ps.num_queries * ps.k; std::vector indices_ivfflat(queries_size); std::vector indices_naive(queries_size); @@ -495,6 +524,12 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); raft::random::uniform( handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); + } else if (ps.metric == cuvs::distance::DistanceType::BitwiseHamming && + std::is_same_v) { + raft::random::uniformInt( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0), DataT(255)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0), DataT(255)); } else { raft::random::uniformInt( handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); @@ -524,14 +559,19 @@ const std::vector> inputs = { {1000, 10000, 1, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::BitwiseHamming, false}, {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::BitwiseHamming, false}, {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::BitwiseHamming, false}, {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::BitwiseHamming, false}, {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::BitwiseHamming, false}, {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, true}, @@ -558,50 +598,70 @@ const std::vector> inputs = { // various random combinations {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::BitwiseHamming, false}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::BitwiseHamming, false}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::BitwiseHamming, false}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::BitwiseHamming, false}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::BitwiseHamming, false}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::BitwiseHamming, false}, // host input data {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::BitwiseHamming, false, true}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::BitwiseHamming, false, true}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::BitwiseHamming, false, true}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false, true}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::BitwiseHamming, false, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::BitwiseHamming, false, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::BitwiseHamming, false, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::BitwiseHamming, false, true}, // // host input data with prefetching for kernel copy overlapping {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::BitwiseHamming, false, true, true}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::BitwiseHamming, false, true, true}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::BitwiseHamming, false, true, true}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false, true, true}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::BitwiseHamming, false, true, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::BitwiseHamming, false, true, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::BitwiseHamming, false, true, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::BitwiseHamming, false, true, true}, {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, @@ -624,10 +684,13 @@ const std::vector> inputs = { // test splitting the big query batches (> max gridDim.y) into smaller batches {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, false}, {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded, false}, + {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::BitwiseHamming, false}, {1000000, 1024, 32, 10, 256, 256, cuvs::distance::DistanceType::InnerProduct, false}, {1000000, 1024, 32, 10, 256, 256, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000000, 1024, 32, 10, 256, 256, cuvs::distance::DistanceType::BitwiseHamming, false}, {98306, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, true}, {98306, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded, true}, + {98306, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::BitwiseHamming, false}, // test radix_sort for getting the cluster selection {1000, @@ -654,10 +717,24 @@ const std::vector> inputs = { raft::matrix::detail::select::warpsort::kMaxCapacity * 4, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, + 10000, + 16, + 10, + raft::matrix::detail::select::warpsort::kMaxCapacity * 4, + raft::matrix::detail::select::warpsort::kMaxCapacity * 4, + cuvs::distance::DistanceType::BitwiseHamming, + false}, // The following two test cases should show very similar recall. // num_queries, num_db_vecs, dim, k, nprobe, nlist, metric, adaptive_centers {20000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded, false}, - {100000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded, false}}; + {100000, 8712, 3, 10, 51, 66, cuvs::distance::DistanceType::L2Expanded, false}, + + // BitwiseHamming with adaptive centers + {1000, 10000, 32, 16, 20, 80, cuvs::distance::DistanceType::BitwiseHamming, true}, + {1000, 10000, 64, 16, 20, 80, cuvs::distance::DistanceType::BitwiseHamming, true}, + {1000, 10000, 128, 16, 20, 80, cuvs::distance::DistanceType::BitwiseHamming, true}, +}; } // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/tests/neighbors/ann_utils.cuh b/cpp/tests/neighbors/ann_utils.cuh index 8a908c0187..9cb5031b11 100644 --- a/cpp/tests/neighbors/ann_utils.cuh +++ b/cpp/tests/neighbors/ann_utils.cuh @@ -106,6 +106,7 @@ inline auto operator<<(std::ostream& os, const print_metric& p) -> std::ostream& break; case cuvs::distance::DistanceType::DiceExpanded: os << "distance::DiceExpanded"; break; case cuvs::distance::DistanceType::Precomputed: os << "distance::Precomputed"; break; + case cuvs::distance::DistanceType::BitwiseHamming: os << "distance::BitwiseHamming"; break; default: RAFT_FAIL("unreachable code"); } return os;