Skip to content

Commit

Permalink
add test to increase coverage rate (#327)
Browse files Browse the repository at this point in the history
- remove *_test.cpp for coverage
- improve hgraph

Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
  • Loading branch information
LHT129 authored Jan 15, 2025
1 parent e23c784 commit b663e67
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 26 deletions.
1 change: 1 addition & 0 deletions scripts/collect_cpp_coverage.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ lcov --remove ${COVERAGE_DIR}/coverage.info \
'build/*' \
'tests/*' \
'*/expected.hpp' \
'*_test.cpp' \
--ignore-errors inconsistent,inconsistent \
--output-file ${COVERAGE_DIR}/coverage.info
lcov --list ${COVERAGE_DIR}/coverage.info \
Expand Down
57 changes: 35 additions & 22 deletions tests/test_hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,25 +364,37 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Serialize File",
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
std::string base_quantization_str = GENERATE("sq8", "fp32");
const std::string name = "hgraph";
auto search_param = fmt::format(search_param_tmp, 200);

for (auto& dim : dims) {
vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str);
auto index = TestFactory(name, param, true);
for (auto& [base_quantization_str, recall] : test_cases) {
vsag::Options::Instance().set_block_size_limit(size);
auto param =
GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str);
auto index = TestFactory(name, param, true);

if (index->CheckFeature(vsag::SUPPORT_BUILD)) {
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestBuildIndex(index, dataset, true);
if (index->CheckFeature(vsag::SUPPORT_SERIALIZE_FILE) and
index->CheckFeature(vsag::SUPPORT_DESERIALIZE_FILE)) {
auto index2 = TestFactory(name, param, true);
TestSerializeFile(index, index2, dataset, search_param, true);
if (index->CheckFeature(vsag::SUPPORT_BUILD)) {
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestBuildIndex(index, dataset, true);
if (index->CheckFeature(vsag::SUPPORT_SERIALIZE_FILE) and
index->CheckFeature(vsag::SUPPORT_DESERIALIZE_FILE)) {
auto index2 = TestFactory(name, param, true);
TestSerializeFile(index, index2, dataset, search_param, true);
}
if (index->CheckFeature(vsag::SUPPORT_SERIALIZE_BINARY_SET) and
index->CheckFeature(vsag::SUPPORT_DESERIALIZE_BINARY_SET)) {
auto index2 = TestFactory(name, param, true);
TestSerializeBinarySet(index, index2, dataset, search_param, true);
}
if (index->CheckFeature(vsag::SUPPORT_SERIALIZE_FILE) and
index->CheckFeature(vsag::SUPPORT_DESERIALIZE_READER_SET)) {
auto index2 = TestFactory(name, param, true);
TestSerializeReaderSet(index, index2, dataset, search_param, name, true);
}
}
vsag::Options::Instance().set_block_size_limit(origin_size);
}
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}

Expand All @@ -393,20 +405,21 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex,
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
std::string base_quantization_str = GENERATE("sq8", "fp32");
const std::string name = "hgraph";
for (auto& dim : dims) {
vsag::Options::Instance().set_block_size_limit(size);
auto param =
GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str, 1);
auto index = vsag::Factory::CreateIndex(name, param, allocator.get());
if (not index.has_value()) {
continue;
for (auto& [base_quantization_str, recall] : test_cases) {
vsag::Options::Instance().set_block_size_limit(size);
auto param =
GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str, 1);
auto index = vsag::Factory::CreateIndex(name, param, allocator.get());
if (not index.has_value()) {
continue;
}
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestContinueAddIgnoreRequire(index.value(), dataset);
vsag::Options::Instance().set_block_size_limit(origin_size);
}
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestContinueAddIgnoreRequire(index.value(), dataset);
}
vsag::Options::Instance().set_block_size_limit(origin_size);
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Duplicate Build", "[ft][hgraph]") {
Expand Down
75 changes: 75 additions & 0 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ TestIndex::TestCalcDistanceById(const IndexPtr& index, const TestDatasetPtr& dat
}
}
}

