From b55a423ad77b57a0f9711d755a7de7130f6aa37c Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 29 Sep 2025 19:18:35 -0700 Subject: [PATCH 01/12] initial --- .../neighbors/ivf_pq/ivf_pq_codepacking.cuh | 30 +++++++++++++++++++ .../ivf_pq_contiguous_list_data_impl.cuh | 11 +++++-- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh index 40b9f74677..2e65e69a0a 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh @@ -176,6 +176,36 @@ __device__ void write_vector( } } + +template +__device__ void write_vector_chunks( + uint8_t* dest, + const uint8_t* source, + uint32_t size_bytes, + uint32_t dest_offset = 0, + uint32_t source_offset = 0) +{ + const uint32_t lane_id = threadIdx.x % SubWarpSize; + + // Each thread handles 128-bit (16-byte) chunks + constexpr uint32_t chunk_bytes = sizeof(pq_vec_t); + + // Calculate starting byte position for this thread + uint32_t byte_pos = lane_id * chunk_bytes; + + for (; byte_pos < size_bytes; byte_pos += SubWarpSize * chunk_bytes) { + if (byte_pos + chunk_bytes <= size_bytes) { + *reinterpret_cast(dest + dest_offset + group_ix * chunk_bytes + ingroup_ix * chunk_bytes + byte_pos) = + *reinterpret_cast(source + source_offset + byte_pos); + } else { + uint32_t remaining = size_bytes - byte_pos; + for (uint32_t i = 0; i < remaining; i++) { + dest[dest_offset + byte_pos + i] = source[source_offset + byte_pos + i]; + } + } + } +} + /** Process the given indices or a block of a single list (cluster). */ template __device__ void run_on_list( diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh index 227cdb3558..5bb183e4ac 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh @@ -138,8 +138,15 @@ __launch_bounds__(BlockSize) static __global__ void pack_contiguous_list_data_ke uint32_t pq_dim, std::variant offset_or_indices) { - write_list( - list_data, offset_or_indices, n_rows, pq_dim, pack_contiguous(codes, pq_dim)); + using subwarp_align = raft::Pow2; + uint32_t stride = subwarp_align::div(blockDim.x); + uint32_t ix = subwarp_align::div(threadIdx.x + blockDim.x * blockIdx.x); + for (; ix < len; ix += stride) { + const uint32_t dst_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + ix + : std::get(offset_or_indices)[ix]; + write_vector(out_list_data, dst_ix, ix, pq_dim, action); + } } /** From 20f61923754a8536f6c2fee6fbcef24cfea1bdfd Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 30 Sep 2025 16:05:21 -0700 Subject: [PATCH 02/12] chunked copy kernel --- .../neighbors/ivf_pq/ivf_pq_codepacking.cuh | 51 +++++++++++-------- .../ivf_pq_contiguous_list_data_impl.cuh | 19 +++++-- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh index 2e65e69a0a..c13af7b7ca 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh @@ -177,32 +177,41 @@ __device__ void write_vector( } -template +template __device__ void write_vector_chunks( - uint8_t* dest, - const uint8_t* source, - uint32_t size_bytes, - uint32_t dest_offset = 0, - uint32_t source_offset = 0) + raft::device_mdspan::list_extents, raft::row_major> + out_list_data, + uint32_t out_ix, + const uint8_t* codes, + uint32_t pq_dim) { - const uint32_t lane_id = threadIdx.x % SubWarpSize; - - // Each thread handles 128-bit (16-byte) chunks + const uint32_t lane_id = raft::Pow2::mod(threadIdx.x); + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); constexpr uint32_t chunk_bytes = sizeof(pq_vec_t); - - // Calculate starting byte position for this thread - uint32_t byte_pos = lane_id * chunk_bytes; - - for (; byte_pos < size_bytes; byte_pos += SubWarpSize * chunk_bytes) { - if (byte_pos + chunk_bytes <= size_bytes) { - *reinterpret_cast(dest + dest_offset + group_ix * chunk_bytes + ingroup_ix * chunk_bytes + byte_pos) = - *reinterpret_cast(source + source_offset + byte_pos); - } else { - uint32_t remaining = size_bytes - byte_pos; - for (uint32_t i = 0; i < remaining; i++) { - dest[dest_offset + byte_pos + i] = source[source_offset + byte_pos + i]; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + constexpr uint32_t chunk_bits = chunk_bytes * 8u; + uint32_t n_chunks = raft::ceildiv(pq_dim * PqBits, chunk_bits); + const bool compute_last_chunk = raft::Pow2::mod(pq_dim * PqBits) == 0; + uint32_t n_copies = compute_last_chunk ? n_chunks - 1 : n_chunks; + for (uint32_t i = 0; i < n_copies; i++) { + uint32_t chunk_ix = i * kChunkSize; + *reinterpret_cast(&out_list_data(group_ix, chunk_ix, ingroup_ix, 0)) = + *reinterpret_cast(&codes[chunk_ix * chunk_bytes]); + } + if (compute_last_chunk) { + pq_vec_t code_chunk; + bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; + uint32_t chunk_ix = n_copies * kChunkSize; + for (uint32_t j = chunk_ix, i = 0; j < pq_dim; i++) { + code_chunk = *reinterpret_cast(&codes[chunk_ix * chunk_bytes]); + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + code_view[k] = codes[chunk_ix * chunk_bytes + k]; } } + *reinterpret_cast(&out_list_data(group_ix, chunk_ix, ingroup_ix, 0)) = + *reinterpret_cast(&codes[chunk_ix * chunk_bytes]); } } diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh index 5bb183e4ac..74e0f4b1bb 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh @@ -137,6 +137,19 @@ __launch_bounds__(BlockSize) static __global__ void pack_contiguous_list_data_ke uint32_t n_rows, uint32_t pq_dim, std::variant offset_or_indices) +{ + write_list( + list_data, offset_or_indices, n_rows, pq_dim, pack_contiguous(codes, pq_dim)); +} + +template +__launch_bounds__(BlockSize) static __global__ void copy_list_chunks_kernel( + raft::device_mdspan::list_extents, raft::row_major> + list_data, + const uint8_t* codes, + uint32_t n_rows, + uint32_t pq_dim, + std::variant offset_or_indices) { using subwarp_align = raft::Pow2; uint32_t stride = subwarp_align::div(blockDim.x); @@ -145,7 +158,7 @@ __launch_bounds__(BlockSize) static __global__ void pack_contiguous_list_data_ke const uint32_t dst_ix = std::holds_alternative(offset_or_indices) ? std::get(offset_or_indices) + ix : std::get(offset_or_indices)[ix]; - write_vector(out_list_data, dst_ix, ix, pq_dim, action); + copy_list_chunks(out_list_data, dst_ix, ix, pq_dim); } } @@ -177,11 +190,11 @@ inline void pack_contiguous_list_data_impl( dim3 threads(kBlockSize, 1, 1); auto kernel = [pq_bits]() { switch (pq_bits) { - case 4: return pack_contiguous_list_data_kernel; + case 4: return copy_list_chunks_kernel; case 5: return pack_contiguous_list_data_kernel; case 6: return pack_contiguous_list_data_kernel; case 7: return pack_contiguous_list_data_kernel; - case 8: return pack_contiguous_list_data_kernel; + case 8: return copy_list_chunks_kernel; default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); } }(); From 9ab79a3ed5b212d5e5f9f8ffad1d73c36f555d65 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 30 Sep 2025 16:59:38 -0700 Subject: [PATCH 03/12] rewrite contiguous codepacker --- .../neighbors/ivf_pq/ivf_pq_codepacking.cuh | 23 ++-- .../ivf_pq_contiguous_list_data_impl.cuh | 106 +++++++++++++++--- 2 files changed, 102 insertions(+), 27 deletions(-) diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh index c13af7b7ca..b211c8f392 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2023-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -176,7 +176,6 @@ __device__ void write_vector( } } - template __device__ void write_vector_chunks( raft::device_mdspan::list_extents, raft::row_major> @@ -185,16 +184,16 @@ __device__ void write_vector_chunks( const uint8_t* codes, uint32_t pq_dim) { - const uint32_t lane_id = raft::Pow2::mod(threadIdx.x); - using group_align = raft::Pow2; - const uint32_t group_ix = group_align::div(out_ix); - const uint32_t ingroup_ix = group_align::mod(out_ix); + const uint32_t lane_id = raft::Pow2::mod(threadIdx.x); + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); constexpr uint32_t chunk_bytes = sizeof(pq_vec_t); - constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; - constexpr uint32_t chunk_bits = chunk_bytes * 8u; - uint32_t n_chunks = raft::ceildiv(pq_dim * PqBits, chunk_bits); - const bool compute_last_chunk = raft::Pow2::mod(pq_dim * PqBits) == 0; - uint32_t n_copies = compute_last_chunk ? n_chunks - 1 : n_chunks; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + constexpr uint32_t chunk_bits = chunk_bytes * 8u; + uint32_t n_chunks = raft::ceildiv(pq_dim * PqBits, chunk_bits); + const bool compute_last_chunk = raft::Pow2::mod(pq_dim * PqBits) == 0; + uint32_t n_copies = compute_last_chunk ? n_chunks - 1 : n_chunks; for (uint32_t i = 0; i < n_copies; i++) { uint32_t chunk_ix = i * kChunkSize; *reinterpret_cast(&out_list_data(group_ix, chunk_ix, ingroup_ix, 0)) = @@ -210,7 +209,7 @@ __device__ void write_vector_chunks( code_view[k] = codes[chunk_ix * chunk_bytes + k]; } } - *reinterpret_cast(&out_list_data(group_ix, chunk_ix, ingroup_ix, 0)) = + *reinterpret_cast(&out_list_data(group_ix, chunk_ix, ingroup_ix, 0)) = *reinterpret_cast(&codes[chunk_ix * chunk_bytes]); } } diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh index 74e0f4b1bb..153618e66b 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh @@ -24,6 +24,10 @@ #include namespace cuvs::neighbors::ivf_pq::detail { + +using cuvs::neighbors::ivf_pq::kIndexGroupSize; +using cuvs::neighbors::ivf_pq::kIndexGroupVecLen; + /** * A consumer for the `run_on_vector` that just flattens PQ codes * into a tightly packed matrix. That is, the codes are not expanded to one code-per-byte. @@ -129,6 +133,70 @@ struct pack_contiguous { } }; +/** + * Pack a vector by copying chunks directly. + */ +template +__device__ inline void pack_vector_chunks( + raft::device_mdspan::list_extents, raft::row_major> + out_list_data, + uint32_t out_ix, + const uint8_t* codes, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); + + constexpr uint32_t chunk_bytes = sizeof(pq_vec_t); + constexpr uint32_t codes_per_chunk = (chunk_bytes * 8u) / PqBits; + uint32_t n_chunks = raft::ceildiv(pq_dim, codes_per_chunk); + + for (uint32_t i = 0; i < n_chunks; i++) { + pq_vec_t chunk; + if (i < n_chunks - 1 || (pq_dim % codes_per_chunk) == 0) { + chunk = *reinterpret_cast(codes + i * chunk_bytes); + } else { + chunk = pq_vec_t{}; + uint32_t occupied_bytes = raft::ceildiv((pq_dim % codes_per_chunk) * PqBits, 8); + auto* chunk_bytes_ptr = reinterpret_cast(&chunk); + for (uint32_t j = 0; j < occupied_bytes; j++) { + chunk_bytes_ptr[j] = codes[i * chunk_bytes + j]; + } + } + *reinterpret_cast(&out_list_data(group_ix, i, ingroup_ix, 0)) = chunk; + } +} + +/** + * Pack a vector by extracting each code (for unaligned cases: pq_bits = 5, 6, 7) + */ +template +__device__ inline void pack_vector( + raft::device_mdspan::list_extents, raft::row_major> + out_list_data, + uint32_t out_ix, + const uint8_t* codes, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); + + pq_vec_t code_chunk = pq_vec_t{}; + bitfield_view_t src_view{const_cast(codes)}; + bitfield_view_t dst_view{reinterpret_cast(&code_chunk)}; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + dst_view[k] = src_view[j]; + } + *reinterpret_cast(&out_list_data(group_ix, i, ingroup_ix, 0)) = code_chunk; + if (j < pq_dim) code_chunk = pq_vec_t{}; + } +} + template __launch_bounds__(BlockSize) static __global__ void pack_contiguous_list_data_kernel( raft::device_mdspan::list_extents, raft::row_major> @@ -143,7 +211,7 @@ __launch_bounds__(BlockSize) static __global__ void pack_contiguous_list_data_ke } template -__launch_bounds__(BlockSize) static __global__ void copy_list_chunks_kernel( +__launch_bounds__(BlockSize) static __global__ void pack_list_chunks_kernel( raft::device_mdspan::list_extents, raft::row_major> list_data, const uint8_t* codes, @@ -151,15 +219,23 @@ __launch_bounds__(BlockSize) static __global__ void copy_list_chunks_kernel( uint32_t pq_dim, std::variant offset_or_indices) { - using subwarp_align = raft::Pow2; - uint32_t stride = subwarp_align::div(blockDim.x); - uint32_t ix = subwarp_align::div(threadIdx.x + blockDim.x * blockIdx.x); - for (; ix < len; ix += stride) { - const uint32_t dst_ix = std::holds_alternative(offset_or_indices) - ? std::get(offset_or_indices) + ix - : std::get(offset_or_indices)[ix]; - copy_list_chunks(out_list_data, dst_ix, ix, pq_dim); - } + uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; + if (ix >= n_rows) return; + + const uint32_t dst_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + ix + : std::get(offset_or_indices)[ix]; + + const uint32_t code_size = raft::ceildiv(pq_dim * PqBits, 8); + const uint8_t* src_codes = codes + ix * code_size; + + if constexpr (PqBits == 4 || PqBits == 8) { + // aligned case: direct chunk copies + pack_vector_chunks(list_data, dst_ix, src_codes, pq_dim); + } else { + // unaligned case: extract each code + pack_vector(list_data, dst_ix, src_codes, pq_dim); + } } /** @@ -190,11 +266,11 @@ inline void pack_contiguous_list_data_impl( dim3 threads(kBlockSize, 1, 1); auto kernel = [pq_bits]() { switch (pq_bits) { - case 4: return copy_list_chunks_kernel; - case 5: return pack_contiguous_list_data_kernel; - case 6: return pack_contiguous_list_data_kernel; - case 7: return pack_contiguous_list_data_kernel; - case 8: return copy_list_chunks_kernel; + case 4: return pack_list_chunks_kernel; + case 5: return pack_list_chunks_kernel; + case 6: return pack_list_chunks_kernel; + case 7: return pack_list_chunks_kernel; + case 8: return pack_list_chunks_kernel; default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); } }(); From 0e4cc9c7b930cc9d01f2d48144f8903d781915b7 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 30 Sep 2025 17:15:59 -0700 Subject: [PATCH 04/12] rm pack_contiguous struct --- .../ivf_pq_contiguous_list_data_impl.cuh | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh index 153618e66b..5dbb4dedb8 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh @@ -107,34 +107,6 @@ inline void unpack_contiguous_list_data_impl( } /** - * A producer for the `write_vector` reads tightly packed flat codes. That is, - * the codes are not expanded to one code-per-byte. - */ -template -struct pack_contiguous { - const uint8_t* codes; - uint32_t code_size; - - /** - * Create a callable to be passed to `write_vector`. - * - * @param[in] codes flat compressed PQ codes - */ - __host__ __device__ inline pack_contiguous(const uint8_t* codes, uint32_t pq_dim) - : codes{codes}, code_size{raft::ceildiv(pq_dim * PqBits, 8)} - { - } - - /** Read j-th component (code) of the i-th vector from the source. */ - __host__ __device__ inline auto operator()(uint32_t i, uint32_t j) -> uint8_t - { - bitfield_view_t code_view{const_cast(codes + i * code_size)}; - return uint8_t(code_view[j]); - } -}; - -/** - * Pack a vector by copying chunks directly. */ template __device__ inline void pack_vector_chunks( From e53469269aea37e91431653c29966c8a9471de57 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 30 Sep 2025 17:18:14 -0700 Subject: [PATCH 05/12] rm unpack_contiguous struct --- .../ivf_pq_contiguous_list_data_impl.cuh | 27 ------------------- 1 file changed, 27 deletions(-) diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh index 5dbb4dedb8..2907b8b2af 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh @@ -28,33 +28,6 @@ namespace cuvs::neighbors::ivf_pq::detail { using cuvs::neighbors::ivf_pq::kIndexGroupSize; using cuvs::neighbors::ivf_pq::kIndexGroupVecLen; -/** - * A consumer for the `run_on_vector` that just flattens PQ codes - * into a tightly packed matrix. That is, the codes are not expanded to one code-per-byte. - */ -template -struct unpack_contiguous { - uint8_t* codes; - uint32_t code_size; - - /** - * Create a callable to be passed to `run_on_vector`. - * - * @param[in] codes flat compressed PQ codes - */ - __host__ __device__ inline unpack_contiguous(uint8_t* codes, uint32_t pq_dim) - : codes{codes}, code_size{raft::ceildiv(pq_dim * PqBits, 8)} - { - } - - /** Write j-th component (code) of the i-th vector into the output array. */ - __host__ __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - bitfield_view_t code_view{codes + i * code_size}; - code_view[j] = code; - } -}; - template __launch_bounds__(BlockSize) static __global__ void unpack_contiguous_list_data_kernel( uint8_t* out_codes, From c6366a1adb876ba0fc6749c35e36b07768c050d5 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 30 Sep 2025 17:24:45 -0700 Subject: [PATCH 06/12] unpack impl initial --- .../ivf_pq_contiguous_list_data_impl.cuh | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh index 2907b8b2af..c410999021 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh @@ -29,7 +29,7 @@ using cuvs::neighbors::ivf_pq::kIndexGroupSize; using cuvs::neighbors::ivf_pq::kIndexGroupVecLen; template -__launch_bounds__(BlockSize) static __global__ void unpack_contiguous_list_data_kernel( +__launch_bounds__(BlockSize) static __global__ void unpack_list_chunks_kernel( uint8_t* out_codes, raft::device_mdspan::list_extents, raft::row_major> in_list_data, @@ -37,8 +37,23 @@ __launch_bounds__(BlockSize) static __global__ void unpack_contiguous_list_data_ uint32_t pq_dim, std::variant offset_or_indices) { - run_on_list( - in_list_data, offset_or_indices, n_rows, pq_dim, unpack_contiguous(out_codes, pq_dim)); + uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; + if (ix >= n_rows) return; + + const uint32_t dst_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + ix + : std::get(offset_or_indices)[ix]; + + const uint32_t code_size = raft::ceildiv(pq_dim * PqBits, 8); + const uint8_t* src_codes = out_codes + ix * code_size; + + if constexpr (PqBits == 4 || PqBits == 8) { + // aligned case: direct chunk copies + unpack_vector_chunks(in_list_data, dst_ix, src_codes, pq_dim); + } else { + // unaligned case: extract each code + unpack_vector(in_list_data, dst_ix, src_codes, pq_dim); + } } /** @@ -142,19 +157,6 @@ __device__ inline void pack_vector( } } -template -__launch_bounds__(BlockSize) static __global__ void pack_contiguous_list_data_kernel( - raft::device_mdspan::list_extents, raft::row_major> - list_data, - const uint8_t* codes, - uint32_t n_rows, - uint32_t pq_dim, - std::variant offset_or_indices) -{ - write_list( - list_data, offset_or_indices, n_rows, pq_dim, pack_contiguous(codes, pq_dim)); -} - template __launch_bounds__(BlockSize) static __global__ void pack_list_chunks_kernel( raft::device_mdspan::list_extents, raft::row_major> From d54a1ebb9cf67983794dfa7c7aff7ddd8775bc1f Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 30 Sep 2025 17:26:05 -0700 Subject: [PATCH 07/12] rm write_vector_chunks --- .../neighbors/ivf_pq/ivf_pq_codepacking.cuh | 38 ------------------- 1 file changed, 38 deletions(-) diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh index b211c8f392..0c81fe9e3a 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh @@ -176,44 +176,6 @@ __device__ void write_vector( } } -template -__device__ void write_vector_chunks( - raft::device_mdspan::list_extents, raft::row_major> - out_list_data, - uint32_t out_ix, - const uint8_t* codes, - uint32_t pq_dim) -{ - const uint32_t lane_id = raft::Pow2::mod(threadIdx.x); - using group_align = raft::Pow2; - const uint32_t group_ix = group_align::div(out_ix); - const uint32_t ingroup_ix = group_align::mod(out_ix); - constexpr uint32_t chunk_bytes = sizeof(pq_vec_t); - constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; - constexpr uint32_t chunk_bits = chunk_bytes * 8u; - uint32_t n_chunks = raft::ceildiv(pq_dim * PqBits, chunk_bits); - const bool compute_last_chunk = raft::Pow2::mod(pq_dim * PqBits) == 0; - uint32_t n_copies = compute_last_chunk ? n_chunks - 1 : n_chunks; - for (uint32_t i = 0; i < n_copies; i++) { - uint32_t chunk_ix = i * kChunkSize; - *reinterpret_cast(&out_list_data(group_ix, chunk_ix, ingroup_ix, 0)) = - *reinterpret_cast(&codes[chunk_ix * chunk_bytes]); - } - if (compute_last_chunk) { - pq_vec_t code_chunk; - bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; - uint32_t chunk_ix = n_copies * kChunkSize; - for (uint32_t j = chunk_ix, i = 0; j < pq_dim; i++) { - code_chunk = *reinterpret_cast(&codes[chunk_ix * chunk_bytes]); - for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { - code_view[k] = codes[chunk_ix * chunk_bytes + k]; - } - } - *reinterpret_cast(&out_list_data(group_ix, chunk_ix, ingroup_ix, 0)) = - *reinterpret_cast(&codes[chunk_ix * chunk_bytes]); - } -} - /** Process the given indices or a block of a single list (cluster). */ template __device__ void run_on_list( From 8956df51546ba1889e487c3fd5e3d86be017ba85 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 30 Sep 2025 17:26:47 -0700 Subject: [PATCH 08/12] update codepacking file --- cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh index 0c81fe9e3a..40b9f74677 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_codepacking.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2025, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From f7c3a3915f4ba5d0089b8961162148a9603b8571 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 1 Oct 2025 12:19:09 -0700 Subject: [PATCH 09/12] unpack function: first impl --- .../ivf_pq_contiguous_list_data_impl.cuh | 81 ++++++++++++++++--- 1 file changed, 72 insertions(+), 9 deletions(-) diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh index c410999021..50420c39ad 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh @@ -15,6 +15,7 @@ */ #pragma once +#include "ivf_pq_build.cuh" #include "ivf_pq_codepacking.cuh" #include #include @@ -28,6 +29,68 @@ namespace cuvs::neighbors::ivf_pq::detail { using cuvs::neighbors::ivf_pq::kIndexGroupSize; using cuvs::neighbors::ivf_pq::kIndexGroupVecLen; +template +__device__ inline void unpack_vector_chunks( + raft::device_mdspan::list_extents, raft::row_major> + in_list_data, + uint32_t out_ix, + uint8_t* codes, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); + + constexpr uint32_t chunk_bytes = sizeof(pq_vec_t); + constexpr uint32_t codes_per_chunk = (chunk_bytes * 8u) / PqBits; + uint32_t n_chunks = raft::ceildiv(pq_dim, codes_per_chunk); + + for (uint32_t i = 0; i < n_chunks; i++) { + pq_vec_t chunk; + if (i < n_chunks - 1 || (pq_dim % codes_per_chunk) == 0) { + chunk = *reinterpret_cast(in_list_data(group_ix, i, ingroup_ix, 0)); + } else { + chunk = pq_vec_t{}; + uint32_t occupied_bytes = raft::ceildiv((pq_dim % codes_per_chunk) * PqBits, 8); + auto* chunk_bytes_ptr = reinterpret_cast(&chunk); + for (uint32_t j = 0; j < occupied_bytes; j++) { + chunk_bytes_ptr[j] =reinterpret_cast(in_list_data(group_ix, i, ingroup_ix, 0))[j]; + } + } + *reinterpret_cast(&codes[i * chunk_bytes + out_ix]) = chunk; + } +} + +/** + * Pack a vector by extracting each code (for unaligned cases: pq_bits = 5, 6, 7) + */ +template +__device__ inline void unpack_vector( + raft::device_mdspan::list_extents, raft::row_major> + in_list_data, + uint32_t out_ix, + uint8_t* codes, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); + + pq_vec_t code_chunk = pq_vec_t{}; + bitfield_view_t dst_view{reinterpret_cast(&code_chunk)}; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + constexpr uint32_t chunk_bytes = sizeof(pq_vec_t); + + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + bitfield_view_t src_view{reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0))}; + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + dst_view[k] = src_view[j]; + } + *reinterpret_cast(&codes[i * chunk_bytes + out_ix]) = code_chunk; + if (j < pq_dim) code_chunk = pq_vec_t{}; + } +} + template __launch_bounds__(BlockSize) static __global__ void unpack_list_chunks_kernel( uint8_t* out_codes, @@ -45,14 +108,14 @@ __launch_bounds__(BlockSize) static __global__ void unpack_list_chunks_kernel( : std::get(offset_or_indices)[ix]; const uint32_t code_size = raft::ceildiv(pq_dim * PqBits, 8); - const uint8_t* src_codes = out_codes + ix * code_size; + const uint8_t* dst_codes = out_codes + ix * code_size; if constexpr (PqBits == 4 || PqBits == 8) { // aligned case: direct chunk copies - unpack_vector_chunks(in_list_data, dst_ix, src_codes, pq_dim); + unpack_vector_chunks(in_list_data, dst_ix, dst_codes, pq_dim); } else { // unaligned case: extract each code - unpack_vector(in_list_data, dst_ix, src_codes, pq_dim); + unpack_vector(in_list_data, dst_ix, dst_codes, pq_dim); } } @@ -82,11 +145,11 @@ inline void unpack_contiguous_list_data_impl( dim3 threads(kBlockSize, 1, 1); auto kernel = [pq_bits]() { switch (pq_bits) { - case 4: return unpack_contiguous_list_data_kernel; - case 5: return unpack_contiguous_list_data_kernel; - case 6: return unpack_contiguous_list_data_kernel; - case 7: return unpack_contiguous_list_data_kernel; - case 8: return unpack_contiguous_list_data_kernel; + case 4: return unpack_list_chunks_kernel; + case 5: return unpack_list_chunks_kernel; + case 6: return unpack_list_chunks_kernel; + case 7: return unpack_list_chunks_kernel; + case 8: return unpack_list_chunks_kernel; default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); } }(); @@ -115,7 +178,7 @@ __device__ inline void pack_vector_chunks( for (uint32_t i = 0; i < n_chunks; i++) { pq_vec_t chunk; if (i < n_chunks - 1 || (pq_dim % codes_per_chunk) == 0) { - chunk = *reinterpret_cast(codes + i * chunk_bytes); + chunk = *reinterpret_cast(codes + i * pq_dim); } else { chunk = pq_vec_t{}; uint32_t occupied_bytes = raft::ceildiv((pq_dim % codes_per_chunk) * PqBits, 8); From cb187c66ea3b201bd4d3c401cea21e7d7a7ce38a Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 1 Oct 2025 12:32:14 -0700 Subject: [PATCH 10/12] fix unpack compilation --- .../ivf_pq/ivf_pq_contiguous_list_data_impl.cuh | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh index 50420c39ad..ccd0723689 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh @@ -54,10 +54,11 @@ __device__ inline void unpack_vector_chunks( uint32_t occupied_bytes = raft::ceildiv((pq_dim % codes_per_chunk) * PqBits, 8); auto* chunk_bytes_ptr = reinterpret_cast(&chunk); for (uint32_t j = 0; j < occupied_bytes; j++) { - chunk_bytes_ptr[j] =reinterpret_cast(in_list_data(group_ix, i, ingroup_ix, 0))[j]; + chunk_bytes_ptr[j] = + reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0))[j]; } } - *reinterpret_cast(&codes[i * chunk_bytes + out_ix]) = chunk; + *reinterpret_cast(codes + i * chunk_bytes) = chunk; } } @@ -78,15 +79,16 @@ __device__ inline void unpack_vector( pq_vec_t code_chunk = pq_vec_t{}; bitfield_view_t dst_view{reinterpret_cast(&code_chunk)}; - constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; constexpr uint32_t chunk_bytes = sizeof(pq_vec_t); for (uint32_t j = 0, i = 0; j < pq_dim; i++) { - bitfield_view_t src_view{reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0))}; + bitfield_view_t src_view{const_cast( + reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0)))}; for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { dst_view[k] = src_view[j]; } - *reinterpret_cast(&codes[i * chunk_bytes + out_ix]) = code_chunk; + *reinterpret_cast(codes + i * chunk_bytes) = code_chunk; if (j < pq_dim) code_chunk = pq_vec_t{}; } } @@ -108,7 +110,7 @@ __launch_bounds__(BlockSize) static __global__ void unpack_list_chunks_kernel( : std::get(offset_or_indices)[ix]; const uint32_t code_size = raft::ceildiv(pq_dim * PqBits, 8); - const uint8_t* dst_codes = out_codes + ix * code_size; + uint8_t* dst_codes = out_codes + ix * code_size; if constexpr (PqBits == 4 || PqBits == 8) { // aligned case: direct chunk copies @@ -178,7 +180,7 @@ __device__ inline void pack_vector_chunks( for (uint32_t i = 0; i < n_chunks; i++) { pq_vec_t chunk; if (i < n_chunks - 1 || (pq_dim % codes_per_chunk) == 0) { - chunk = *reinterpret_cast(codes + i * pq_dim); + chunk = *reinterpret_cast(codes + i * chunk_bytes); } else { chunk = pq_vec_t{}; uint32_t occupied_bytes = raft::ceildiv((pq_dim % codes_per_chunk) * PqBits, 8); From 48ec9b5e13273a1836e1428a66dd71718ac4d5a7 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Wed, 1 Oct 2025 13:18:03 -0700 Subject: [PATCH 11/12] first updates to ivf_pq_build.cuh --- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 323 +++++++++++++++------- 1 file changed, 217 insertions(+), 106 deletions(-) diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index a2556dc651..e09af1d7b3 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -485,30 +485,33 @@ void train_per_cluster(raft::resources const& handle, } /** - * A consumer for the `run_on_list` and `run_on_vector` that just flattens PQ codes - * one-per-byte. That is, independent of the code width (pq_bits), one code uses - * the whole byte, hence one vectors uses pq_dim bytes. + * Unpack codes from the packed list data for aligned cases (pq_bits = 4, 8) */ -struct unpack_codes { - raft::device_matrix_view out_codes; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[out] out_codes the destination for the read codes. - */ - __device__ inline unpack_codes( - raft::device_matrix_view out_codes) - : out_codes{out_codes} - { - } - - /** Write j-th component (code) of the i-th vector into the output array. */ - __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - out_codes(i, j) = code; +template +__device__ inline void unpack_codes_impl( + raft::device_mdspan::list_extents, raft::row_major> + in_list_data, + uint32_t in_ix, + uint8_t* out_codes, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(in_ix); + const uint32_t ingroup_ix = group_align::mod(in_ix); + + pq_vec_t code_chunk; + bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + // read the chunk + code_chunk = *reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0)); + // extract the codes, one at a time + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + out_codes[j] = code_view[k]; + } } -}; +} template __launch_bounds__(BlockSize) static __global__ void unpack_list_data_kernel( @@ -517,9 +520,17 @@ __launch_bounds__(BlockSize) static __global__ void unpack_list_data_kernel( in_list_data, std::variant offset_or_indices) { + uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; + if (ix >= out_codes.extent(0)) return; + + const uint32_t src_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + ix + : std::get(offset_or_indices)[ix]; + const uint32_t pq_dim = out_codes.extent(1); - auto unpack_action = unpack_codes{out_codes}; - run_on_list(in_list_data, offset_or_indices, out_codes.extent(0), pq_dim, unpack_action); + uint8_t* out_codes_ptr = &out_codes(ix, 0); + + unpack_codes_impl(in_list_data, src_ix, out_codes_ptr, pq_dim); } /** @@ -574,61 +585,52 @@ void unpack_list_data(raft::resources const& res, raft::resource::get_cuda_stream(res)); } -/** A consumer for the `run_on_list` and `run_on_vector` that approximates the original input data. +/** + * Reconstruct a vector from PQ codes. */ -struct reconstruct_vectors { - codebook_gen codebook_kind; - uint32_t cluster_ix; - uint32_t pq_len; - raft::device_mdspan, raft::row_major> pq_centers; - raft::device_mdspan, raft::row_major> centers_rot; - raft::device_mdspan, raft::row_major> out_vectors; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[out] out_vectors the destination for the decoded vectors. - * @param[in] pq_centers the codebook - * @param[in] centers_rot - * @param[in] codebook_kind - * @param[in] cluster_ix label/id of the cluster. - */ - __device__ inline reconstruct_vectors( - raft::device_matrix_view out_vectors, - raft::device_mdspan, raft::row_major> pq_centers, - raft::device_matrix_view centers_rot, - codebook_gen codebook_kind, - uint32_t cluster_ix) - : codebook_kind{codebook_kind}, - cluster_ix{cluster_ix}, - pq_len{pq_centers.extent(1)}, - pq_centers{pq_centers}, - centers_rot{reinterpret_vectors(centers_rot, pq_centers)}, - out_vectors{reinterpret_vectors(out_vectors, pq_centers)} - { - } - - /** - * Decode j-th component of the i-th vector by its code and write it into a chunk of the output - * vectors (pq_len elements). - */ - __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - uint32_t partition_ix; - switch (codebook_kind) { - case codebook_gen::PER_CLUSTER: { - partition_ix = cluster_ix; - } break; - case codebook_gen::PER_SUBSPACE: { - partition_ix = j; - } break; - default: __builtin_unreachable(); - } - for (uint32_t k = 0; k < pq_len; k++) { - out_vectors(i, j, k) = pq_centers(partition_ix, k, code) + centers_rot(cluster_ix, j, k); +template +__device__ inline void reconstruct_vector_impl( + raft::device_mdspan::list_extents, raft::row_major> + in_list_data, + uint32_t in_ix, + float* out_vector, + raft::device_mdspan, raft::row_major> pq_centers, + raft::device_mdspan, raft::row_major> centers_rot, + codebook_gen codebook_kind, + uint32_t cluster_ix, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(in_ix); + const uint32_t ingroup_ix = group_align::mod(in_ix); + const uint32_t pq_len = pq_centers.extent(1); + + pq_vec_t code_chunk; + bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + // read the chunk + code_chunk = *reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0)); + // reconstruct the codes + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + uint8_t code = code_view[k]; + uint32_t partition_ix; + switch (codebook_kind) { + case codebook_gen::PER_CLUSTER: { + partition_ix = cluster_ix; + } break; + case codebook_gen::PER_SUBSPACE: { + partition_ix = j; + } break; + default: __builtin_unreachable(); + } + for (uint32_t l = 0; l < pq_len; l++) { + out_vector[j * pq_len + l] = pq_centers(partition_ix, l, code) + centers_rot(cluster_ix, j, l); + } } } -}; +} template __launch_bounds__(BlockSize) static __global__ void reconstruct_list_data_kernel( @@ -641,11 +643,19 @@ __launch_bounds__(BlockSize) static __global__ void reconstruct_list_data_kernel uint32_t cluster_ix, std::variant offset_or_indices) { + uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; + if (ix >= out_vectors.extent(0)) return; + + const uint32_t src_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + ix + : std::get(offset_or_indices)[ix]; + const uint32_t pq_dim = out_vectors.extent(1) / pq_centers.extent(1); - auto reconstruct_action = - reconstruct_vectors{out_vectors, pq_centers, centers_rot, codebook_kind, cluster_ix}; - run_on_list( - in_list_data, offset_or_indices, out_vectors.extent(0), pq_dim, reconstruct_action); + float* out_vector_ptr = &out_vectors(ix, 0); + + auto centers_rot_3d = reinterpret_vectors(centers_rot, pq_centers); + reconstruct_vector_impl(in_list_data, src_ix, out_vector_ptr, pq_centers, + centers_rot_3d, codebook_kind, cluster_ix, pq_dim); } /** Decode the list data; see the public interface for the api and usage. */ @@ -733,27 +743,35 @@ void reconstruct_list_data(raft::resources const& res, } /** - * A producer for the `write_list` and `write_vector` reads the codes byte-by-byte. That is, - * independent of the code width (pq_bits), one code uses the whole byte, hence one vectors uses - * pq_dim bytes. + * Pack codes into the list data format. */ -struct pass_codes { - raft::device_matrix_view codes; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[in] codes the source codes. - */ - __device__ inline pass_codes( - raft::device_matrix_view codes) - : codes{codes} - { +template +__device__ inline void pack_codes_impl( + raft::device_mdspan::list_extents, raft::row_major> + out_list_data, + uint32_t out_ix, + const uint8_t* in_codes, + uint32_t pq_dim) +{ + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); + + pq_vec_t code_chunk; + bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + // clear the chunk + code_chunk = pq_vec_t{}; + // pack the codes + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + code_view[k] = in_codes[j]; + } + // write the chunk to the list + *reinterpret_cast(&out_list_data(group_ix, i, ingroup_ix, 0)) = code_chunk; } - - /** Read j-th component (code) of the i-th vector from the source. */ - __device__ inline auto operator()(uint32_t i, uint32_t j) const -> uint8_t { return codes(i, j); } -}; +} template __launch_bounds__(BlockSize) static __global__ void pack_list_data_kernel( @@ -762,8 +780,15 @@ __launch_bounds__(BlockSize) static __global__ void pack_list_data_kernel( raft::device_matrix_view codes, std::variant offset_or_indices) { - write_list( - list_data, offset_or_indices, codes.extent(0), codes.extent(1), pass_codes{codes}); + uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; + if (ix >= codes.extent(0)) return; + + const uint32_t dst_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + ix + : std::get(offset_or_indices)[ix]; + + const uint8_t* codes_ptr = &codes(ix, 0); + pack_codes_impl(list_data, dst_ix, codes_ptr, codes.extent(1)); } /** @@ -819,6 +844,85 @@ void pack_list_data(raft::resources const& res, raft::resource::get_cuda_stream(res)); } +/** + * Encode a vector on-the-fly by finding the closest PQ centers. + */ +template +__device__ inline void encode_vector_impl( + raft::device_mdspan::list_extents, raft::row_major> + out_list_data, + uint32_t out_ix, + uint32_t in_ix, + raft::device_mdspan, raft::row_major> pq_centers, + raft::device_matrix_view new_vectors, + codebook_gen codebook_kind, + uint32_t cluster_ix, + uint32_t pq_dim) +{ + const uint32_t lane_id = raft::Pow2::mod(threadIdx.x); + + using group_align = raft::Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); + const uint32_t pq_len = pq_centers.extent(1); + const uint32_t pq_book_size = pq_centers.extent(2); + + auto new_vectors_3d = reinterpret_vectors(new_vectors, pq_centers); + + pq_vec_t code_chunk; + bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + // clear the chunk + if (lane_id == 0) { code_chunk = pq_vec_t{}; } + // encode codes + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + uint32_t partition_ix; + switch (codebook_kind) { + case codebook_gen::PER_CLUSTER: { + partition_ix = cluster_ix; + } break; + case codebook_gen::PER_SUBSPACE: { + partition_ix = j; + } break; + default: __builtin_unreachable(); + } + + float min_dist = std::numeric_limits::infinity(); + uint8_t code = 0; + // calculate the distance for each PQ cluster, find the minimum for each thread + for (uint32_t l = lane_id; l < pq_book_size; l += SubWarpSize) { + float d = 0.0f; + for (uint32_t m = 0; m < pq_len; m++) { + auto t = new_vectors_3d(in_ix, j, m) - pq_centers(partition_ix, m, l); + d += t * t; + } + if (d < min_dist) { + min_dist = d; + code = uint8_t(l); + } + } + // reduce among threads + #pragma unroll + for (uint32_t stride = SubWarpSize >> 1; stride > 0; stride >>= 1) { + const auto other_dist = raft::shfl_xor(min_dist, stride, SubWarpSize); + const auto other_code = raft::shfl_xor(code, stride, SubWarpSize); + if (other_dist < min_dist) { + min_dist = other_dist; + code = other_code; + } + } + + if (lane_id == 0) { code_view[k] = code; } + } + // write the chunk to the list + if (lane_id == 0) { + *reinterpret_cast(&out_list_data(group_ix, i, ingroup_ix, 0)) = code_chunk; + } + } +} + template __launch_bounds__(BlockSize) static __global__ void encode_list_data_kernel( raft::device_mdspan::list_extents, raft::row_major> @@ -830,11 +934,18 @@ __launch_bounds__(BlockSize) static __global__ void encode_list_data_kernel( std::variant offset_or_indices) { constexpr uint32_t kSubWarpSize = std::min(raft::WarpSize, 1u << PqBits); - const uint32_t pq_dim = new_vectors.extent(1) / pq_centers.extent(1); - auto encode_action = - encode_vectors{pq_centers, new_vectors, codebook_kind, cluster_ix}; - write_list( - list_data, offset_or_indices, new_vectors.extent(0), pq_dim, encode_action); + const uint32_t warp_ix = (threadIdx.x + blockDim.x * blockIdx.x) / kSubWarpSize; + + if (warp_ix >= new_vectors.extent(0)) return; + + const uint32_t dst_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + warp_ix + : std::get(offset_or_indices)[warp_ix]; + + const uint32_t pq_dim = new_vectors.extent(1) / pq_centers.extent(1); + + encode_vector_impl(list_data, dst_ix, warp_ix, pq_centers, + new_vectors, codebook_kind, cluster_ix, pq_dim); } template From c7ba1a360f5c4225262c18e4d023488fccc5862d Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Thu, 2 Oct 2025 09:24:00 -0700 Subject: [PATCH 12/12] fix cuda memory error --- .../neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh index ccd0723689..32da500d85 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_contiguous_list_data_impl.cuh @@ -77,19 +77,15 @@ __device__ inline void unpack_vector( const uint32_t group_ix = group_align::div(out_ix); const uint32_t ingroup_ix = group_align::mod(out_ix); - pq_vec_t code_chunk = pq_vec_t{}; - bitfield_view_t dst_view{reinterpret_cast(&code_chunk)}; constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; - constexpr uint32_t chunk_bytes = sizeof(pq_vec_t); + bitfield_view_t dst_view{codes}; for (uint32_t j = 0, i = 0; j < pq_dim; i++) { bitfield_view_t src_view{const_cast( reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0)))}; for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { - dst_view[k] = src_view[j]; + dst_view[j] = src_view[k]; } - *reinterpret_cast(codes + i * chunk_bytes) = code_chunk; - if (j < pq_dim) code_chunk = pq_vec_t{}; } }