Skip to content

Commit

Permalink
increase the code coverage of hnswlib
Browse files Browse the repository at this point in the history
Signed-off-by: jinjiabao.jjb <jinjiabao.jjb@antgroup.com>
  • Loading branch information
jinjiabao.jjb committed Jan 14, 2025
1 parent e23c784 commit 785c681
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 250 deletions.
3 changes: 0 additions & 3 deletions src/algorithm/hnswlib/algorithm_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ class AlgorithmInterface {
size_t ef,
vsag::BaseFilterFunctor* isIdAllowed = nullptr) const;

virtual void
saveIndex(const std::string& location) = 0;

virtual void
saveIndex(void* d) = 0;

Expand Down
8 changes: 0 additions & 8 deletions src/algorithm/hnswlib/block_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,6 @@ BlockManager::Serialize(std::ostream& ofs, size_t cur_element_count) {
return this->SerializeImpl(writer, cur_element_count);
}

bool
BlockManager::Deserialize(std::function<void(uint64_t, uint64_t, void*)> read_func,
uint64_t cursor,
size_t cur_element_count) {
ReadFuncStreamReader reader(read_func, cursor);
return this->DeserializeImpl(reader, cur_element_count);
}

bool
BlockManager::Deserialize(std::istream& ifs, size_t cur_element_count) {
IOStreamReader reader(ifs);
Expand Down
5 changes: 0 additions & 5 deletions src/algorithm/hnswlib/block_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ class BlockManager {
bool
Serialize(std::ostream& ofs, size_t cur_element_count);

bool
Deserialize(std::function<void(uint64_t, uint64_t, void*)> read_func,
uint64_t cursor,
size_t cur_element_count);

bool
Deserialize(std::istream& ifs, size_t cur_element_count);

Expand Down
99 changes: 14 additions & 85 deletions src/algorithm/hnswlib/hnswalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,14 +789,6 @@ HierarchicalNSW::saveIndex(std::ostream& out_stream) {
SerializeImpl(writer);
}

void
HierarchicalNSW::saveIndex(const std::string& location) {
std::ofstream output(location, std::ios::binary);
IOStreamWriter writer(output);
SerializeImpl(writer);
output.close();
}

template <typename T>
static void
WriteOne(StreamWriter& writer, T& value) {
Expand Down Expand Up @@ -979,44 +971,6 @@ HierarchicalNSW::markDeletedInternal(InnerIdType internalId) {
}
}

/*
* Removes the deleted mark of the node, does NOT really change the current graph.
*
* Note: the method is not safe to use when replacement of deleted elements is enabled,
* because elements marked as deleted can be completely removed by addPoint
*/
void
HierarchicalNSW::unmarkDelete(LabelType label) {
// lock all operations with element by label
std::unique_lock lock_table(label_lookup_lock_);
auto search = label_lookup_.find(label);
if (search == label_lookup_.end()) {
throw std::runtime_error("Label not found");
}
InnerIdType internalId = search->second;
unmarkDeletedInternal(internalId);
}

/*
* Remove the deleted mark of the node.
*/
void
HierarchicalNSW::unmarkDeletedInternal(InnerIdType internalId) {
assert(internalId < cur_element_count_);
if (isMarkedDeleted(internalId)) {
unsigned char* ll_cur =
(unsigned char*)data_level0_memory_->GetElementPtr(internalId, offsetLevel0_) + 2;
*ll_cur &= ~DELETE_MARK;
num_deleted_ -= 1;
if (allow_replace_deleted_) {
std::unique_lock<std::mutex> lock_deleted_elements(deleted_elements_lock_);
deleted_elements_.erase(internalId);
}
} else {
throw std::runtime_error("The requested to undelete element is not deleted");
}
}

/*
* Adds point.
*/
Expand Down Expand Up @@ -1441,8 +1395,13 @@ HierarchicalNSW::searchKnn(const void* query_data,

MaxHeap top_candidates(allocator_);

top_candidates =
searchBaseLayerST<false, true>(currObj, query_data, std::max(ef, k), isIdAllowed);
if (num_deleted_ == 0) {
top_candidates =
searchBaseLayerST<false, true>(currObj, query_data, std::max(ef, k), isIdAllowed);
} else {
top_candidates =
searchBaseLayerST<true, true>(currObj, query_data, std::max(ef, k), isIdAllowed);
}

while (top_candidates.size() > k) {
top_candidates.pop();
Expand Down Expand Up @@ -1501,8 +1460,13 @@ HierarchicalNSW::searchRange(const void* query_data,
}

MaxHeap top_candidates(allocator_);

top_candidates = searchBaseLayerST<false, true>(currObj, query_data, radius, ef, isIdAllowed);
if (num_deleted_ == 0) {
top_candidates =
searchBaseLayerST<false, true>(currObj, query_data, radius, ef, isIdAllowed);
} else {
top_candidates =
searchBaseLayerST<true, true>(currObj, query_data, radius, ef, isIdAllowed);
}

while (not top_candidates.empty()) {
std::pair<float, InnerIdType> rez = top_candidates.top();
Expand All @@ -1513,39 +1477,4 @@ HierarchicalNSW::searchRange(const void* query_data,
// std::cout << "hnswalg::result.size(): " << result.size() << std::endl;
return result;
}

void
HierarchicalNSW::checkIntegrity() {
int connections_checked = 0;
vsag::Vector<int> inbound_connections_num(cur_element_count_, 0, allocator_);
for (int i = 0; i < cur_element_count_; i++) {
for (int l = 0; l <= element_levels_[i]; l++) {
auto data_ll_cur = getLinklistAtLevelWithLock(i, l);
linklistsizeint* ll_cur = (linklistsizeint*)data_ll_cur.get();
int size = getListCount(ll_cur);
auto* data = (InnerIdType*)(ll_cur + 1);
vsag::UnorderedSet<InnerIdType> s(allocator_);
for (int j = 0; j < size; j++) {
assert(data[j] > 0);
assert(data[j] < cur_element_count_);
assert(data[j] != i);
inbound_connections_num[data[j]]++;
s.insert(data[j]);
connections_checked++;
}
assert(s.size() == size);
}
}
if (cur_element_count_ > 1) {
int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0];
for (int i = 0; i < cur_element_count_; i++) {
assert(inbound_connections_num[i] > 0);
min1 = std::min(inbound_connections_num[i], min1);
max1 = std::max(inbound_connections_num[i], max1);
}
std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
}
std::cout << "integrity ok, checked " << connections_checked << " connections\n";
}

} // namespace hnswlib
20 changes: 0 additions & 20 deletions src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
void
saveIndex(std::ostream& out_stream) override;

void
saveIndex(const std::string& location) override;

void
SerializeImpl(StreamWriter& writer);

Expand All @@ -329,20 +326,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
void
markDeletedInternal(InnerIdType internalId);

/*
* Removes the deleted mark of the node, does NOT really change the current graph.
*
* Note: the method is not safe to use when replacement of deleted elements is enabled,
* because elements marked as deleted can be completely removed by addPoint
*/
void
unmarkDelete(LabelType label);
/*
* Remove the deleted mark of the node.
*/
void
unmarkDeletedInternal(InnerIdType internalId);

/*
* Checks the first 16 bits of the memory to see if the element is marked deleted.
*/
Expand Down Expand Up @@ -405,9 +388,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
uint64_t ef,
vsag::BaseFilterFunctor* isIdAllowed = nullptr) const override;

void
checkIntegrity();

void
reset();

Expand Down
33 changes: 0 additions & 33 deletions src/algorithm/hnswlib/hnswalg_static.h
Original file line number Diff line number Diff line change
Expand Up @@ -1144,39 +1144,6 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> {
out_stream.write((char*)node_cluster_dist_, max_elements_ * sizeof(float));
}

void
saveIndex(const std::string& location) override {
throw std::runtime_error("static hnsw does not support save index");
// std::ofstream output(location, std::ios::binary);
// std::streampos position;
//
// writeBinaryPOD(output, offsetLevel0_);
// writeBinaryPOD(output, max_elements_);
// writeBinaryPOD(output, cur_element_count_);
// writeBinaryPOD(output, size_data_per_element_);
// writeBinaryPOD(output, label_offset_);
// writeBinaryPOD(output, offsetData_);
// writeBinaryPOD(output, maxlevel_);
// writeBinaryPOD(output, enterpoint_node_);
// writeBinaryPOD(output, maxM_);
//
// writeBinaryPOD(output, maxM0_);
// writeBinaryPOD(output, M_);
// writeBinaryPOD(output, mult_);
// writeBinaryPOD(output, ef_construction_);
//
// output.write(data_level0_memory_, cur_element_count_ * size_data_per_element_);
//
// for (size_t i = 0; i < cur_element_count_; i++) {
// unsigned int linkListSize =
// element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
// writeBinaryPOD(output, linkListSize);
// if (linkListSize)
// output.write(linkLists_[i], linkListSize);
// }
// output.close();
}

// load index from a file stream
void
loadIndex(StreamReader& in_stream, SpaceInterface* s, size_t max_elements_i = 0) override {
Expand Down
94 changes: 0 additions & 94 deletions src/algorithm/hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,6 @@
#endif
#endif

#if defined(USE_AVX) || defined(USE_SSE)
#ifdef _MSC_VER
#include <intrin.h>

#include <stdexcept>
void
cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
__cpuidex(out, eax, ecx);
}
static __int64
xgetbv(unsigned int x) {
return _xgetbv(x);
}
#else
#include <cpuid.h>
#include <x86intrin.h>

#include <cstdint>
#include <future>
static void
cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
__cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]);
}
static uint64_t
xgetbv(unsigned int index) {
uint32_t eax, edx;
__asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index));
return ((uint64_t)edx << 32) | eax;
}
#endif

