Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
323 changes: 217 additions & 106 deletions cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -485,30 +485,33 @@ void train_per_cluster(raft::resources const& handle,
}

/**
* A consumer for the `run_on_list` and `run_on_vector` that just flattens PQ codes
* one-per-byte. That is, independent of the code width (pq_bits), one code uses
* the whole byte, hence one vectors uses pq_dim bytes.
* Unpack codes from the packed list data for aligned cases (pq_bits = 4, 8)
*/
struct unpack_codes {
raft::device_matrix_view<uint8_t, uint32_t, raft::row_major> out_codes;

/**
* Create a callable to be passed to `run_on_list`.
*
* @param[out] out_codes the destination for the read codes.
*/
__device__ inline unpack_codes(
raft::device_matrix_view<uint8_t, uint32_t, raft::row_major> out_codes)
: out_codes{out_codes}
{
}

/** Write j-th component (code) of the i-th vector into the output array. */
__device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j)
{
out_codes(i, j) = code;
template <uint32_t PqBits>
__device__ inline void unpack_codes_impl(
raft::device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, raft::row_major>
in_list_data,
uint32_t in_ix,
uint8_t* out_codes,
uint32_t pq_dim)
{
using group_align = raft::Pow2<kIndexGroupSize>;
const uint32_t group_ix = group_align::div(in_ix);
const uint32_t ingroup_ix = group_align::mod(in_ix);

pq_vec_t code_chunk;
bitfield_view_t<PqBits> code_view{reinterpret_cast<uint8_t*>(&code_chunk)};
constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits;

for (uint32_t j = 0, i = 0; j < pq_dim; i++) {
// read the chunk
code_chunk = *reinterpret_cast<const pq_vec_t*>(&in_list_data(group_ix, i, ingroup_ix, 0));
// extract the codes, one at a time
for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) {
out_codes[j] = code_view[k];
}
}
};
}

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) static __global__ void unpack_list_data_kernel(
Expand All @@ -517,9 +520,17 @@ __launch_bounds__(BlockSize) static __global__ void unpack_list_data_kernel(
in_list_data,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x;
if (ix >= out_codes.extent(0)) return;

const uint32_t src_ix = std::holds_alternative<uint32_t>(offset_or_indices)
? std::get<uint32_t>(offset_or_indices) + ix
: std::get<const uint32_t*>(offset_or_indices)[ix];

const uint32_t pq_dim = out_codes.extent(1);
auto unpack_action = unpack_codes{out_codes};
run_on_list<PqBits>(in_list_data, offset_or_indices, out_codes.extent(0), pq_dim, unpack_action);
uint8_t* out_codes_ptr = &out_codes(ix, 0);

unpack_codes_impl<PqBits>(in_list_data, src_ix, out_codes_ptr, pq_dim);
}

/**
Expand Down Expand Up @@ -574,61 +585,52 @@ void unpack_list_data(raft::resources const& res,
raft::resource::get_cuda_stream(res));
}

/** A consumer for the `run_on_list` and `run_on_vector` that approximates the original input data.
/**
* Reconstruct a vector from PQ codes.
*/
struct reconstruct_vectors {
codebook_gen codebook_kind;
uint32_t cluster_ix;
uint32_t pq_len;
raft::device_mdspan<const float, raft::extent_3d<uint32_t>, raft::row_major> pq_centers;
raft::device_mdspan<const float, raft::extent_3d<uint32_t>, raft::row_major> centers_rot;
raft::device_mdspan<float, raft::extent_3d<uint32_t>, raft::row_major> out_vectors;

/**
* Create a callable to be passed to `run_on_list`.
*
* @param[out] out_vectors the destination for the decoded vectors.
* @param[in] pq_centers the codebook
* @param[in] centers_rot
* @param[in] codebook_kind
* @param[in] cluster_ix label/id of the cluster.
*/
__device__ inline reconstruct_vectors(
raft::device_matrix_view<float, uint32_t, raft::row_major> out_vectors,
raft::device_mdspan<const float, raft::extent_3d<uint32_t>, raft::row_major> pq_centers,
raft::device_matrix_view<const float, uint32_t, raft::row_major> centers_rot,
codebook_gen codebook_kind,
uint32_t cluster_ix)
: codebook_kind{codebook_kind},
cluster_ix{cluster_ix},
pq_len{pq_centers.extent(1)},
pq_centers{pq_centers},
centers_rot{reinterpret_vectors(centers_rot, pq_centers)},
out_vectors{reinterpret_vectors(out_vectors, pq_centers)}
{
}

/**
* Decode j-th component of the i-th vector by its code and write it into a chunk of the output
* vectors (pq_len elements).
*/
__device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j)
{
uint32_t partition_ix;
switch (codebook_kind) {
case codebook_gen::PER_CLUSTER: {
partition_ix = cluster_ix;
} break;
case codebook_gen::PER_SUBSPACE: {
partition_ix = j;
} break;
default: __builtin_unreachable();
}
for (uint32_t k = 0; k < pq_len; k++) {
out_vectors(i, j, k) = pq_centers(partition_ix, k, code) + centers_rot(cluster_ix, j, k);
template <uint32_t PqBits>
__device__ inline void reconstruct_vector_impl(
raft::device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, raft::row_major>
in_list_data,
uint32_t in_ix,
float* out_vector,
raft::device_mdspan<const float, raft::extent_3d<uint32_t>, raft::row_major> pq_centers,
raft::device_mdspan<const float, raft::extent_3d<uint32_t>, raft::row_major> centers_rot,
codebook_gen codebook_kind,
uint32_t cluster_ix,
uint32_t pq_dim)
{
using group_align = raft::Pow2<kIndexGroupSize>;
const uint32_t group_ix = group_align::div(in_ix);
const uint32_t ingroup_ix = group_align::mod(in_ix);
const uint32_t pq_len = pq_centers.extent(1);

pq_vec_t code_chunk;
bitfield_view_t<PqBits> code_view{reinterpret_cast<uint8_t*>(&code_chunk)};
constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits;

for (uint32_t j = 0, i = 0; j < pq_dim; i++) {
// read the chunk
code_chunk = *reinterpret_cast<const pq_vec_t*>(&in_list_data(group_ix, i, ingroup_ix, 0));
// reconstruct the codes
for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) {
uint8_t code = code_view[k];
uint32_t partition_ix;
switch (codebook_kind) {
case codebook_gen::PER_CLUSTER: {
partition_ix = cluster_ix;
} break;
case codebook_gen::PER_SUBSPACE: {
partition_ix = j;
} break;
default: __builtin_unreachable();
}
for (uint32_t l = 0; l < pq_len; l++) {
out_vector[j * pq_len + l] = pq_centers(partition_ix, l, code) + centers_rot(cluster_ix, j, l);
}
}
}
};
}

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) static __global__ void reconstruct_list_data_kernel(
Expand All @@ -641,11 +643,19 @@ __launch_bounds__(BlockSize) static __global__ void reconstruct_list_data_kernel
uint32_t cluster_ix,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x;
if (ix >= out_vectors.extent(0)) return;

const uint32_t src_ix = std::holds_alternative<uint32_t>(offset_or_indices)
? std::get<uint32_t>(offset_or_indices) + ix
: std::get<const uint32_t*>(offset_or_indices)[ix];

const uint32_t pq_dim = out_vectors.extent(1) / pq_centers.extent(1);
auto reconstruct_action =
reconstruct_vectors{out_vectors, pq_centers, centers_rot, codebook_kind, cluster_ix};
run_on_list<PqBits>(
in_list_data, offset_or_indices, out_vectors.extent(0), pq_dim, reconstruct_action);
float* out_vector_ptr = &out_vectors(ix, 0);

auto centers_rot_3d = reinterpret_vectors(centers_rot, pq_centers);
reconstruct_vector_impl<PqBits>(in_list_data, src_ix, out_vector_ptr, pq_centers,
centers_rot_3d, codebook_kind, cluster_ix, pq_dim);
}

