diff --git a/scripts/collect_cpp_coverage.sh b/scripts/collect_cpp_coverage.sh index 7c56f5f2f..4f4c8a5d1 100644 --- a/scripts/collect_cpp_coverage.sh +++ b/scripts/collect_cpp_coverage.sh @@ -27,6 +27,7 @@ lcov --remove ${COVERAGE_DIR}/coverage.info \ 'build/*' \ 'tests/*' \ '*/expected.hpp' \ + '*_test.cpp' \ --ignore-errors inconsistent,inconsistent \ --output-file ${COVERAGE_DIR}/coverage.info lcov --list ${COVERAGE_DIR}/coverage.info \ diff --git a/tests/test_hgraph.cpp b/tests/test_hgraph.cpp index b2e72cee2..e3aea77f6 100644 --- a/tests/test_hgraph.cpp +++ b/tests/test_hgraph.cpp @@ -364,25 +364,37 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Serialize File", auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); auto metric_type = GENERATE("l2", "ip", "cosine"); - std::string base_quantization_str = GENERATE("sq8", "fp32"); const std::string name = "hgraph"; auto search_param = fmt::format(search_param_tmp, 200); for (auto& dim : dims) { - vsag::Options::Instance().set_block_size_limit(size); - auto param = GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str); - auto index = TestFactory(name, param, true); + for (auto& [base_quantization_str, recall] : test_cases) { + vsag::Options::Instance().set_block_size_limit(size); + auto param = + GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str); + auto index = TestFactory(name, param, true); - if (index->CheckFeature(vsag::SUPPORT_BUILD)) { - auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); - TestBuildIndex(index, dataset, true); - if (index->CheckFeature(vsag::SUPPORT_SERIALIZE_FILE) and - index->CheckFeature(vsag::SUPPORT_DESERIALIZE_FILE)) { - auto index2 = TestFactory(name, param, true); - TestSerializeFile(index, index2, dataset, search_param, true); + if (index->CheckFeature(vsag::SUPPORT_BUILD)) { + auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); + TestBuildIndex(index, dataset, true); + if (index->CheckFeature(vsag::SUPPORT_SERIALIZE_FILE) and + index->CheckFeature(vsag::SUPPORT_DESERIALIZE_FILE)) { + auto index2 = TestFactory(name, param, true); + TestSerializeFile(index, index2, dataset, search_param, true); + } + if (index->CheckFeature(vsag::SUPPORT_SERIALIZE_BINARY_SET) and + index->CheckFeature(vsag::SUPPORT_DESERIALIZE_BINARY_SET)) { + auto index2 = TestFactory(name, param, true); + TestSerializeBinarySet(index, index2, dataset, search_param, true); + } + if (index->CheckFeature(vsag::SUPPORT_SERIALIZE_FILE) and + index->CheckFeature(vsag::SUPPORT_DESERIALIZE_READER_SET)) { + auto index2 = TestFactory(name, param, true); + TestSerializeReaderSet(index, index2, dataset, search_param, name, true); + } } + vsag::Options::Instance().set_block_size_limit(origin_size); } - vsag::Options::Instance().set_block_size_limit(origin_size); } } @@ -393,20 +405,21 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); auto metric_type = GENERATE("l2", "ip", "cosine"); - std::string base_quantization_str = GENERATE("sq8", "fp32"); const std::string name = "hgraph"; for (auto& dim : dims) { - vsag::Options::Instance().set_block_size_limit(size); - auto param = - GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str, 1); - auto index = vsag::Factory::CreateIndex(name, param, allocator.get()); - if (not index.has_value()) { - continue; + for (auto& [base_quantization_str, recall] : test_cases) { + vsag::Options::Instance().set_block_size_limit(size); + auto param = + GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str, 1); + auto index = vsag::Factory::CreateIndex(name, param, allocator.get()); + if (not index.has_value()) { + continue; + } + auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); + TestContinueAddIgnoreRequire(index.value(), dataset); + vsag::Options::Instance().set_block_size_limit(origin_size); } - auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); - TestContinueAddIgnoreRequire(index.value(), dataset); } - vsag::Options::Instance().set_block_size_limit(origin_size); } TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Duplicate Build", "[ft][hgraph]") { diff --git a/tests/test_index.cpp b/tests/test_index.cpp index 4e2f769dc..832b8a0b0 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -351,6 +351,7 @@ TestIndex::TestCalcDistanceById(const IndexPtr& index, const TestDatasetPtr& dat } } } + void TestIndex::TestSerializeFile(const IndexPtr& index_from, const IndexPtr& index_to, @@ -390,6 +391,80 @@ TestIndex::TestSerializeFile(const IndexPtr& index_from, } } +void +TestIndex::TestSerializeBinarySet(const IndexPtr& index_from, + const IndexPtr& index_to, + const TestDatasetPtr& dataset, + const std::string& search_param, + bool expected_success) { + auto serialize_binary = index_from->Serialize(); + REQUIRE(serialize_binary.has_value() == expected_success); + + auto deserialize_index = index_to->Deserialize(serialize_binary.value()); + REQUIRE(deserialize_index.has_value() == expected_success); + + const auto& queries = dataset->query_; + auto query_count = queries->GetNumElements(); + auto dim = queries->GetDim(); + auto topk = 10; + for (auto i = 0; i < query_count; ++i) { + auto query = vsag::Dataset::Make(); + query->NumElements(1) + ->Dim(dim) + ->Float32Vectors(queries->GetFloat32Vectors() + i * dim) + ->Owner(false); + auto res_from = index_from->KnnSearch(query, topk, search_param); + auto res_to = index_to->KnnSearch(query, topk, search_param); + REQUIRE(res_from.has_value()); + REQUIRE(res_to.has_value()); + REQUIRE(res_from.value()->GetDim() == res_to.value()->GetDim()); + for (auto j = 0; j < topk; ++j) { + REQUIRE(res_to.value()->GetIds()[j] == res_from.value()->GetIds()[j]); + } + } +} + +void +TestIndex::TestSerializeReaderSet(const IndexPtr& index_from, + const IndexPtr& index_to, + const TestDatasetPtr& dataset, + const std::string& search_param, + const std::string& index_name, + bool expected_success) { + auto dir = fixtures::TempDir("serialize"); + auto path = dir.GenerateRandomFile(); + std::ofstream outfile(path, std::ios::out | std::ios::binary); + auto serialize_index = index_from->Serialize(outfile); + REQUIRE(serialize_index.has_value() == expected_success); + outfile.close(); + + vsag::ReaderSet rs; + auto reader = vsag::Factory::CreateLocalFileReader(path, 0, 0); + rs.Set(index_name, reader); + auto deserialize_index = index_to->Deserialize(rs); + REQUIRE(deserialize_index.has_value() == expected_success); + + const auto& queries = dataset->query_; + auto query_count = queries->GetNumElements(); + auto dim = queries->GetDim(); + auto topk = 10; + for (auto i = 0; i < query_count; ++i) { + auto query = vsag::Dataset::Make(); + query->NumElements(1) + ->Dim(dim) + ->Float32Vectors(queries->GetFloat32Vectors() + i * dim) + ->Owner(false); + auto res_from = index_from->KnnSearch(query, topk, search_param); + auto res_to = index_to->KnnSearch(query, topk, search_param); + REQUIRE(res_from.has_value()); + REQUIRE(res_to.has_value()); + REQUIRE(res_from.value()->GetDim() == res_to.value()->GetDim()); + for (auto j = 0; j < topk; ++j) { + REQUIRE(res_to.value()->GetIds()[j] == res_from.value()->GetIds()[j]); + } + } +} + void TestIndex::TestConcurrentAdd(const TestIndex::IndexPtr& index, const TestDatasetPtr& dataset, diff --git a/tests/test_index.h b/tests/test_index.h index ba6c017ec..33bf0e834 100644 --- a/tests/test_index.h +++ b/tests/test_index.h @@ -118,10 +118,19 @@ class TestIndex { bool expected_success = true); static void - TestSerializeBinary(const IndexPtr& index, - const TestDatasetPtr& dataset, - const std::string& path, - bool expected_success = true){}; + TestSerializeBinarySet(const IndexPtr& index_from, + const IndexPtr& index_to, + const TestDatasetPtr& dataset, + const std::string& search_param, + bool expected_success = true); + + static void + TestSerializeReaderSet(const IndexPtr& index_from, + const IndexPtr& index_to, + const TestDatasetPtr& dataset, + const std::string& search_param, + const std::string& index_name, + bool expected_success = true); static void TestConcurrentKnnSearch(const IndexPtr& index,