Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

increase the code coverage of hnswlib #328

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/algorithm/hnswlib/algorithm_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ class AlgorithmInterface {
size_t ef,
vsag::BaseFilterFunctor* isIdAllowed = nullptr) const;

virtual void
saveIndex(const std::string& location) = 0;

virtual void
saveIndex(void* d) = 0;

Expand Down
8 changes: 0 additions & 8 deletions src/algorithm/hnswlib/block_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,6 @@ BlockManager::Serialize(std::ostream& ofs, size_t cur_element_count) {
return this->SerializeImpl(writer, cur_element_count);
}

bool
BlockManager::Deserialize(std::function<void(uint64_t, uint64_t, void*)> read_func,
uint64_t cursor,
size_t cur_element_count) {
ReadFuncStreamReader reader(read_func, cursor);
return this->DeserializeImpl(reader, cur_element_count);
}

bool
BlockManager::Deserialize(std::istream& ifs, size_t cur_element_count) {
IOStreamReader reader(ifs);
Expand Down
5 changes: 0 additions & 5 deletions src/algorithm/hnswlib/block_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ class BlockManager {
bool
Serialize(std::ostream& ofs, size_t cur_element_count);

bool
Deserialize(std::function<void(uint64_t, uint64_t, void*)> read_func,
uint64_t cursor,
size_t cur_element_count);

bool
Deserialize(std::istream& ifs, size_t cur_element_count);

Expand Down
99 changes: 14 additions & 85 deletions src/algorithm/hnswlib/hnswalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,14 +789,6 @@ HierarchicalNSW::saveIndex(std::ostream& out_stream) {
SerializeImpl(writer);
}

void
HierarchicalNSW::saveIndex(const std::string& location) {
std::ofstream output(location, std::ios::binary);
IOStreamWriter writer(output);
SerializeImpl(writer);
output.close();
}

template <typename T>
static void
WriteOne(StreamWriter& writer, T& value) {
Expand Down Expand Up @@ -979,44 +971,6 @@ HierarchicalNSW::markDeletedInternal(InnerIdType internalId) {
}
}

/*
* Removes the deleted mark of the node, does NOT really change the current graph.
*
* Note: the method is not safe to use when replacement of deleted elements is enabled,
* because elements marked as deleted can be completely removed by addPoint
*/
void
HierarchicalNSW::unmarkDelete(LabelType label) {
// lock all operations with element by label
std::unique_lock lock_table(label_lookup_lock_);
auto search = label_lookup_.find(label);
if (search == label_lookup_.end()) {
throw std::runtime_error("Label not found");
}
InnerIdType internalId = search->second;
unmarkDeletedInternal(internalId);
}

/*
* Remove the deleted mark of the node.
*/
void
HierarchicalNSW::unmarkDeletedInternal(InnerIdType internalId) {
assert(internalId < cur_element_count_);
if (isMarkedDeleted(internalId)) {
unsigned char* ll_cur =
(unsigned char*)data_level0_memory_->GetElementPtr(internalId, offsetLevel0_) + 2;
*ll_cur &= ~DELETE_MARK;
num_deleted_ -= 1;
if (allow_replace_deleted_) {
std::unique_lock<std::mutex> lock_deleted_elements(deleted_elements_lock_);
deleted_elements_.erase(internalId);
}
} else {
throw std::runtime_error("The requested to undelete element is not deleted");
}
}

/*
* Adds point.
*/
Expand Down Expand Up @@ -1441,8 +1395,13 @@ HierarchicalNSW::searchKnn(const void* query_data,

MaxHeap top_candidates(allocator_);

top_candidates =
searchBaseLayerST<false, true>(currObj, query_data, std::max(ef, k), isIdAllowed);
if (num_deleted_ == 0) {
top_candidates =
searchBaseLayerST<false, true>(currObj, query_data, std::max(ef, k), isIdAllowed);
} else {
top_candidates =
searchBaseLayerST<true, true>(currObj, query_data, std::max(ef, k), isIdAllowed);
}

while (top_candidates.size() > k) {
top_candidates.pop();
Expand Down Expand Up @@ -1501,8 +1460,13 @@ HierarchicalNSW::searchRange(const void* query_data,
}

MaxHeap top_candidates(allocator_);

top_candidates = searchBaseLayerST<false, true>(currObj, query_data, radius, ef, isIdAllowed);
if (num_deleted_ == 0) {
top_candidates =
searchBaseLayerST<false, true>(currObj, query_data, radius, ef, isIdAllowed);
} else {
top_candidates =
searchBaseLayerST<true, true>(currObj, query_data, radius, ef, isIdAllowed);
}

while (not top_candidates.empty()) {
std::pair<float, InnerIdType> rez = top_candidates.top();
Expand All @@ -1513,39 +1477,4 @@ HierarchicalNSW::searchRange(const void* query_data,
// std::cout << "hnswalg::result.size(): " << result.size() << std::endl;
return result;
}

void
HierarchicalNSW::checkIntegrity() {
int connections_checked = 0;
vsag::Vector<int> inbound_connections_num(cur_element_count_, 0, allocator_);
for (int i = 0; i < cur_element_count_; i++) {
for (int l = 0; l <= element_levels_[i]; l++) {
auto data_ll_cur = getLinklistAtLevelWithLock(i, l);
linklistsizeint* ll_cur = (linklistsizeint*)data_ll_cur.get();
int size = getListCount(ll_cur);
auto* data = (InnerIdType*)(ll_cur + 1);
vsag::UnorderedSet<InnerIdType> s(allocator_);
for (int j = 0; j < size; j++) {
assert(data[j] > 0);
assert(data[j] < cur_element_count_);
assert(data[j] != i);
inbound_connections_num[data[j]]++;
s.insert(data[j]);
connections_checked++;
}
assert(s.size() == size);
}
}
if (cur_element_count_ > 1) {
int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0];
for (int i = 0; i < cur_element_count_; i++) {
assert(inbound_connections_num[i] > 0);
min1 = std::min(inbound_connections_num[i], min1);
max1 = std::max(inbound_connections_num[i], max1);
}
std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
}
std::cout << "integrity ok, checked " << connections_checked << " connections\n";
}

} // namespace hnswlib
20 changes: 0 additions & 20 deletions src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
void
saveIndex(std::ostream& out_stream) override;

void
saveIndex(const std::string& location) override;

void
SerializeImpl(StreamWriter& writer);

Expand All @@ -329,20 +326,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
void
markDeletedInternal(InnerIdType internalId);

/*
* Removes the deleted mark of the node, does NOT really change the current graph.
*
* Note: the method is not safe to use when replacement of deleted elements is enabled,
* because elements marked as deleted can be completely removed by addPoint
*/
void
unmarkDelete(LabelType label);
/*
* Remove the deleted mark of the node.
*/
void
unmarkDeletedInternal(InnerIdType internalId);

/*
* Checks the first 16 bits of the memory to see if the element is marked deleted.
*/
Expand Down Expand Up @@ -405,9 +388,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
uint64_t ef,
vsag::BaseFilterFunctor* isIdAllowed = nullptr) const override;

void
checkIntegrity();

void
reset();

Expand Down
33 changes: 0 additions & 33 deletions src/algorithm/hnswlib/hnswalg_static.h
Original file line number Diff line number Diff line change
Expand Up @@ -1144,39 +1144,6 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> {
out_stream.write((char*)node_cluster_dist_, max_elements_ * sizeof(float));
}

void
saveIndex(const std::string& location) override {
throw std::runtime_error("static hnsw does not support save index");
// std::ofstream output(location, std::ios::binary);
// std::streampos position;
//
// writeBinaryPOD(output, offsetLevel0_);
// writeBinaryPOD(output, max_elements_);
// writeBinaryPOD(output, cur_element_count_);
// writeBinaryPOD(output, size_data_per_element_);
// writeBinaryPOD(output, label_offset_);
// writeBinaryPOD(output, offsetData_);
// writeBinaryPOD(output, maxlevel_);
// writeBinaryPOD(output, enterpoint_node_);
// writeBinaryPOD(output, maxM_);
//
// writeBinaryPOD(output, maxM0_);
// writeBinaryPOD(output, M_);
// writeBinaryPOD(output, mult_);
// writeBinaryPOD(output, ef_construction_);
//
// output.write(data_level0_memory_, cur_element_count_ * size_data_per_element_);
//
// for (size_t i = 0; i < cur_element_count_; i++) {
// unsigned int linkListSize =
// element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
// writeBinaryPOD(output, linkListSize);
// if (linkListSize)
// output.write(linkLists_[i], linkListSize);
// }
// output.close();
}

// load index from a file stream
void
loadIndex(StreamReader& in_stream, SpaceInterface* s, size_t max_elements_i = 0) override {
Expand Down
94 changes: 0 additions & 94 deletions src/algorithm/hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,6 @@
#endif
#endif

#if defined(USE_AVX) || defined(USE_SSE)
#ifdef _MSC_VER
#include <intrin.h>

#include <stdexcept>
void
cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
__cpuidex(out, eax, ecx);
}
static __int64
xgetbv(unsigned int x) {
return _xgetbv(x);
}
#else
#include <cpuid.h>
#include <x86intrin.h>

#include <cstdint>
#include <future>
static void
cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
__cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]);
}
static uint64_t
xgetbv(unsigned int index) {
uint32_t eax, edx;
__asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index));
return ((uint64_t)edx << 32) | eax;
}
#endif