/** Decode the list data; see the public interface for the api and usage. */
Expand Down Expand Up @@ -733,27 +743,35 @@ void reconstruct_list_data(raft::resources const& res,
}

/**
* A producer for the `write_list` and `write_vector` reads the codes byte-by-byte. That is,
* independent of the code width (pq_bits), one code uses the whole byte, hence one vectors uses
* pq_dim bytes.
* Pack codes into the list data format.
*/
struct pass_codes {
raft::device_matrix_view<const uint8_t, uint32_t, raft::row_major> codes;

/**
* Create a callable to be passed to `run_on_list`.
*
* @param[in] codes the source codes.
*/
__device__ inline pass_codes(
raft::device_matrix_view<const uint8_t, uint32_t, raft::row_major> codes)
: codes{codes}
{
template <uint32_t PqBits>
__device__ inline void pack_codes_impl(
raft::device_mdspan<uint8_t, list_spec<uint32_t, uint32_t>::list_extents, raft::row_major>
out_list_data,
uint32_t out_ix,
const uint8_t* in_codes,
uint32_t pq_dim)
{
using group_align = raft::Pow2<kIndexGroupSize>;
const uint32_t group_ix = group_align::div(out_ix);
const uint32_t ingroup_ix = group_align::mod(out_ix);

pq_vec_t code_chunk;
bitfield_view_t<PqBits> code_view{reinterpret_cast<uint8_t*>(&code_chunk)};
constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits;

for (uint32_t j = 0, i = 0; j < pq_dim; i++) {
// clear the chunk
code_chunk = pq_vec_t{};
// pack the codes
for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) {
code_view[k] = in_codes[j];
}
// write the chunk to the list
*reinterpret_cast<pq_vec_t*>(&out_list_data(group_ix, i, ingroup_ix, 0)) = code_chunk;
}

/** Read j-th component (code) of the i-th vector from the source. */
__device__ inline auto operator()(uint32_t i, uint32_t j) const -> uint8_t { return codes(i, j); }
};
}

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) static __global__ void pack_list_data_kernel(
Expand All @@ -762,8 +780,15 @@ __launch_bounds__(BlockSize) static __global__ void pack_list_data_kernel(
raft::device_matrix_view<const uint8_t, uint32_t, raft::row_major> codes,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
write_list<PqBits, 1>(
list_data, offset_or_indices, codes.extent(0), codes.extent(1), pass_codes{codes});
uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x;
if (ix >= codes.extent(0)) return;

const uint32_t dst_ix = std::holds_alternative<uint32_t>(offset_or_indices)
? std::get<uint32_t>(offset_or_indices) + ix
: std::get<const uint32_t*>(offset_or_indices)[ix];

const uint8_t* codes_ptr = &codes(ix, 0);
pack_codes_impl<PqBits>(list_data, dst_ix, codes_ptr, codes.extent(1));
}

/**
Expand Down Expand Up @@ -819,6 +844,85 @@ void pack_list_data(raft::resources const& res,
raft::resource::get_cuda_stream(res));
}

/**
* Encode a vector on-the-fly by finding the closest PQ centers.
*/
template <uint32_t PqBits, uint32_t SubWarpSize>
__device__ inline void encode_vector_impl(
raft::device_mdspan<uint8_t, list_spec<uint32_t, uint32_t>::list_extents, raft::row_major>
out_list_data,
uint32_t out_ix,
uint32_t in_ix,
raft::device_mdspan<const float, raft::extent_3d<uint32_t>, raft::row_major> pq_centers,
raft::device_matrix_view<const float, uint32_t, raft::row_major> new_vectors,
codebook_gen codebook_kind,
uint32_t cluster_ix,
uint32_t pq_dim)
{
const uint32_t lane_id = raft::Pow2<SubWarpSize>::mod(threadIdx.x);

using group_align = raft::Pow2<kIndexGroupSize>;
const uint32_t group_ix = group_align::div(out_ix);
const uint32_t ingroup_ix = group_align::mod(out_ix);
const uint32_t pq_len = pq_centers.extent(1);
const uint32_t pq_book_size = pq_centers.extent(2);

auto new_vectors_3d = reinterpret_vectors(new_vectors, pq_centers);

pq_vec_t code_chunk;
bitfield_view_t<PqBits> code_view{reinterpret_cast<uint8_t*>(&code_chunk)};
constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits;

for (uint32_t j = 0, i = 0; j < pq_dim; i++) {
// clear the chunk
if (lane_id == 0) { code_chunk = pq_vec_t{}; }
// encode codes
for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) {
uint32_t partition_ix;
switch (codebook_kind) {
case codebook_gen::PER_CLUSTER: {
partition_ix = cluster_ix;
} break;
case codebook_gen::PER_SUBSPACE: {
partition_ix = j;
} break;
default: __builtin_unreachable();
}

float min_dist = std::numeric_limits<float>::infinity();
uint8_t code = 0;
// calculate the distance for each PQ cluster, find the minimum for each thread
for (uint32_t l = lane_id; l < pq_book_size; l += SubWarpSize) {
float d = 0.0f;
for (uint32_t m = 0; m < pq_len; m++) {
auto t = new_vectors_3d(in_ix, j, m) - pq_centers(partition_ix, m, l);
d += t * t;
}
if (d < min_dist) {
min_dist = d;
code = uint8_t(l);
}
}
// reduce among threads
#pragma unroll
for (uint32_t stride = SubWarpSize >> 1; stride > 0; stride >>= 1) {
const auto other_dist = raft::shfl_xor(min_dist, stride, SubWarpSize);
const auto other_code = raft::shfl_xor(code, stride, SubWarpSize);
if (other_dist < min_dist) {
min_dist = other_dist;
code = other_code;
}
}

if (lane_id == 0) { code_view[k] = code; }
}
// write the chunk to the list
if (lane_id == 0) {
*reinterpret_cast<pq_vec_t*>(&out_list_data(group_ix, i, ingroup_ix, 0)) = code_chunk;
}
}
}

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) static __global__ void encode_list_data_kernel(
raft::device_mdspan<uint8_t, list_spec<uint32_t, uint32_t>::list_extents, raft::row_major>
Expand All @@ -830,11 +934,18 @@ __launch_bounds__(BlockSize) static __global__ void encode_list_data_kernel(
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
constexpr uint32_t kSubWarpSize = std::min<uint32_t>(raft::WarpSize, 1u << PqBits);
const uint32_t pq_dim = new_vectors.extent(1) / pq_centers.extent(1);
auto encode_action =
encode_vectors<kSubWarpSize, uint32_t>{pq_centers, new_vectors, codebook_kind, cluster_ix};
write_list<PqBits, kSubWarpSize>(
list_data, offset_or_indices, new_vectors.extent(0), pq_dim, encode_action);
const uint32_t warp_ix = (threadIdx.x + blockDim.x * blockIdx.x) / kSubWarpSize;

if (warp_ix >= new_vectors.extent(0)) return;

const uint32_t dst_ix = std::holds_alternative<uint32_t>(offset_or_indices)
? std::get<uint32_t>(offset_or_indices) + warp_ix
: std::get<const uint32_t*>(offset_or_indices)[warp_ix];

const uint32_t pq_dim = new_vectors.extent(1) / pq_centers.extent(1);

encode_vector_impl<PqBits, kSubWarpSize>(list_data, dst_ix, warp_ix, pq_centers,
new_vectors, codebook_kind, cluster_ix, pq_dim);
}

template <typename T, typename IdxT>
Expand Down
Loading