#if defined(USE_AVX512)
#include <immintrin.h>
#endif
Expand All @@ -69,69 +38,6 @@ xgetbv(unsigned int index) {
#define PORTABLE_ALIGN64 __declspec(align(64))
#endif

// Adapted from https://github.com/Mysticial/FeatureDetector
#define _XCR_XFEATURE_ENABLED_MASK 0

static bool
AVXCapable() {
int cpuInfo[4];

// CPU support
cpuid(cpuInfo, 0, 0);
int nIds = cpuInfo[0];

bool HW_AVX = false;
if (nIds >= 0x00000001) {
cpuid(cpuInfo, 0x00000001, 0);
HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0;
}

// OS support
cpuid(cpuInfo, 1, 0);

bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;

bool avxSupported = false;
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
avxSupported = (xcrFeatureMask & 0x6) == 0x6;
}
return HW_AVX && avxSupported;
}

static bool
AVX512Capable() {
if (!AVXCapable())
return false;

int cpuInfo[4];

// CPU support
cpuid(cpuInfo, 0, 0);
int nIds = cpuInfo[0];

bool HW_AVX512F = false;
if (nIds >= 0x00000007) { // AVX512 Foundation
cpuid(cpuInfo, 0x00000007, 0);
HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0;
}

// OS support
cpuid(cpuInfo, 1, 0);

bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;

bool avx512Supported = false;
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6;
}
return HW_AVX512F && avx512Supported;
}
#endif

#include "hnswalg.h"
#include "hnswalg_static.h"
#include "space_ip.h"
Expand Down
3 changes: 2 additions & 1 deletion src/default_thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ DefaultThreadPool::Enqueue(std::function<void(void)> func) {

void
DefaultThreadPool::WaitUntilEmpty() {
pool_->wait_until_empty();
// In progschj::ThreadPool, wait_until_nothing_in_flight indicates that all tasks have been completed, while wait_until_empty means that there are no tasks waiting. Therefore, what we actually need here is the semantics of wait_until_nothing_in_flight.
pool_->wait_until_nothing_in_flight();
}

void
Expand Down
Loading

0 comments on commit 785c681

Please sign in to comment.