diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index a2556dc651..e09af1d7b3 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -485,30 +485,33 @@ void train_per_cluster(raft::resources const& handle, } /** - * A consumer for the `run_on_list` and `run_on_vector` that just flattens PQ codes - * one-per-byte. That is, independent of the code width (pq_bits), one code uses - * the whole byte, hence one vectors uses pq_dim bytes. + * Unpack codes from the packed list data for aligned cases (pq_bits = 4, 8) */ -struct unpack_codes { - raft::device_matrix_view out_codes; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[out] out_codes the destination for the read codes. - */ - __device__ inline unpack_codes( - raft::device_matrix_view out_codes) - : out_codes{out_codes} - { - } - - /** Write j-th component (code) of the i-th vector into the output array. */ - __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - out_codes(i, j) = code; +template +__device__ inline void unpack_codes_impl( + raft::device_mdspan::list_extents, raft::row_major> + in_list_data, + uint32_t in_ix, + uint8_t* out_codes, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(in_ix); + const uint32_t ingroup_ix = group_align::mod(in_ix); + + pq_vec_t code_chunk; + bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + // read the chunk + code_chunk = *reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0)); + // extract the codes, one at a time + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + out_codes[j] = code_view[k]; + } } -}; +} template __launch_bounds__(BlockSize) static __global__ void unpack_list_data_kernel( @@ -517,9 +520,17 @@ __launch_bounds__(BlockSize) static __global__ void unpack_list_data_kernel( in_list_data, std::variant offset_or_indices) { + uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; + if (ix >= out_codes.extent(0)) return; + + const uint32_t src_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + ix + : std::get(offset_or_indices)[ix]; + const uint32_t pq_dim = out_codes.extent(1); - auto unpack_action = unpack_codes{out_codes}; - run_on_list(in_list_data, offset_or_indices, out_codes.extent(0), pq_dim, unpack_action); + uint8_t* out_codes_ptr = &out_codes(ix, 0); + + unpack_codes_impl(in_list_data, src_ix, out_codes_ptr, pq_dim); } /** @@ -574,61 +585,52 @@ void unpack_list_data(raft::resources const& res, raft::resource::get_cuda_stream(res)); } -/** A consumer for the `run_on_list` and `run_on_vector` that approximates the original input data. +/** + * Reconstruct a vector from PQ codes. */ -struct reconstruct_vectors { - codebook_gen codebook_kind; - uint32_t cluster_ix; - uint32_t pq_len; - raft::device_mdspan, raft::row_major> pq_centers; - raft::device_mdspan, raft::row_major> centers_rot; - raft::device_mdspan, raft::row_major> out_vectors; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[out] out_vectors the destination for the decoded vectors. - * @param[in] pq_centers the codebook - * @param[in] centers_rot - * @param[in] codebook_kind - * @param[in] cluster_ix label/id of the cluster. - */ - __device__ inline reconstruct_vectors( - raft::device_matrix_view out_vectors, - raft::device_mdspan, raft::row_major> pq_centers, - raft::device_matrix_view centers_rot, - codebook_gen codebook_kind, - uint32_t cluster_ix) - : codebook_kind{codebook_kind}, - cluster_ix{cluster_ix}, - pq_len{pq_centers.extent(1)}, - pq_centers{pq_centers}, - centers_rot{reinterpret_vectors(centers_rot, pq_centers)}, - out_vectors{reinterpret_vectors(out_vectors, pq_centers)} - { - } - - /** - * Decode j-th component of the i-th vector by its code and write it into a chunk of the output - * vectors (pq_len elements). - */ - __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - uint32_t partition_ix; - switch (codebook_kind) { - case codebook_gen::PER_CLUSTER: { - partition_ix = cluster_ix; - } break; - case codebook_gen::PER_SUBSPACE: { - partition_ix = j; - } break; - default: __builtin_unreachable(); - } - for (uint32_t k = 0; k < pq_len; k++) { - out_vectors(i, j, k) = pq_centers(partition_ix, k, code) + centers_rot(cluster_ix, j, k); +template +__device__ inline void reconstruct_vector_impl( + raft::device_mdspan::list_extents, raft::row_major> + in_list_data, + uint32_t in_ix, + float* out_vector, + raft::device_mdspan, raft::row_major> pq_centers, + raft::device_mdspan, raft::row_major> centers_rot, + codebook_gen codebook_kind, + uint32_t cluster_ix, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(in_ix); + const uint32_t ingroup_ix = group_align::mod(in_ix); + const uint32_t pq_len = pq_centers.extent(1); + + pq_vec_t code_chunk; + bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + // read the chunk + code_chunk = *reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0)); + // reconstruct the codes + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + uint8_t code = code_view[k]; + uint32_t partition_ix; + switch (codebook_kind) { + case codebook_gen::PER_CLUSTER: { + partition_ix = cluster_ix; + } break; + case codebook_gen::PER_SUBSPACE: { + partition_ix = j; + } break; + default: __builtin_unreachable(); + } + for (uint32_t l = 0; l < pq_len; l++) { + out_vector[j * pq_len + l] = pq_centers(partition_ix, l, code) + centers_rot(cluster_ix, j, l); + } } } -}; +} template __launch_bounds__(BlockSize) static __global__ void reconstruct_list_data_kernel( @@ -641,11 +643,19 @@ __launch_bounds__(BlockSize) static __global__ void reconstruct_list_data_kernel uint32_t cluster_ix, std::variant offset_or_indices) { + uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; + if (ix >= out_vectors.extent(0)) return; + + const uint32_t src_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + ix + : std::get(offset_or_indices)[ix]; + const uint32_t pq_dim = out_vectors.extent(1) / pq_centers.extent(1); - auto reconstruct_action = - reconstruct_vectors{out_vectors, pq_centers, centers_rot, codebook_kind, cluster_ix}; - run_on_list( - in_list_data, offset_or_indices, out_vectors.extent(0), pq_dim, reconstruct_action); + float* out_vector_ptr = &out_vectors(ix, 0); + + auto centers_rot_3d = reinterpret_vectors(centers_rot, pq_centers); + reconstruct_vector_impl(in_list_data, src_ix, out_vector_ptr, pq_centers, + centers_rot_3d, codebook_kind, cluster_ix, pq_dim); } /** Decode the list data; see the public interface for the api and usage. */ @@ -733,27 +743,35 @@ void reconstruct_list_data(raft::resources const& res, } /** - * A producer for the `write_list` and `write_vector` reads the codes byte-by-byte. That is, - * independent of the code width (pq_bits), one code uses the whole byte, hence one vectors uses - * pq_dim bytes. + * Pack codes into the list data format. */ -struct pass_codes { - raft::device_matrix_view codes; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[in] codes the source codes. - */ - __device__ inline pass_codes( - raft::device_matrix_view codes) - : codes{codes} - { +template +__device__ inline void pack_codes_impl( + raft::device_mdspan::list_extents, raft::row_major> + out_list_data, + uint32_t out_ix, + const uint8_t* in_codes, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); + + pq_vec_t code_chunk; + bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + // clear the chunk + code_chunk = pq_vec_t{}; + // pack the codes + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + code_view[k] = in_codes[j]; + } + // write the chunk to the list + *reinterpret_cast(&out_list_data(group_ix, i, ingroup_ix, 0)) = code_chunk; } - - /** Read j-th component (code) of the i-th vector from the source. */ - __device__ inline auto operator()(uint32_t i, uint32_t j) const -> uint8_t { return codes(i, j); } -}; +} template __launch_bounds__(BlockSize) static __global__ void pack_list_data_kernel( @@ -762,8 +780,15 @@ __launch_bounds__(BlockSize) static __global__ void pack_list_data_kernel( raft::device_matrix_view codes, std::variant offset_or_indices) { - write_list( - list_data, offset_or_indices, codes.extent(0), codes.extent(1), pass_codes{codes}); + uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; + if (ix >= codes.extent(0)) return; + + const uint32_t dst_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + ix + : std::get(offset_or_indices)[ix]; + + const uint8_t* codes_ptr = &codes(ix, 0); + pack_codes_impl(list_data, dst_ix, codes_ptr, codes.extent(1)); } /** @@ -819,6 +844,85 @@ void pack_list_data(raft::resources const& res, raft::resource::get_cuda_stream(res)); } +/** + * Encode a vector on-the-fly by finding the closest PQ centers. + */ +template +__device__ inline void encode_vector_impl( + raft::device_mdspan::list_extents, raft::row_major> + out_list_data, + uint32_t out_ix, + uint32_t in_ix, + raft::device_mdspan, raft::row_major> pq_centers, + raft::device_matrix_view new_vectors, + codebook_gen codebook_kind, + uint32_t cluster_ix, + uint32_t pq_dim) +{ + const uint32_t lane_id = raft::Pow2::mod(threadIdx.x); + + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); + const uint32_t pq_len = pq_centers.extent(1); + const uint32_t pq_book_size = pq_centers.extent(2); + + auto new_vectors_3d = reinterpret_vectors(new_vectors, pq_centers); + + pq_vec_t code_chunk; + bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + // clear the chunk + if (lane_id == 0) { code_chunk = pq_vec_t{}; } + // encode codes + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + uint32_t partition_ix; + switch (codebook_kind) { + case codebook_gen::PER_CLUSTER: { + partition_ix = cluster_ix; + } break; + case codebook_gen::PER_SUBSPACE: { + partition_ix = j; + } break; + default: __builtin_unreachable(); + } + + float min_dist = std::numeric_limits::infinity(); + uint8_t code = 0; + // calculate the distance for each PQ cluster, find the minimum for each thread + for (uint32_t l = lane_id; l < pq_book_size; l += SubWarpSize) { + float d = 0.0f; + for (uint32_t m = 0; m < pq_len; m++) { + auto t = new_vectors_3d(in_ix, j, m) - pq_centers(partition_ix, m, l); + d += t * t; + } + if (d < min_dist) { + min_dist = d; + code = uint8_t(l); + } + } + // reduce among threads + #pragma unroll + for (uint32_t stride = SubWarpSize >> 1; stride > 0; stride >>= 1) { + const auto other_dist = raft::shfl_xor(min_dist, stride, SubWarpSize); + const auto other_code = raft::shfl_xor(code, stride, SubWarpSize); + if (other_dist < min_dist) { + min_dist = other_dist; + code = other_code; + } + } + + if (lane_id == 0) { code_view[k] = code; } + } + // write the chunk to the list + if (lane_id == 0) { + *reinterpret_cast(&out_list_data(group_ix, i, ingroup_ix, 0)) = code_chunk; + } + } +} + template __launch_bounds__(BlockSize) static __global__ void encode_list_data_kernel( raft::device_mdspan::list_extents, raft::row_major> @@ -830,11 +934,18 @@ __launch_bounds__(BlockSize) static __global__ void encode_list_data_kernel( std::variant offset_or_indices) { constexpr uint32_t kSubWarpSize = std::min(raft::WarpSize, 1u << PqBits); - const uint32_t pq_dim = new_vectors.extent(1) / pq_centers.extent(1); - auto encode_action = - encode_vectors{pq_centers, new_vectors, codebook_kind, cluster_ix}; - write_list( - list_data, offset_or_indices, new_vectors.extent(0), pq_dim, encode_action); + const uint32_t warp_ix = (threadIdx.x + blockDim.x * blockIdx.x) / kSubWarpSize; + + if (warp_ix >= new_vectors.extent(0)) return; + + const uint32_t dst_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + warp_ix + : std::get(offset_or_indices)[warp_ix]; + + const uint32_t pq_dim = new_vectors.extent(1) / pq_centers.extent(1); + + encode_vector_impl(list_data, dst_ix, warp_ix, pq_centers, + new_vectors, codebook_kind, cluster_ix, pq_dim); } template diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh index 227cdb3558..32da500d85 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh @@ -15,6 +15,7 @@ */ #pragma once +#include "ivf_pq_build.cuh" #include "ivf_pq_codepacking.cuh" #include #include @@ -24,35 +25,72 @@ #include namespace cuvs::neighbors::ivf_pq::detail { + +using cuvs::neighbors::ivf_pq::kIndexGroupSize; +using cuvs::neighbors::ivf_pq::kIndexGroupVecLen; + +template +__device__ inline void unpack_vector_chunks( + raft::device_mdspan::list_extents, raft::row_major> + in_list_data, + uint32_t out_ix, + uint8_t* codes, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); + + constexpr uint32_t chunk_bytes = sizeof(pq_vec_t); + constexpr uint32_t codes_per_chunk = (chunk_bytes * 8u) / PqBits; + uint32_t n_chunks = raft::ceildiv(pq_dim, codes_per_chunk); + + for (uint32_t i = 0; i < n_chunks; i++) { + pq_vec_t chunk; + if (i < n_chunks - 1 || (pq_dim % codes_per_chunk) == 0) { + chunk = *reinterpret_cast(in_list_data(group_ix, i, ingroup_ix, 0)); + } else { + chunk = pq_vec_t{}; + uint32_t occupied_bytes = raft::ceildiv((pq_dim % codes_per_chunk) * PqBits, 8); + auto* chunk_bytes_ptr = reinterpret_cast(&chunk); + for (uint32_t j = 0; j < occupied_bytes; j++) { + chunk_bytes_ptr[j] = + reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0))[j]; + } + } + *reinterpret_cast(codes + i * chunk_bytes) = chunk; + } +} + /** - * A consumer for the `run_on_vector` that just flattens PQ codes - * into a tightly packed matrix. That is, the codes are not expanded to one code-per-byte. + * Pack a vector by extracting each code (for unaligned cases: pq_bits = 5, 6, 7) */ template -struct unpack_contiguous { - uint8_t* codes; - uint32_t code_size; - - /** - * Create a callable to be passed to `run_on_vector`. - * - * @param[in] codes flat compressed PQ codes - */ - __host__ __device__ inline unpack_contiguous(uint8_t* codes, uint32_t pq_dim) - : codes{codes}, code_size{raft::ceildiv(pq_dim * PqBits, 8)} - { - } +__device__ inline void unpack_vector( + raft::device_mdspan::list_extents, raft::row_major> + in_list_data, + uint32_t out_ix, + uint8_t* codes, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); - /** Write j-th component (code) of the i-th vector into the output array. */ - __host__ __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - bitfield_view_t code_view{codes + i * code_size}; - code_view[j] = code; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + + bitfield_view_t dst_view{codes}; + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + bitfield_view_t src_view{const_cast( + reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0)))}; + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + dst_view[j] = src_view[k]; + } } -}; +} template -__launch_bounds__(BlockSize) static __global__ void unpack_contiguous_list_data_kernel( +__launch_bounds__(BlockSize) static __global__ void unpack_list_chunks_kernel( uint8_t* out_codes, raft::device_mdspan::list_extents, raft::row_major> in_list_data, @@ -60,8 +98,23 @@ __launch_bounds__(BlockSize) static __global__ void unpack_contiguous_list_data_ uint32_t pq_dim, std::variant offset_or_indices) { - run_on_list( - in_list_data, offset_or_indices, n_rows, pq_dim, unpack_contiguous(out_codes, pq_dim)); + uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; + if (ix >= n_rows) return; + + const uint32_t dst_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + ix + : std::get(offset_or_indices)[ix]; + + const uint32_t code_size = raft::ceildiv(pq_dim * PqBits, 8); + uint8_t* dst_codes = out_codes + ix * code_size; + + if constexpr (PqBits == 4 || PqBits == 8) { + // aligned case: direct chunk copies + unpack_vector_chunks(in_list_data, dst_ix, dst_codes, pq_dim); + } else { + // unaligned case: extract each code + unpack_vector(in_list_data, dst_ix, dst_codes, pq_dim); + } } /** @@ -90,11 +143,11 @@ inline void unpack_contiguous_list_data_impl( dim3 threads(kBlockSize, 1, 1); auto kernel = [pq_bits]() { switch (pq_bits) { - case 4: return unpack_contiguous_list_data_kernel; - case 5: return unpack_contiguous_list_data_kernel; - case 6: return unpack_contiguous_list_data_kernel; - case 7: return unpack_contiguous_list_data_kernel; - case 8: return unpack_contiguous_list_data_kernel; + case 4: return unpack_list_chunks_kernel; + case 5: return unpack_list_chunks_kernel; + case 6: return unpack_list_chunks_kernel; + case 7: return unpack_list_chunks_kernel; + case 8: return unpack_list_chunks_kernel; default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); } }(); @@ -103,34 +156,70 @@ inline void unpack_contiguous_list_data_impl( } /** - * A producer for the `write_vector` reads tightly packed flat codes. That is, - * the codes are not expanded to one code-per-byte. */ template -struct pack_contiguous { - const uint8_t* codes; - uint32_t code_size; - - /** - * Create a callable to be passed to `write_vector`. - * - * @param[in] codes flat compressed PQ codes - */ - __host__ __device__ inline pack_contiguous(const uint8_t* codes, uint32_t pq_dim) - : codes{codes}, code_size{raft::ceildiv(pq_dim * PqBits, 8)} - { +__device__ inline void pack_vector_chunks( + raft::device_mdspan::list_extents, raft::row_major> + out_list_data, + uint32_t out_ix, + const uint8_t* codes, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); + + constexpr uint32_t chunk_bytes = sizeof(pq_vec_t); + constexpr uint32_t codes_per_chunk = (chunk_bytes * 8u) / PqBits; + uint32_t n_chunks = raft::ceildiv(pq_dim, codes_per_chunk); + + for (uint32_t i = 0; i < n_chunks; i++) { + pq_vec_t chunk; + if (i < n_chunks - 1 || (pq_dim % codes_per_chunk) == 0) { + chunk = *reinterpret_cast(codes + i * chunk_bytes); + } else { + chunk = pq_vec_t{}; + uint32_t occupied_bytes = raft::ceildiv((pq_dim % codes_per_chunk) * PqBits, 8); + auto* chunk_bytes_ptr = reinterpret_cast(&chunk); + for (uint32_t j = 0; j < occupied_bytes; j++) { + chunk_bytes_ptr[j] = codes[i * chunk_bytes + j]; + } + } + *reinterpret_cast(&out_list_data(group_ix, i, ingroup_ix, 0)) = chunk; } +} - /** Read j-th component (code) of the i-th vector from the source. */ - __host__ __device__ inline auto operator()(uint32_t i, uint32_t j) -> uint8_t - { - bitfield_view_t code_view{const_cast(codes + i * code_size)}; - return uint8_t(code_view[j]); +/** + * Pack a vector by extracting each code (for unaligned cases: pq_bits = 5, 6, 7) + */ +template +__device__ inline void pack_vector( + raft::device_mdspan::list_extents, raft::row_major> + out_list_data, + uint32_t out_ix, + const uint8_t* codes, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); + + pq_vec_t code_chunk = pq_vec_t{}; + bitfield_view_t src_view{const_cast(codes)}; + bitfield_view_t dst_view{reinterpret_cast(&code_chunk)}; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + dst_view[k] = src_view[j]; + } + *reinterpret_cast(&out_list_data(group_ix, i, ingroup_ix, 0)) = code_chunk; + if (j < pq_dim) code_chunk = pq_vec_t{}; } -}; +} template -__launch_bounds__(BlockSize) static __global__ void pack_contiguous_list_data_kernel( +__launch_bounds__(BlockSize) static __global__ void pack_list_chunks_kernel( raft::device_mdspan::list_extents, raft::row_major> list_data, const uint8_t* codes, @@ -138,8 +227,23 @@ __launch_bounds__(BlockSize) static __global__ void pack_contiguous_list_data_ke uint32_t pq_dim, std::variant offset_or_indices) { - write_list( - list_data, offset_or_indices, n_rows, pq_dim, pack_contiguous(codes, pq_dim)); + uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; + if (ix >= n_rows) return; + + const uint32_t dst_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + ix + : std::get(offset_or_indices)[ix]; + + const uint32_t code_size = raft::ceildiv(pq_dim * PqBits, 8); + const uint8_t* src_codes = codes + ix * code_size; + + if constexpr (PqBits == 4 || PqBits == 8) { + // aligned case: direct chunk copies + pack_vector_chunks(list_data, dst_ix, src_codes, pq_dim); + } else { + // unaligned case: extract each code + pack_vector(list_data, dst_ix, src_codes, pq_dim); + } } /** @@ -170,11 +274,11 @@ inline void pack_contiguous_list_data_impl( dim3 threads(kBlockSize, 1, 1); auto kernel = [pq_bits]() { switch (pq_bits) { - case 4: return pack_contiguous_list_data_kernel; - case 5: return pack_contiguous_list_data_kernel; - case 6: return pack_contiguous_list_data_kernel; - case 7: return pack_contiguous_list_data_kernel; - case 8: return pack_contiguous_list_data_kernel; + case 4: return pack_list_chunks_kernel; + case 5: return pack_list_chunks_kernel; + case 6: return pack_list_chunks_kernel; + case 7: return pack_list_chunks_kernel; + case 8: return pack_list_chunks_kernel; default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); } }();