diff --git a/src/util/HashSet.h b/src/util/HashSet.h index 54f12a8f69..f638df2ad7 100644 --- a/src/util/HashSet.h +++ b/src/util/HashSet.h @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -20,6 +21,21 @@ using std::string; namespace ad_utility { + +// `slotMemoryCost` represents the per-slot memory cost of a node hash set. +// It accounts for the memory used by a slot in the hash table, which typically +// consists of a pointer (used for node storage) plus any additional control +// bytes required for maintaining the hash set's structure and state. +// This value helps estimate and manage memory consumption for operations that +// involve slots, such as insertion and rehashing. +// +// The value is defined as `sizeof(void*) + 1` bytes, where: +// - `sizeof(void*)` represents the size of a pointer on the platform (usually 4 +// bytes for 32-bit and 8 bytes for 64-bit systems). +// - `+ 1` accounts for an extra control byte used for state management in the +// hash set. +constexpr size_t slotMemoryCost = sizeof(void*) + 1; + // Wrapper for HashSets (with elements of type T) to be used everywhere // throughout code for the semantic search. This wrapper interface is not // designed to be complete from the beginning. Feel free to extend it at need. @@ -51,6 +67,7 @@ class CustomHashSetWithMemoryLimit { detail::AllocationMemoryLeftThreadsafe memoryLeft_; MemorySize memoryUsed_; SizeGetter sizeGetter_; + size_t currentSlotSize_; public: CustomHashSetWithMemoryLimit( @@ -58,29 +75,65 @@ class CustomHashSetWithMemoryLimit { SizeGetter sizeGetter = {}) : memoryLeft_{memoryLeft}, memoryUsed_{MemorySize::bytes(0)}, - sizeGetter_{sizeGetter} {} + sizeGetter_{sizeGetter}, + currentSlotSize_{0} { + // Once the hash set is initialized, calculate the initial memory + // used by the slots of the hash set + updateSlotArrayMemoryUsage(); + } + + ~CustomHashSetWithMemoryLimit() { decreaseMemoryUsed(memoryUsed_); } - ~CustomHashSetWithMemoryLimit() { - memoryLeft_.ptr()->wlock()->increase(memoryUsed_); + // Try to allocate the amount of memory requested + void increaseMemoryUsed(ad_utility::MemorySize amount) { + memoryLeft_.ptr()->wlock()->decrease_if_enough_left_or_throw(amount); + memoryUsed_ += amount; + } + + // Decrease the amount of memory used + void decreaseMemoryUsed(ad_utility::MemorySize amount) { + memoryLeft_.ptr()->wlock()->increase(amount); + memoryUsed_ -= amount; + } + + // Update the memory usage for the slot array if the bucket count changes. + // This function should be called after any operation that could cause + // rehashing. When the slot count increases, it reserves additional memory, + // and if the slot count decreases, it releases the unused memory back to the + // memory tracker. + void updateSlotArrayMemoryUsage() { + size_t newSlotSize = hashSet_.bucket_count(); + if (newSlotSize != currentSlotSize_) { + if (newSlotSize > currentSlotSize_) { + ad_utility::MemorySize sizeIncrease = + ad_utility::MemorySize::bytes(slotMemoryCost) * + (newSlotSize - currentSlotSize_); + increaseMemoryUsed(sizeIncrease); + } else { + ad_utility::MemorySize sizeDecrease = + ad_utility::MemorySize::bytes(slotMemoryCost) * + (currentSlotSize_ - newSlotSize); + + decreaseMemoryUsed(sizeDecrease); + } + } + currentSlotSize_ = newSlotSize; } // Insert an element into the hash set. If the memory limit is exceeded, the // insert operation fails with a runtime error. std::pair insert(const T& value) { - MemorySize size = sizeGetter_(value); - if (!memoryLeft_.ptr()->wlock()->decrease_if_enough_left_or_return_false( - size)) { - throw std::runtime_error( - "The element to be inserted is too large for the hash set."); - } + MemorySize size = + sizeGetter_(value) + ad_utility::MemorySize::bytes(sizeof(T)); + increaseMemoryUsed(size); const auto& [it, wasInserted] = hashSet_.insert(value); - if (wasInserted) { - memoryUsed_ += size; - } else { - memoryLeft_.ptr()->wlock()->increase(size); + if (!wasInserted) { + decreaseMemoryUsed(size); } + + updateSlotArrayMemoryUsage(); return std::pair{it, wasInserted}; } @@ -88,18 +141,28 @@ class CustomHashSetWithMemoryLimit { void erase(const T& value) { auto it = hashSet_.find(value); if (it != hashSet_.end()) { - MemorySize size = sizeGetter_(*it); + MemorySize size = + sizeGetter_(*it) + ad_utility::MemorySize::bytes(sizeof(T)); hashSet_.erase(it); - memoryLeft_.ptr()->wlock()->increase(size); - memoryUsed_ -= size; + decreaseMemoryUsed(size); + updateSlotArrayMemoryUsage(); } } // _____________________________________________________________________________ void clear() { hashSet_.clear(); - memoryLeft_.ptr()->wlock()->increase(memoryUsed_); - memoryUsed_ = MemorySize::bytes(0); + // Release all node memory + decreaseMemoryUsed(memoryUsed_); + + // Update slot memory usage based on the new bucket count after clearing + size_t newSlotSize = hashSet_.bucket_count(); + ad_utility::MemorySize slotMemoryAfterClear = + MemorySize::bytes(slotMemoryCost * newSlotSize); + // After clearing it only tracks the slot memory as nodes are gone + increaseMemoryUsed(slotMemoryAfterClear); + + currentSlotSize_ = newSlotSize; } // _____________________________________________________________________________ diff --git a/test/LocalVocabTest.cpp b/test/LocalVocabTest.cpp index feb81d5f42..977a99019c 100644 --- a/test/LocalVocabTest.cpp +++ b/test/LocalVocabTest.cpp @@ -399,13 +399,12 @@ TEST(LocalVocab, memoryLimit) { } } catch (const std::exception& e) { std::string errorMessage = e.what(); - EXPECT_THAT(errorMessage, ::testing::StartsWith( - "The element to be inserted is too large")); + EXPECT_THAT(errorMessage, ::testing::StartsWith("Tried to allocate")); } auto extraWord = ad_utility::triple_component::LiteralOrIri::literalWithoutQuotes( "ExtraWord"); EXPECT_THROW(localVocab.getIndexAndAddIfNotContained(extraWord), - std::runtime_error); + ad_utility::detail::AllocationExceedsLimitException); }