void
TestIndex::TestSerializeFile(const IndexPtr& index_from,
const IndexPtr& index_to,
Expand Down Expand Up @@ -390,6 +391,80 @@ TestIndex::TestSerializeFile(const IndexPtr& index_from,
}
}

void
TestIndex::TestSerializeBinarySet(const IndexPtr& index_from,
const IndexPtr& index_to,
const TestDatasetPtr& dataset,
const std::string& search_param,
bool expected_success) {
auto serialize_binary = index_from->Serialize();
REQUIRE(serialize_binary.has_value() == expected_success);

auto deserialize_index = index_to->Deserialize(serialize_binary.value());
REQUIRE(deserialize_index.has_value() == expected_success);

const auto& queries = dataset->query_;
auto query_count = queries->GetNumElements();
auto dim = queries->GetDim();
auto topk = 10;
for (auto i = 0; i < query_count; ++i) {
auto query = vsag::Dataset::Make();
query->NumElements(1)
->Dim(dim)
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
->Owner(false);
auto res_from = index_from->KnnSearch(query, topk, search_param);
auto res_to = index_to->KnnSearch(query, topk, search_param);
REQUIRE(res_from.has_value());
REQUIRE(res_to.has_value());
REQUIRE(res_from.value()->GetDim() == res_to.value()->GetDim());
for (auto j = 0; j < topk; ++j) {
REQUIRE(res_to.value()->GetIds()[j] == res_from.value()->GetIds()[j]);
}
}
}

void
TestIndex::TestSerializeReaderSet(const IndexPtr& index_from,
const IndexPtr& index_to,
const TestDatasetPtr& dataset,
const std::string& search_param,
const std::string& index_name,
bool expected_success) {
auto dir = fixtures::TempDir("serialize");
auto path = dir.GenerateRandomFile();
std::ofstream outfile(path, std::ios::out | std::ios::binary);
auto serialize_index = index_from->Serialize(outfile);
REQUIRE(serialize_index.has_value() == expected_success);
outfile.close();

vsag::ReaderSet rs;
auto reader = vsag::Factory::CreateLocalFileReader(path, 0, 0);
rs.Set(index_name, reader);
auto deserialize_index = index_to->Deserialize(rs);
REQUIRE(deserialize_index.has_value() == expected_success);

const auto& queries = dataset->query_;
auto query_count = queries->GetNumElements();
auto dim = queries->GetDim();
auto topk = 10;
for (auto i = 0; i < query_count; ++i) {
auto query = vsag::Dataset::Make();
query->NumElements(1)
->Dim(dim)
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
->Owner(false);
auto res_from = index_from->KnnSearch(query, topk, search_param);
auto res_to = index_to->KnnSearch(query, topk, search_param);
REQUIRE(res_from.has_value());
REQUIRE(res_to.has_value());
REQUIRE(res_from.value()->GetDim() == res_to.value()->GetDim());
for (auto j = 0; j < topk; ++j) {
REQUIRE(res_to.value()->GetIds()[j] == res_from.value()->GetIds()[j]);
}
}
}

void
TestIndex::TestConcurrentAdd(const TestIndex::IndexPtr& index,
const TestDatasetPtr& dataset,
Expand Down
17 changes: 13 additions & 4 deletions tests/test_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,19 @@ class TestIndex {
bool expected_success = true);

static void
TestSerializeBinary(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& path,
bool expected_success = true){};
TestSerializeBinarySet(const IndexPtr& index_from,
const IndexPtr& index_to,
const TestDatasetPtr& dataset,
const std::string& search_param,
bool expected_success = true);

static void
TestSerializeReaderSet(const IndexPtr& index_from,
const IndexPtr& index_to,
const TestDatasetPtr& dataset,
const std::string& search_param,
const std::string& index_name,
bool expected_success = true);

static void
TestConcurrentKnnSearch(const IndexPtr& index,
Expand Down

0 comments on commit b663e67

Please sign in to comment.