From c5dcbe42515ce7f8eb8d817a8abde86f3d0dc525 Mon Sep 17 00:00:00 2001 From: Ishan Chattopadhyaya Date: Mon, 30 Jun 2025 23:42:08 +0530 Subject: [PATCH] Adding overloaded ios::app to serialize() methods to CAGRA, Brute Force & HNSW --- cpp/cmake/modules/ConfigureCUDA.cmake | 1 + cpp/include/cuvs/neighbors/brute_force.h | 28 +++ cpp/include/cuvs/neighbors/brute_force.hpp | 8 +- cpp/include/cuvs/neighbors/cagra.h | 61 ++++++- cpp/include/cuvs/neighbors/cagra.hpp | 16 +- cpp/src/neighbors/brute_force_c.cpp | 22 ++- cpp/src/neighbors/brute_force_serialize.cu | 39 ++++- cpp/src/neighbors/cagra_c.cpp | 77 +++++++-- cpp/src/neighbors/cagra_serialize.cuh | 11 +- .../detail/cagra/cagra_serialize.cuh | 12 +- cpp/tests/neighbors/ann_cagra_c.cu | 160 ++++++++++++++++++ cpp/tests/neighbors/brute_force_c.cu | 79 +++++++++ 12 files changed, 476 insertions(+), 38 deletions(-) diff --git a/cpp/cmake/modules/ConfigureCUDA.cmake b/cpp/cmake/modules/ConfigureCUDA.cmake index 0b6ebbaad2..93e0eb3505 100644 --- a/cpp/cmake/modules/ConfigureCUDA.cmake +++ b/cpp/cmake/modules/ConfigureCUDA.cmake @@ -45,6 +45,7 @@ if(CUDA_LOG_COMPILE_TIME) endif() list(APPEND CUVS_CUDA_FLAGS --expt-extended-lambda --expt-relaxed-constexpr) +list(APPEND CUVS_CUDA_FLAGS -allow-unsupported-compiler) list(APPEND CUVS_CXX_FLAGS "-DCUDA_API_PER_THREAD_DEFAULT_STREAM") list(APPEND CUVS_CUDA_FLAGS "-DCUDA_API_PER_THREAD_DEFAULT_STREAM") # make sure we produce smallest binary size diff --git a/cpp/include/cuvs/neighbors/brute_force.h b/cpp/include/cuvs/neighbors/brute_force.h index 91893e7d93..743e90b3f9 100644 --- a/cpp/include/cuvs/neighbors/brute_force.h +++ b/cpp/include/cuvs/neighbors/brute_force.h @@ -184,6 +184,34 @@ cuvsError_t cuvsBruteForceSearch(cuvsResources_t res, * cuvsError_t res_create_status = cuvsResourcesCreate(&res); * * // create an index with `cuvsBruteforceBuild` + * cuvsBruteForceSerializeWithMode(res, "/path/to/index", index, 'w'); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] filename the file name for saving the index + * @param[in] index BRUTEFORCE index + * @param[in] file_mode File mode: 'w' for write (ios::out), 'a' for append (ios::app) + * + */ +cuvsError_t cuvsBruteForceSerializeWithMode(cuvsResources_t res, + const char* filename, + cuvsBruteForceIndex_t index, + char file_mode); + +/** + * Save the index to file (backward compatibility version - writes to file). + * The serialization format can be subject to changes, therefore loading + * an index saved with a previous version of cuvs is not guaranteed + * to work. + * + * @code{.c} + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // create an index with `cuvsBruteforceBuild` * cuvsBruteForceSerialize(res, "/path/to/index", index); * @endcode * diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index dade5d8c0e..9049d67b50 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -731,11 +731,13 @@ void search(raft::resources const& handle, * @param[in] index brute force index * @param[in] include_dataset whether to include the dataset in the serialized * output + * @param[in] file_mode File mode: 'w' for write (ios::out), 'a' for append (ios::app) */ void serialize(raft::resources const& handle, const std::string& filename, const cuvs::neighbors::brute_force::index& index, - bool include_dataset = true); + bool include_dataset = true, + char file_mode = 'w'); /** * Save the index to file. * The serialization format can be subject to changes, therefore loading @@ -761,12 +763,14 @@ void serialize(raft::resources const& handle, * @param[in] index brute force index * @param[in] include_dataset whether to include the dataset in the serialized * output + * @param[in] file_mode File mode: 'w' for write (ios::out), 'a' for append (ios::app) * */ void serialize(raft::resources const& handle, const std::string& filename, const cuvs::neighbors::brute_force::index& index, - bool include_dataset = true); + bool include_dataset = true, + char file_mode = 'w'); /** * Write the index to an output stream diff --git a/cpp/include/cuvs/neighbors/cagra.h b/cpp/include/cuvs/neighbors/cagra.h index 5959124870..7c21abb039 100644 --- a/cpp/include/cuvs/neighbors/cagra.h +++ b/cpp/include/cuvs/neighbors/cagra.h @@ -560,6 +560,35 @@ cuvsError_t cuvsCagraSearch(cuvsResources_t res, * cuvsError_t res_create_status = cuvsResourcesCreate(&res); * * // create an index with `cuvsCagraBuild` + * cuvsCagraSerializeWithMode(res, "/path/to/index", index, true, 'w'); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] filename the file name for saving the index + * @param[in] index CAGRA index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + * @param[in] file_mode File mode: 'w' for write (ios::out), 'a' for append (ios::app) + * + */ +cuvsError_t cuvsCagraSerializeWithMode(cuvsResources_t res, + const char* filename, + cuvsCagraIndex_t index, + bool include_dataset, + char file_mode); + +/** + * Save the index to file (backward compatibility version - writes to file). + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.c} + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // create an index with `cuvsCagraBuild` * cuvsCagraSerialize(res, "/path/to/index", index, true); * @endcode * @@ -590,7 +619,7 @@ cuvsError_t cuvsCagraSerialize(cuvsResources_t res, * cuvsError_t res_create_status = cuvsResourcesCreate(&res); * * // create an index with `cuvsCagraBuild` - * cuvsCagraSerializeHnswlib(res, "/path/to/index", index); + * cuvsCagraSerializeToHnswlib(res, "/path/to/index", index); * @endcode * * @param[in] res cuvsResources_t opaque C handle @@ -602,6 +631,36 @@ cuvsError_t cuvsCagraSerializeToHnswlib(cuvsResources_t res, const char* filename, cuvsCagraIndex_t index); +/** + * Save the CAGRA index to file in hnswlib format with file mode control. + * NOTE: The saved index can only be read by the hnswlib wrapper in cuVS, + * as the serialization format is not compatible with the original hnswlib. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.c} + * #include + * #include + * + * // Create cuvsResources_t + * cuvsResources_t res; + * cuvsError_t res_create_status = cuvsResourcesCreate(&res); + * + * // create an index with `cuvsCagraBuild` + * cuvsCagraSerializeToHnswlibWithMode(res, "/path/to/index", index, 'w'); + * @endcode + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] filename the file name for saving the index + * @param[in] index CAGRA index + * @param[in] file_mode File mode: 'w' for write (ios::out), 'a' for append (ios::app) + * + */ +cuvsError_t cuvsCagraSerializeToHnswlibWithMode(cuvsResources_t res, + const char* filename, + cuvsCagraIndex_t index, + char file_mode); + /** * Load index from file. * diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 48ee128271..115851f60c 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -1294,12 +1294,14 @@ void search(raft::resources const& res, * @param[in] filename the file name for saving the index * @param[in] index CAGRA index * @param[in] include_dataset Whether or not to write out the dataset to the file. + * @param[in] file_mode File mode: 'w' for write (ios::out), 'a' for append (ios::app) * */ void serialize(raft::resources const& handle, const std::string& filename, const cuvs::neighbors::cagra::index& index, - bool include_dataset = true); + bool include_dataset = true, + char file_mode = 'w'); /** * Load index from file. @@ -1399,12 +1401,14 @@ void deserialize(raft::resources const& handle, * @param[in] filename the file name for saving the index * @param[in] index CAGRA index * @param[in] include_dataset Whether or not to write out the dataset to the file. + * @param[in] file_mode File mode: 'w' for write (ios::out), 'a' for append (ios::app) * */ void serialize(raft::resources const& handle, const std::string& filename, const cuvs::neighbors::cagra::index& index, - bool include_dataset = true); + bool include_dataset = true, + char file_mode = 'w'); /** * Load index from file. @@ -1505,11 +1509,13 @@ void deserialize(raft::resources const& handle, * @param[in] filename the file name for saving the index * @param[in] index CAGRA index * @param[in] include_dataset Whether or not to write out the dataset to the file. + * @param[in] file_mode File mode: 'w' for write (ios::out), 'a' for append (ios::app) */ void serialize(raft::resources const& handle, const std::string& filename, const cuvs::neighbors::cagra::index& index, - bool include_dataset = true); + bool include_dataset = true, + char file_mode = 'w'); /** * Load index from file. @@ -1610,11 +1616,13 @@ void deserialize(raft::resources const& handle, * @param[in] filename the file name for saving the index * @param[in] index CAGRA index * @param[in] include_dataset Whether or not to write out the dataset to the file. + * @param[in] file_mode File mode: 'w' for write (ios::out), 'a' for append (ios::app) */ void serialize(raft::resources const& handle, const std::string& filename, const cuvs::neighbors::cagra::index& index, - bool include_dataset = true); + bool include_dataset = true, + char file_mode = 'w'); /** * Load index from file. diff --git a/cpp/src/neighbors/brute_force_c.cpp b/cpp/src/neighbors/brute_force_c.cpp index 85a46c5b86..e91e88650a 100644 --- a/cpp/src/neighbors/brute_force_c.cpp +++ b/cpp/src/neighbors/brute_force_c.cpp @@ -107,11 +107,11 @@ void _search(cuvsResources_t res, } template -void _serialize(cuvsResources_t res, const char* filename, cuvsBruteForceIndex index) +void _serialize(cuvsResources_t res, const char* filename, cuvsBruteForceIndex index, char file_mode) { auto res_ptr = reinterpret_cast(res); auto index_ptr = reinterpret_cast*>(index.addr); - cuvs::neighbors::brute_force::serialize(*res_ptr, std::string(filename), *index_ptr); + cuvs::neighbors::brute_force::serialize(*res_ptr, std::string(filename), *index_ptr, true, file_mode); } template @@ -263,17 +263,25 @@ extern "C" cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res, }); } -extern "C" cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res, - const char* filename, - cuvsBruteForceIndex_t index) +extern "C" cuvsError_t cuvsBruteForceSerializeWithMode(cuvsResources_t res, + const char* filename, + cuvsBruteForceIndex_t index, + char file_mode) { return cuvs::core::translate_exceptions([=] { if (index->dtype.code == kDLFloat && index->dtype.bits == 32) { - _serialize(res, filename, *index); + _serialize(res, filename, *index, file_mode); } else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) { - _serialize(res, filename, *index); + _serialize(res, filename, *index, file_mode); } else { RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits); } }); } + +extern "C" cuvsError_t cuvsBruteForceSerialize(cuvsResources_t res, + const char* filename, + cuvsBruteForceIndex_t index) +{ + return cuvsBruteForceSerializeWithMode(res, filename, index, 'w'); +} diff --git a/cpp/src/neighbors/brute_force_serialize.cu b/cpp/src/neighbors/brute_force_serialize.cu index 1b5b5111e9..2092833cd1 100644 --- a/cpp/src/neighbors/brute_force_serialize.cu +++ b/cpp/src/neighbors/brute_force_serialize.cu @@ -55,9 +55,16 @@ void serialize(raft::resources const& handle, void serialize(raft::resources const& handle, const std::string& filename, const index& index, - bool include_dataset) + bool include_dataset, + char file_mode) { - auto os = std::ofstream{filename, std::ios::out | std::ios::binary}; + std::ios::openmode mode = std::ios::binary; + if (file_mode == 'a') { + mode |= std::ios::app; + } else { + mode |= std::ios::out; + } + auto os = std::ofstream{filename, mode}; RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str()); serialize(handle, os, index, include_dataset); } @@ -65,9 +72,16 @@ void serialize(raft::resources const& handle, void serialize(raft::resources const& handle, const std::string& filename, const index& index, - bool include_dataset) + bool include_dataset, + char file_mode) { - auto os = std::ofstream{filename, std::ios::out | std::ios::binary}; + std::ios::openmode mode = std::ios::binary; + if (file_mode == 'a') { + mode |= std::ios::app; + } else { + mode |= std::ios::out; + } + auto os = std::ofstream{filename, mode}; RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str()); serialize(handle, os, index, include_dataset); } @@ -88,6 +102,23 @@ void serialize(raft::resources const& handle, serialize(handle, os, index, include_dataset); } +// Backward compatibility functions - use default 'w' mode +void serialize(raft::resources const& handle, + const std::string& filename, + const index& index, + bool include_dataset) +{ + serialize(handle, filename, index, include_dataset, 'w'); +} + +void serialize(raft::resources const& handle, + const std::string& filename, + const index& index, + bool include_dataset) +{ + serialize(handle, filename, index, include_dataset, 'w'); +} + template auto deserialize(raft::resources const& handle, std::istream& is) { diff --git a/cpp/src/neighbors/cagra_c.cpp b/cpp/src/neighbors/cagra_c.cpp index 656724826e..a792e00f98 100644 --- a/cpp/src/neighbors/cagra_c.cpp +++ b/cpp/src/neighbors/cagra_c.cpp @@ -226,11 +226,12 @@ template void _serialize(cuvsResources_t res, const char* filename, cuvsCagraIndex_t index, - bool include_dataset) + bool include_dataset, + char file_mode) { auto res_ptr = reinterpret_cast(res); auto index_ptr = reinterpret_cast*>(index->addr); - cuvs::neighbors::cagra::serialize(*res_ptr, std::string(filename), *index_ptr, include_dataset); + cuvs::neighbors::cagra::serialize(*res_ptr, std::string(filename), *index_ptr, include_dataset, file_mode); } template @@ -632,41 +633,83 @@ extern "C" cuvsError_t cuvsCagraDeserialize(cuvsResources_t res, }); } -extern "C" cuvsError_t cuvsCagraSerialize(cuvsResources_t res, - const char* filename, - cuvsCagraIndex_t index, - bool include_dataset) +extern "C" cuvsError_t cuvsCagraSerializeWithMode(cuvsResources_t res, + const char* filename, + cuvsCagraIndex_t index, + bool include_dataset, + char file_mode) { return cuvs::core::translate_exceptions([=] { if (index->dtype.code == kDLFloat && index->dtype.bits == 32) { - _serialize(res, filename, index, include_dataset); + _serialize(res, filename, index, include_dataset, file_mode); } else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) { - _serialize(res, filename, index, include_dataset); + _serialize(res, filename, index, include_dataset, file_mode); } else if (index->dtype.code == kDLInt && index->dtype.bits == 8) { - _serialize(res, filename, index, include_dataset); + _serialize(res, filename, index, include_dataset, file_mode); } else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) { - _serialize(res, filename, index, include_dataset); + _serialize(res, filename, index, include_dataset, file_mode); } else { RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits); } }); } -extern "C" cuvsError_t cuvsCagraSerializeToHnswlib(cuvsResources_t res, - const char* filename, - cuvsCagraIndex_t index) +template +void _serialize_to_hnswlib_with_mode(cuvsResources_t res, const char* filename, cuvsCagraIndex_t index, char file_mode) +{ + auto res_ptr = reinterpret_cast(res); + auto index_ptr = reinterpret_cast*>(index->addr); + + // Convert file mode to std::ios flags + std::ios_base::openmode mode = std::ios::binary; + if (file_mode == 'w') { + mode |= std::ios::out; + } else if (file_mode == 'a') { + mode |= std::ios::app; + } else { + RAFT_FAIL("Invalid file mode '%c'. Only 'w' (write) and 'a' (append) are supported.", file_mode); + } + + std::ofstream of(filename, mode); + if (!of) { RAFT_FAIL("Cannot open file %s", filename); } + + cuvs::neighbors::cagra::serialize_to_hnswlib(*res_ptr, of, *index_ptr); + + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename); } +} + +extern "C" cuvsError_t cuvsCagraSerializeToHnswlibWithMode(cuvsResources_t res, + const char* filename, + cuvsCagraIndex_t index, + char file_mode) { return cuvs::core::translate_exceptions([=] { if (index->dtype.code == kDLFloat && index->dtype.bits == 32) { - _serialize_to_hnswlib(res, filename, index); + _serialize_to_hnswlib_with_mode(res, filename, index, file_mode); } else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) { - _serialize_to_hnswlib(res, filename, index); + _serialize_to_hnswlib_with_mode(res, filename, index, file_mode); } else if (index->dtype.code == kDLInt && index->dtype.bits == 8) { - _serialize_to_hnswlib(res, filename, index); + _serialize_to_hnswlib_with_mode(res, filename, index, file_mode); } else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) { - _serialize_to_hnswlib(res, filename, index); + _serialize_to_hnswlib_with_mode(res, filename, index, file_mode); } else { RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits); } }); } + +extern "C" cuvsError_t cuvsCagraSerializeToHnswlib(cuvsResources_t res, + const char* filename, + cuvsCagraIndex_t index) +{ + return cuvsCagraSerializeToHnswlibWithMode(res, filename, index, 'w'); +} + +extern "C" cuvsError_t cuvsCagraSerialize(cuvsResources_t res, + const char* filename, + cuvsCagraIndex_t index, + bool include_dataset) +{ + return cuvsCagraSerializeWithMode(res, filename, index, include_dataset, 'w'); +} diff --git a/cpp/src/neighbors/cagra_serialize.cuh b/cpp/src/neighbors/cagra_serialize.cuh index 1b153b2ce3..1a8a2068f9 100644 --- a/cpp/src/neighbors/cagra_serialize.cuh +++ b/cpp/src/neighbors/cagra_serialize.cuh @@ -21,13 +21,22 @@ namespace cuvs::neighbors::cagra { #define CUVS_INST_CAGRA_SERIALIZE(DTYPE) \ + void serialize(raft::resources const& handle, \ + const std::string& filename, \ + const cuvs::neighbors::cagra::index& index, \ + bool include_dataset, \ + char file_mode) \ + { \ + cuvs::neighbors::cagra::detail::serialize( \ + handle, filename, index, include_dataset, file_mode); \ + }; \ void serialize(raft::resources const& handle, \ const std::string& filename, \ const cuvs::neighbors::cagra::index& index, \ bool include_dataset) \ { \ cuvs::neighbors::cagra::detail::serialize( \ - handle, filename, index, include_dataset); \ + handle, filename, index, include_dataset, 'w'); \ }; \ \ void deserialize(raft::resources const& handle, \ diff --git a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh index 4bd761dc60..7c89919c15 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh @@ -85,9 +85,17 @@ template void serialize(raft::resources const& res, const std::string& filename, const index& index_, - bool include_dataset) + bool include_dataset, + char file_mode = 'w') { - std::ofstream of(filename, std::ios::out | std::ios::binary); + std::ios_base::openmode mode = std::ios::binary; + if (file_mode == 'a') { + mode |= std::ios::app; + } else { + mode |= std::ios::out; + } + + std::ofstream of(filename, mode); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } detail::serialize(res, of, index_, include_dataset); diff --git a/cpp/tests/neighbors/ann_cagra_c.cu b/cpp/tests/neighbors/ann_cagra_c.cu index ae80cc8986..2e942f0656 100644 --- a/cpp/tests/neighbors/ann_cagra_c.cu +++ b/cpp/tests/neighbors/ann_cagra_c.cu @@ -19,6 +19,8 @@ #include #include #include +#include +#include #include #include @@ -600,3 +602,161 @@ TEST(CagraC, BuildMergeSearch) cuvsCagraIndexDestroy(index_main); cuvsResourcesDestroy(res); } + +TEST(CagraC, SerializeWithModes) +{ + int64_t n_rows = 100; + int64_t n_dim = 10; + + // Create resources + cuvsResources_t res; + ASSERT_EQ(cuvsResourcesCreate(&res), CUVS_SUCCESS); + + // Create GPU data using RAFT handle + raft::handle_t handle; + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector data(n_rows * n_dim, stream); + raft::random::RngState r(1234ULL); + raft::random::uniform(handle, r, data.data(), n_rows * n_dim, 0.1f, 2.0f); + + // Create DLManagedTensor for dataset + DLManagedTensor dataset_tensor; + dataset_tensor.dl_tensor.data = data.data(); + dataset_tensor.dl_tensor.device.device_type = kDLCUDA; + dataset_tensor.dl_tensor.device.device_id = 0; + dataset_tensor.dl_tensor.ndim = 2; + dataset_tensor.dl_tensor.dtype.code = kDLFloat; + dataset_tensor.dl_tensor.dtype.bits = 32; + dataset_tensor.dl_tensor.dtype.lanes = 1; + dataset_tensor.dl_tensor.byte_offset = 0; + int64_t shape[2] = {n_rows, n_dim}; + dataset_tensor.dl_tensor.shape = shape; + dataset_tensor.dl_tensor.strides = nullptr; + + // Create index parameters + cuvsCagraIndexParams_t params; + ASSERT_EQ(cuvsCagraIndexParamsCreate(¶ms), CUVS_SUCCESS); + + // Build index + cuvsCagraIndex_t index; + ASSERT_EQ(cuvsCagraIndexCreate(&index), CUVS_SUCCESS); + ASSERT_EQ(cuvsCagraBuild(res, params, &dataset_tensor, index), CUVS_SUCCESS); + + // Test serialization with 'w' mode + const char* filename_w = "/tmp/test_cagra_w.index"; + ASSERT_EQ(cuvsCagraSerializeWithMode(res, filename_w, index, true, 'w'), CUVS_SUCCESS); + + // Test serialization with 'a' mode + const char* filename_a = "/tmp/test_cagra_a.index"; + ASSERT_EQ(cuvsCagraSerializeWithMode(res, filename_a, index, true, 'a'), CUVS_SUCCESS); + + // Test backward compatibility (should use 'w' mode by default) + const char* filename_compat = "/tmp/test_cagra_compat.index"; + ASSERT_EQ(cuvsCagraSerialize(res, filename_compat, index, true), CUVS_SUCCESS); + + // Verify files exist + std::ifstream file_w(filename_w, std::ios::binary); + ASSERT_TRUE(file_w.good()); + file_w.close(); + + std::ifstream file_a(filename_a, std::ios::binary); + ASSERT_TRUE(file_a.good()); + file_a.close(); + + std::ifstream file_compat(filename_compat, std::ios::binary); + ASSERT_TRUE(file_compat.good()); + file_compat.close(); + + // Test deserialization works with both files + cuvsCagraIndex_t index_loaded_w, index_loaded_compat; + ASSERT_EQ(cuvsCagraIndexCreate(&index_loaded_w), CUVS_SUCCESS); + ASSERT_EQ(cuvsCagraIndexCreate(&index_loaded_compat), CUVS_SUCCESS); + + ASSERT_EQ(cuvsCagraDeserialize(res, filename_w, index_loaded_w), CUVS_SUCCESS); + ASSERT_EQ(cuvsCagraDeserialize(res, filename_compat, index_loaded_compat), CUVS_SUCCESS); + + // Clean up + std::remove(filename_w); + std::remove(filename_a); + std::remove(filename_compat); + + ASSERT_EQ(cuvsCagraIndexDestroy(index), CUVS_SUCCESS); + ASSERT_EQ(cuvsCagraIndexDestroy(index_loaded_w), CUVS_SUCCESS); + ASSERT_EQ(cuvsCagraIndexDestroy(index_loaded_compat), CUVS_SUCCESS); + ASSERT_EQ(cuvsCagraIndexParamsDestroy(params), CUVS_SUCCESS); + ASSERT_EQ(cuvsResourcesDestroy(res), CUVS_SUCCESS); +} + +TEST(CagraC, SerializeToHnswlibWithModes) +{ + int64_t n_rows = 100; + int64_t n_dim = 10; + + // Create resources + cuvsResources_t res; + ASSERT_EQ(cuvsResourcesCreate(&res), CUVS_SUCCESS); + + // Create GPU data using RAFT handle + raft::handle_t handle; + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector data(n_rows * n_dim, stream); + raft::random::RngState r(1234ULL); + raft::random::uniform(handle, r, data.data(), n_rows * n_dim, 0.1f, 2.0f); + + // Create DLManagedTensor for dataset + DLManagedTensor dataset_tensor; + dataset_tensor.dl_tensor.data = data.data(); + dataset_tensor.dl_tensor.device.device_type = kDLCUDA; + dataset_tensor.dl_tensor.device.device_id = 0; + dataset_tensor.dl_tensor.ndim = 2; + dataset_tensor.dl_tensor.dtype.code = kDLFloat; + dataset_tensor.dl_tensor.dtype.bits = 32; + dataset_tensor.dl_tensor.dtype.lanes = 1; + dataset_tensor.dl_tensor.byte_offset = 0; + int64_t shape[2] = {n_rows, n_dim}; + dataset_tensor.dl_tensor.shape = shape; + dataset_tensor.dl_tensor.strides = nullptr; + + // Create index parameters + cuvsCagraIndexParams_t params; + ASSERT_EQ(cuvsCagraIndexParamsCreate(¶ms), CUVS_SUCCESS); + + // Build index + cuvsCagraIndex_t index; + ASSERT_EQ(cuvsCagraIndexCreate(&index), CUVS_SUCCESS); + ASSERT_EQ(cuvsCagraBuild(res, params, &dataset_tensor, index), CUVS_SUCCESS); + + // Test serialization to HNSWLIB with 'w' mode + const char* filename_w = "/tmp/test_cagra_hnswlib_w.index"; + ASSERT_EQ(cuvsCagraSerializeToHnswlibWithMode(res, filename_w, index, 'w'), CUVS_SUCCESS); + + // Test serialization to HNSWLIB with 'a' mode + const char* filename_a = "/tmp/test_cagra_hnswlib_a.index"; + ASSERT_EQ(cuvsCagraSerializeToHnswlibWithMode(res, filename_a, index, 'a'), CUVS_SUCCESS); + + // Test backward compatibility (should use 'w' mode by default) + const char* filename_compat = "/tmp/test_cagra_hnswlib_compat.index"; + ASSERT_EQ(cuvsCagraSerializeToHnswlib(res, filename_compat, index), CUVS_SUCCESS); + + // Verify files exist + std::ifstream file_w(filename_w, std::ios::binary); + ASSERT_TRUE(file_w.good()); + file_w.close(); + + std::ifstream file_a(filename_a, std::ios::binary); + ASSERT_TRUE(file_a.good()); + file_a.close(); + + std::ifstream file_compat(filename_compat, std::ios::binary); + ASSERT_TRUE(file_compat.good()); + file_compat.close(); + + // Clean up + std::remove(filename_w); + std::remove(filename_a); + std::remove(filename_compat); + + ASSERT_EQ(cuvsCagraIndexDestroy(index), CUVS_SUCCESS); + ASSERT_EQ(cuvsCagraIndexParamsDestroy(params), CUVS_SUCCESS); + ASSERT_EQ(cuvsResourcesDestroy(res), CUVS_SUCCESS); +} diff --git a/cpp/tests/neighbors/brute_force_c.cu b/cpp/tests/neighbors/brute_force_c.cu index d41efb24a2..c015a4e215 100644 --- a/cpp/tests/neighbors/brute_force_c.cu +++ b/cpp/tests/neighbors/brute_force_c.cu @@ -25,6 +25,7 @@ #include #include +#include #include #include #include @@ -531,3 +532,81 @@ TEST(BruteForceC, BuildSearchWithBitsetFilter) run_test_with_filter(n_rows, n_queries, n_dim, n_neighbors, BITSET); } + +TEST(BruteForceC, SerializeWithModes) +{ + int64_t n_rows = 1000; + int64_t n_dim = 32; + + // Create resources + cuvsResources_t res; + ASSERT_EQ(cuvsResourcesCreate(&res), CUVS_SUCCESS); + + // Create GPU data using RAFT handle + raft::handle_t handle; + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector data(n_rows * n_dim, stream); + generate_random_data(data.data(), n_rows * n_dim); + + // Create DLManagedTensor for dataset + DLManagedTensor dataset_tensor; + dataset_tensor.dl_tensor.data = data.data(); + dataset_tensor.dl_tensor.device.device_type = kDLCUDA; + dataset_tensor.dl_tensor.device.device_id = 0; + dataset_tensor.dl_tensor.ndim = 2; + dataset_tensor.dl_tensor.dtype.code = kDLFloat; + dataset_tensor.dl_tensor.dtype.bits = 32; + dataset_tensor.dl_tensor.dtype.lanes = 1; + dataset_tensor.dl_tensor.byte_offset = 0; + int64_t shape[2] = {n_rows, n_dim}; + dataset_tensor.dl_tensor.shape = shape; + dataset_tensor.dl_tensor.strides = nullptr; + + // Build index + cuvsBruteForceIndex_t index; + ASSERT_EQ(cuvsBruteForceIndexCreate(&index), CUVS_SUCCESS); + ASSERT_EQ(cuvsBruteForceBuild(res, &dataset_tensor, L2Expanded, 0.0f, index), CUVS_SUCCESS); + + // Test serialization with 'w' mode + const char* filename_w = "/tmp/test_brute_force_w.index"; + ASSERT_EQ(cuvsBruteForceSerializeWithMode(res, filename_w, index, 'w'), CUVS_SUCCESS); + + // Test serialization with 'a' mode + const char* filename_a = "/tmp/test_brute_force_a.index"; + ASSERT_EQ(cuvsBruteForceSerializeWithMode(res, filename_a, index, 'a'), CUVS_SUCCESS); + + // Test backward compatibility (should use 'w' mode by default) + const char* filename_compat = "/tmp/test_brute_force_compat.index"; + ASSERT_EQ(cuvsBruteForceSerialize(res, filename_compat, index), CUVS_SUCCESS); + + // Verify files exist + std::ifstream file_w(filename_w, std::ios::binary); + ASSERT_TRUE(file_w.good()); + file_w.close(); + + std::ifstream file_a(filename_a, std::ios::binary); + ASSERT_TRUE(file_a.good()); + file_a.close(); + + std::ifstream file_compat(filename_compat, std::ios::binary); + ASSERT_TRUE(file_compat.good()); + file_compat.close(); + + // Test deserialization works with both files + cuvsBruteForceIndex_t index_loaded_w, index_loaded_compat; + ASSERT_EQ(cuvsBruteForceIndexCreate(&index_loaded_w), CUVS_SUCCESS); + ASSERT_EQ(cuvsBruteForceIndexCreate(&index_loaded_compat), CUVS_SUCCESS); + + ASSERT_EQ(cuvsBruteForceDeserialize(res, filename_w, index_loaded_w), CUVS_SUCCESS); + ASSERT_EQ(cuvsBruteForceDeserialize(res, filename_compat, index_loaded_compat), CUVS_SUCCESS); + + // Clean up + std::remove(filename_w); + std::remove(filename_a); + std::remove(filename_compat); + + ASSERT_EQ(cuvsBruteForceIndexDestroy(index), CUVS_SUCCESS); + ASSERT_EQ(cuvsBruteForceIndexDestroy(index_loaded_w), CUVS_SUCCESS); + ASSERT_EQ(cuvsBruteForceIndexDestroy(index_loaded_compat), CUVS_SUCCESS); + ASSERT_EQ(cuvsResourcesDestroy(res), CUVS_SUCCESS); +}