#if defined(USE_AVX512)
#include <immintrin.h>
#endif
Expand All @@ -69,69 +38,6 @@ xgetbv(unsigned int index) {
#define PORTABLE_ALIGN64 __declspec(align(64))
#endif

// Adapted from https://github.com/Mysticial/FeatureDetector
#define _XCR_XFEATURE_ENABLED_MASK 0

static bool
AVXCapable() {
int cpuInfo[4];

// CPU support
cpuid(cpuInfo, 0, 0);
int nIds = cpuInfo[0];

bool HW_AVX = false;
if (nIds >= 0x00000001) {
cpuid(cpuInfo, 0x00000001, 0);
HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0;
}

// OS support
cpuid(cpuInfo, 1, 0);

bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;

bool avxSupported = false;
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
avxSupported = (xcrFeatureMask & 0x6) == 0x6;
}
return HW_AVX && avxSupported;
}

static bool
AVX512Capable() {
if (!AVXCapable())
return false;

int cpuInfo[4];

// CPU support
cpuid(cpuInfo, 0, 0);
int nIds = cpuInfo[0];

bool HW_AVX512F = false;
if (nIds >= 0x00000007) { // AVX512 Foundation
cpuid(cpuInfo, 0x00000007, 0);
HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0;
}

// OS support
cpuid(cpuInfo, 1, 0);

bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;

bool avx512Supported = false;
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6;
}
return HW_AVX512F && avx512Supported;
}
#endif

#include "hnswalg.h"
#include "hnswalg_static.h"
#include "space_ip.h"
Expand Down
3 changes: 2 additions & 1 deletion src/default_thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ DefaultThreadPool::Enqueue(std::function<void(void)> func) {

void
DefaultThreadPool::WaitUntilEmpty() {
pool_->wait_until_empty();
// In progschj::ThreadPool, wait_until_nothing_in_flight indicates that all tasks have been completed, while wait_until_empty means that there are no tasks waiting. Therefore, what we actually need here is the semantics of wait_until_nothing_in_flight.
pool_->wait_until_nothing_in_flight();
}

void
Expand Down
Loading
Loading