Skip to content

Commit

Permalink
add test with forced id expectations
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Jan 14, 2025
1 parent 0abf20a commit 21f94cd
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions cpp/test/random/rmat_rectangular_generator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,101 @@ const std::vector<RmatInputs> inputs = {
{18, 16, 200000, false, 456789ULL, TOLERANCE},
{18, 16, 200000, true, 456789ULL, TOLERANCE}};

struct RmatForcedOutputs {
size_t r_scale;
size_t c_scale;
size_t r_node_id;
size_t c_node_id;
};

class RmatGenForceTest : public ::testing::TestWithParam<RmatForcedOutputs> {
public:
RmatGenForceTest()
: handle{},
stream{resource::get_cuda_stream(handle)},
params{::testing::TestWithParam<RmatForcedOutputs>::GetParam()},
out{2, stream},
out_src{1, stream},
out_dst{1, stream},
theta{0, stream},
h_theta{},
state{0, GeneratorType::GenPC},
max_scale(std::max(params.r_scale, params.c_scale))
{
theta.resize(4 * max_scale, stream);
h_theta.resize(theta.size(), 0.f);
for (size_t bit_pos = 0; bit_pos < max_scale; ++bit_pos) {
size_t row_bit = ((params.r_node_id & (1 << bit_pos)) != 0);
size_t col_bit = ((params.c_node_id & (1 << bit_pos)) != 0);

// now force theta for bit -- 2x2 matrix row major
h_theta[4 * bit_pos + row_bit * 2 + col_bit] = 1.f;
}

raft::update_device(theta.data(), h_theta.data(), max_scale * 4, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
}

protected:
void SetUp() override
{
rmat_rectangular_gen(out.data(),
out_src.data(),
out_dst.data(),
theta.data(),
params.r_scale,
params.c_scale,
size_t(1),
stream,
state);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
}

void validate()
{
std::vector<size_t> h_out(2, size_t(0));
raft::update_host(h_out.data(), out.data(), 2, stream);
RAFT_CUDA_TRY(cudaGetLastError());
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));

std::vector<size_t> h_out_expect;
h_out_expect.push_back(params.r_node_id);
h_out_expect.push_back(params.c_node_id);

ASSERT_TRUE(hostVecMatch(h_out_expect, h_out, raft::Compare<size_t>()));
}

protected:
raft::resources handle;
cudaStream_t stream;

RmatForcedOutputs params;
size_t max_scale;
std::vector<float> h_theta;
rmm::device_uvector<size_t> out, out_src, out_dst;
rmm::device_uvector<float> theta;
RngState state;
};

const std::vector<RmatForcedOutputs> forcedInputs = {{16, 16, 12425, 1233},
{16, 16, 12, 424},
{5, 5, 15, 15},
{5, 6, 15, 15},
{5, 15, 15, 15},
{6, 5, 15, 15},
{15, 5, 15, 15},
{32, 16, 1253163, 60000},
{16, 16, 12, 0},
{16, 16, 0, 1255}};

TEST_P(RmatGenTest, Result) { validate(); }
INSTANTIATE_TEST_SUITE_P(RmatGenTests, RmatGenTest, ::testing::ValuesIn(inputs));

TEST_P(RmatGenMdspanTest, Result) { validate(); }
INSTANTIATE_TEST_SUITE_P(RmatGenMdspanTests, RmatGenMdspanTest, ::testing::ValuesIn(inputs));

TEST_P(RmatGenForceTest, Result) { validate(); }
INSTANTIATE_TEST_SUITE_P(RmatGenForceTests, RmatGenForceTest, ::testing::ValuesIn(forcedInputs));

} // namespace random
} // namespace raft

0 comments on commit 21f94cd

Please sign in to comment.