Skip to content

Commit

Permalink
[Feat] Support bitset_to_csr (#2523)
Browse files Browse the repository at this point in the history
This API, `bitset_to_csr,` will be utilized to implement the `bitset'- based filter for prefiltered Brute Force.

Authors:
  - rhdong (https://github.com/rhdong)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Micka (https://github.com/lowener)

URL: #2523
  • Loading branch information
rhdong authored Jan 15, 2025
1 parent 32e918b commit a7f1916
Show file tree
Hide file tree
Showing 14 changed files with 1,199 additions and 126 deletions.
111 changes: 77 additions & 34 deletions cpp/bench/prims/linalg/masked_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/random/rng.cuh>
#include <raft/sparse/linalg/masked_matmul.hpp>
#include <raft/sparse/linalg/masked_matmul.cuh>
#include <raft/util/itertools.hpp>

#include <cusparse_v2.h>
Expand All @@ -49,11 +49,14 @@ inline auto operator<<(std::ostream& os, const MaskedMatmulBenchParams<value_t>&
{
os << " m*k*n=" << params.m << "*" << params.k << "*" << params.n
<< "\tsparsity=" << params.sparsity;
if (params.sparsity == 1.0) { os << "<-inner product for comparison"; }
if (params.sparsity == 0.0) { os << "<-inner product for comparison"; }
return os;
}

template <typename value_t, typename index_t = int64_t, typename bitmap_t = uint32_t>
template <typename value_t,
bool bitmap_or_bitset = true,
typename index_t = int64_t,
typename bits_t = uint32_t>
struct MaskedMatmulBench : public fixture {
MaskedMatmulBench(const MaskedMatmulBenchParams<value_t>& p)
: fixture(true),
Expand All @@ -64,15 +67,15 @@ struct MaskedMatmulBench : public fixture {
c_indptr_d(0, stream),
c_indices_d(0, stream),
c_data_d(0, stream),
bitmap_d(0, stream),
bits_d(0, stream),
c_dense_data_d(0, stream)
{
index_t element = raft::ceildiv(index_t(params.m * params.n), index_t(sizeof(bitmap_t) * 8));
std::vector<bitmap_t> bitmap_h(element);
index_t element = raft::ceildiv(index_t(params.m * params.n), index_t(sizeof(bits_t) * 8));
std::vector<bits_t> bits_h(element);

a_data_d.resize(params.m * params.k, stream);
b_data_d.resize(params.k * params.n, stream);
bitmap_d.resize(element, stream);
bits_d.resize(element, stream);

raft::random::RngState rng(2024ULL);
raft::random::uniform(
Expand All @@ -82,7 +85,13 @@ struct MaskedMatmulBench : public fixture {

std::vector<bool> c_dense_data_h(params.m * params.n);

c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bitmap_h);
if constexpr (bitmap_or_bitset) {
c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bits_h);
} else {
c_true_nnz = create_sparse_matrix(1, params.n, params.sparsity, bits_h);
repeat_cpu_bitset_inplace(bits_h, params.n, params.m - 1);
c_true_nnz *= params.m;
}

std::vector<value_t> values(c_true_nnz);
std::vector<index_t> indices(c_true_nnz);
Expand All @@ -93,24 +102,49 @@ struct MaskedMatmulBench : public fixture {
c_indices_d.resize(c_true_nnz, stream);
c_dense_data_d.resize(params.m * params.n, stream);

cpu_convert_to_csr(bitmap_h, params.m, params.n, indices, indptr);
cpu_convert_to_csr(bits_h, params.m, params.n, indices, indptr);
RAFT_EXPECTS(c_true_nnz == c_indices_d.size(),
"Something wrong. The c_true_nnz != c_indices_d.size()!");

update_device(c_data_d.data(), values.data(), c_true_nnz, stream);
update_device(c_indices_d.data(), indices.data(), c_true_nnz, stream);
update_device(c_indptr_d.data(), indptr.data(), params.m + 1, stream);
update_device(bitmap_d.data(), bitmap_h.data(), element, stream);
update_device(bits_d.data(), bits_h.data(), element, stream);
}

void repeat_cpu_bitset_inplace(std::vector<bits_t>& inout, size_t input_bits, size_t repeat)
{
size_t output_bit_index = input_bits;

for (size_t r = 0; r < repeat; ++r) {
for (size_t i = 0; i < input_bits; ++i) {
size_t input_unit_index = i / (sizeof(bits_t) * 8);
size_t input_bit_offset = i % (sizeof(bits_t) * 8);
bool bit = (inout[input_unit_index] >> input_bit_offset) & 1;

size_t output_unit_index = output_bit_index / (sizeof(bits_t) * 8);
size_t output_bit_offset = output_bit_index % (sizeof(bits_t) * 8);

inout[output_unit_index] |= (static_cast<bits_t>(bit) << output_bit_offset);

++output_bit_index;
}
}
}

index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector<bitmap_t>& bitmap)
index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector<bits_t>& bits)
{
index_t total = static_cast<index_t>(m * n);
index_t num_ones = static_cast<index_t>((total * 1.0f) * sparsity);
index_t num_ones = static_cast<index_t>((total * 1.0f) * (1.0f - sparsity));
index_t res = num_ones;

for (auto& item : bitmap) {
item = static_cast<bitmap_t>(0);
if (sparsity == 0.0f) {
std::fill(bits.begin(), bits.end(), 0xffffffff);
return num_ones;
}

for (auto& item : bits) {
item = static_cast<bits_t>(0);
}

std::random_device rd;
Expand All @@ -120,8 +154,8 @@ struct MaskedMatmulBench : public fixture {
while (num_ones > 0) {
index_t index = dis(gen);

bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))];
index_t bit_position = index % (8 * sizeof(bitmap_t));
bits_t& element = bits[index / (8 * sizeof(bits_t))];
index_t bit_position = index % (8 * sizeof(bits_t));

if (((element >> bit_position) & 1) == 0) {
element |= (static_cast<index_t>(1) << bit_position);
Expand All @@ -131,7 +165,7 @@ struct MaskedMatmulBench : public fixture {
return res;
}

void cpu_convert_to_csr(std::vector<bitmap_t>& bitmap,
void cpu_convert_to_csr(std::vector<bits_t>& bits,
index_t rows,
index_t cols,
std::vector<index_t>& indices,
Expand All @@ -142,14 +176,14 @@ struct MaskedMatmulBench : public fixture {
indptr[offset_indptr++] = 0;

index_t index = 0;
bitmap_t element = 0;
bits_t element = 0;
index_t bit_position = 0;

for (index_t i = 0; i < rows; ++i) {
for (index_t j = 0; j < cols; ++j) {
index = i * cols + j;
element = bitmap[index / (8 * sizeof(bitmap_t))];
bit_position = index % (8 * sizeof(bitmap_t));
element = bits[index / (8 * sizeof(bits_t))];
bit_position = index % (8 * sizeof(bits_t));

if (((element >> bit_position) & 1)) {
indices[offset_values] = static_cast<index_t>(j);
Expand Down Expand Up @@ -181,13 +215,17 @@ struct MaskedMatmulBench : public fixture {
params.n,
static_cast<index_t>(c_indices_d.size()));

auto mask =
raft::core::bitmap_view<const bitmap_t, index_t>(bitmap_d.data(), params.m, params.n);

auto c = raft::make_device_csr_matrix_view<value_t>(c_data_d.data(), c_structure);

if (params.sparsity < 1.0) {
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
if (params.sparsity > 0.0) {
if constexpr (bitmap_or_bitset) {
auto mask =
raft::core::bitmap_view<const bits_t, index_t>(bits_d.data(), params.m, params.n);
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
} else {
auto mask = raft::core::bitset_view<const bits_t, index_t>(bits_d.data(), params.n);
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
}
} else {
raft::distance::pairwise_distance(handle,
a_data_d.data(),
Expand All @@ -201,12 +239,16 @@ struct MaskedMatmulBench : public fixture {
}
resource::sync_stream(handle);

raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
resource::sync_stream(handle);

loop_on_state(state, [this, &a, &b, &mask, &c]() {
if (params.sparsity < 1.0) {
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
loop_on_state(state, [this, &a, &b, &c]() {
if (params.sparsity > 0.0) {
if constexpr (bitmap_or_bitset) {
auto mask =
raft::core::bitmap_view<const bits_t, index_t>(bits_d.data(), params.m, params.n);
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
} else {
auto mask = raft::core::bitset_view<const bits_t, index_t>(bits_d.data(), params.n);
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
}
} else {
raft::distance::pairwise_distance(handle,
a_data_d.data(),
Expand All @@ -228,7 +270,7 @@ struct MaskedMatmulBench : public fixture {

rmm::device_uvector<value_t> a_data_d;
rmm::device_uvector<value_t> b_data_d;
rmm::device_uvector<bitmap_t> bitmap_d;
rmm::device_uvector<bits_t> bits_d;

rmm::device_uvector<value_t> c_dense_data_d;

Expand All @@ -253,7 +295,7 @@ static std::vector<MaskedMatmulBenchParams<value_t>> getInputs()
raft::util::itertools::product<TestParams>({size_t(10), size_t(1024)},
{size_t(128), size_t(1024)},
{size_t(1024 * 1024)},
{0.01f, 0.1f, 0.2f, 0.5f, 1.0f});
{0.99f, 0.9f, 0.8f, 0.5f, 0.0f});

param_vec.reserve(params_group.size());
for (TestParams params : params_group) {
Expand All @@ -263,6 +305,7 @@ static std::vector<MaskedMatmulBenchParams<value_t>> getInputs()
return param_vec;
}

RAFT_BENCH_REGISTER((MaskedMatmulBench<float>), "", getInputs<float>());
RAFT_BENCH_REGISTER((MaskedMatmulBench<float, true>), "", getInputs<float>());
RAFT_BENCH_REGISTER((MaskedMatmulBench<float, false>), "", getInputs<float>());

} // namespace raft::bench::linalg
Loading

0 comments on commit a7f1916

Please sign in to comment.