From e6e43439075f095ea6f4f20ec5fdaaa1a00b48fa Mon Sep 17 00:00:00 2001 From: Benjamin Winger Date: Sat, 15 Feb 2025 03:32:18 +0000 Subject: [PATCH] Separate hash aggregate finalization into its own operator (#4913) --- .../operator/aggregate/hash_aggregate.h | 53 +++++++++++++++---- .../processor/operator/physical_operator.h | 1 + src/processor/map/map_aggregate.cpp | 22 ++++++-- src/processor/map/map_distinct.cpp | 2 +- .../operator/aggregate/hash_aggregate.cpp | 40 +++++++------- src/processor/operator/physical_operator.cpp | 2 + 6 files changed, 85 insertions(+), 35 deletions(-) diff --git a/src/include/processor/operator/aggregate/hash_aggregate.h b/src/include/processor/operator/aggregate/hash_aggregate.h index 921431e48dd..b7f1b02ee8d 100644 --- a/src/include/processor/operator/aggregate/hash_aggregate.h +++ b/src/include/processor/operator/aggregate/hash_aggregate.h @@ -4,6 +4,7 @@ #include #include +#include #include "aggregate_hash_table.h" #include "common/copy_constructors.h" @@ -14,6 +15,7 @@ #include "main/client_context.h" #include "processor/operator/aggregate/aggregate_input.h" #include "processor/operator/aggregate/base_aggregate.h" +#include "processor/operator/physical_operator.h" #include "processor/result/factorized_table.h" #include "processor/result/factorized_table_schema.h" @@ -40,7 +42,8 @@ class HashAggregateSharedState final : public BaseAggregateSharedState { public: explicit HashAggregateSharedState(main::ClientContext* context, HashAggregateInfo hashAggInfo, const std::vector& aggregateFunctions, - std::span aggregateInfos); + std::span aggregateInfos, std::vector keyTypes, + std::vector payloadTypes); void appendTuple(std::span tuple, common::hash_t hash) { auto& partition = @@ -59,7 +62,7 @@ class HashAggregateSharedState final : public BaseAggregateSharedState { overflow.push(std::make_unique(std::move(overflowBuffer))); } - void finalizeAggregateHashTable(const AggregateHashTable& localHashTable); + void finalizeAggregateHashTable(); std::pair getNextRangeToRead() override; @@ -78,15 +81,10 @@ class HashAggregateSharedState final : public BaseAggregateSharedState { return globalPartitions[0].hashTable->getTableSchema(); } - void setThreadFinishedProducing() { numThreadsFinishedProducing++; } - bool allThreadsFinishedProducing() const { return numThreadsFinishedProducing >= numThreads; } - - void registerThread() { numThreads++; } + const HashAggregateInfo& getAggregateInfo() const { return aggInfo; } void assertFinalized() const; - const HashAggregateInfo& getAggregateInfo() const { return aggInfo; } - protected: std::tuple getPartitionForOffset( common::offset_t offset) const; @@ -152,10 +150,9 @@ class HashAggregateSharedState final : public BaseAggregateSharedState { }; std::vector globalPartitions; uint64_t limitNumber; - std::atomic numThreadsFinishedProducing; - std::atomic numThreads; storage::MemoryManager* memoryManager; uint8_t shiftForPartitioning; + bool readyForFinalization = false; }; struct HashAggregateLocalState { @@ -207,7 +204,10 @@ class HashAggregate final : public BaseAggregate { void executeInternal(ExecutionContext* context) override; - void finalizeInternal(ExecutionContext* context) override; + // Delegated to HashAggregateFinalize so it can be parallelized + void finalizeInternal(ExecutionContext* /*context*/) override { + sharedState->readyForFinalization = true; + } std::unique_ptr copy() override { return make_unique(resultSetDescriptor->copy(), sharedState, @@ -222,5 +222,36 @@ class HashAggregate final : public BaseAggregate { std::shared_ptr sharedState; }; +class HashAggregateFinalize final : public Sink { +public: + HashAggregateFinalize(std::unique_ptr resultSetDescriptor, + std::shared_ptr sharedState, + std::unique_ptr child, uint32_t id, + std::unique_ptr printInfo) + : Sink{std::move(resultSetDescriptor), PhysicalOperatorType::AGGREGATE_FINALIZE, + std::move(child), id, std::move(printInfo)}, + sharedState{std::move(sharedState)} {} + + // Otherwise the runtime metrics for this operator are negative + // since it doesn't call children[0]->getNextTuple + bool isSource() const override { return true; } + + void executeInternal(ExecutionContext* /*context*/) override { + KU_ASSERT(sharedState->readyForFinalization); + sharedState->finalizeAggregateHashTable(); + } + void finalizeInternal(ExecutionContext* /*context*/) override { + sharedState->assertFinalized(); + } + + std::unique_ptr copy() override { + return make_unique(resultSetDescriptor->copy(), sharedState, + children[0]->copy(), id, printInfo->copy()); + } + +private: + std::shared_ptr sharedState; +}; + } // namespace processor } // namespace kuzu diff --git a/src/include/processor/operator/physical_operator.h b/src/include/processor/operator/physical_operator.h index 8b62d015d76..a247e501369 100644 --- a/src/include/processor/operator/physical_operator.h +++ b/src/include/processor/operator/physical_operator.h @@ -17,6 +17,7 @@ using physical_op_id = uint32_t; enum class PhysicalOperatorType : uint8_t { ALTER, AGGREGATE, + AGGREGATE_FINALIZE, AGGREGATE_SCAN, ATTACH_DATABASE, BATCH_INSERT, diff --git a/src/processor/map/map_aggregate.cpp b/src/processor/map/map_aggregate.cpp index 4649398fd47..92eb8efdc5f 100644 --- a/src/processor/map/map_aggregate.cpp +++ b/src/processor/map/map_aggregate.cpp @@ -1,10 +1,12 @@ #include "binder/expression/aggregate_function_expression.h" +#include "common/types/types.h" #include "planner/operator/logical_aggregate.h" #include "processor/operator/aggregate/hash_aggregate.h" #include "processor/operator/aggregate/hash_aggregate_scan.h" #include "processor/operator/aggregate/simple_aggregate.h" #include "processor/operator/aggregate/simple_aggregate_scan.h" #include "processor/plan_mapper.h" +#include "processor/result/result_set_descriptor.h" using namespace kuzu::binder; using namespace kuzu::common; @@ -143,12 +145,23 @@ std::unique_ptr PlanMapper::createHashAggregate(const expressi auto aggregateInputInfos = getAggregateInputInfos(allKeys, aggregates, *inSchema); auto flatKeys = getKeyExpressions(keys, *inSchema, true /* isFlat */); auto unFlatKeys = getKeyExpressions(keys, *inSchema, false /* isFlat */); + std::vector keyTypes, payloadTypes; + for (auto& key : flatKeys) { + keyTypes.push_back(key->getDataType().copy()); + } + for (auto& key : unFlatKeys) { + keyTypes.push_back(key->getDataType().copy()); + } + for (auto& payload : payloads) { + payloadTypes.push_back(payload->getDataType().copy()); + } auto tableSchema = getFactorizedTableSchema(flatKeys, unFlatKeys, payloads, aggFunctions); HashAggregateInfo aggregateInfo{getDataPos(flatKeys, *inSchema), getDataPos(unFlatKeys, *inSchema), getDataPos(payloads, *inSchema), std::move(tableSchema)}; - auto sharedState = std::make_shared(clientContext, - std::move(aggregateInfo), aggFunctions, aggregateInputInfos); + auto sharedState = + std::make_shared(clientContext, std::move(aggregateInfo), + aggFunctions, aggregateInputInfos, std::move(keyTypes), std::move(payloadTypes)); auto printInfo = std::make_unique(allKeys, aggregates); auto aggregate = make_unique(std::make_unique(inSchema), sharedState, std::move(aggFunctions), std::move(aggregateInputInfos), @@ -159,8 +172,11 @@ std::unique_ptr PlanMapper::createHashAggregate(const expressi outputExpressions.insert(outputExpressions.end(), unFlatKeys.begin(), unFlatKeys.end()); outputExpressions.insert(outputExpressions.end(), payloads.begin(), payloads.end()); auto aggOutputPos = getDataPos(aggregates, *outSchema); + auto finalizer = + std::make_unique(std::make_unique(inSchema), + sharedState, std::move(aggregate), getOperatorID(), printInfo->copy()); return std::make_unique(sharedState, - getDataPos(outputExpressions, *outSchema), std::move(aggOutputPos), std::move(aggregate), + getDataPos(outputExpressions, *outSchema), std::move(aggOutputPos), std::move(finalizer), getOperatorID(), printInfo->copy()); } diff --git a/src/processor/map/map_distinct.cpp b/src/processor/map/map_distinct.cpp index 3330ef051f8..1908a23fdd6 100644 --- a/src/processor/map/map_distinct.cpp +++ b/src/processor/map/map_distinct.cpp @@ -26,7 +26,7 @@ std::unique_ptr PlanMapper::mapDistinct(const LogicalOperator* } auto op = createDistinctHashAggregate(distinct->getKeys(), distinct->getPayloads(), inSchema, outSchema, std::move(prevOperator)); - auto hashAggregate = op->getChild(0)->ptrCast(); + auto hashAggregate = op->getChild(0)->getChild(0)->ptrCast(); hashAggregate->getSharedState()->setLimitNumber(limitNum); auto printInfo = static_cast(hashAggregate->getPrintInfo()); const_cast(printInfo)->limitNum = limitNum; diff --git a/src/processor/operator/aggregate/hash_aggregate.cpp b/src/processor/operator/aggregate/hash_aggregate.cpp index c49b00e7864..a5f0922ecf7 100644 --- a/src/processor/operator/aggregate/hash_aggregate.cpp +++ b/src/processor/operator/aggregate/hash_aggregate.cpp @@ -1,7 +1,6 @@ #include "processor/operator/aggregate/hash_aggregate.h" -#include -#include +#include #include "binder/expression/expression_util.h" #include "common/assert.h" @@ -35,14 +34,19 @@ std::string HashAggregatePrintInfo::toString() const { HashAggregateSharedState::HashAggregateSharedState(main::ClientContext* context, HashAggregateInfo hashAggInfo, const std::vector& aggregateFunctions, - std::span aggregateInfos) + std::span aggregateInfos, std::vector keyTypes, + std::vector payloadTypes) : BaseAggregateSharedState{aggregateFunctions}, aggInfo{std::move(hashAggInfo)}, globalPartitions{static_cast(context->getMaxNumThreadForExec())}, - limitNumber{common::INVALID_LIMIT}, numThreads{0}, memoryManager{context->getMemoryManager()}, + limitNumber{common::INVALID_LIMIT}, memoryManager{context->getMemoryManager()}, // .size() - 1 since we want the bit width of the largest value that could be used to index // the partitions shiftForPartitioning{ static_cast(sizeof(hash_t) * 8 - std::bit_width(globalPartitions.size() - 1))} { + std::vector distinctAggregateKeyTypes; + for (auto& aggInfo : aggregateInfos) { + distinctAggregateKeyTypes.push_back(aggInfo.distinctAggKeyType.copy()); + } // When copying directly into factorizedTables the table's schema's internal mayContainNulls // won't be updated and it's probably less work to just always check nulls @@ -54,6 +58,12 @@ HashAggregateSharedState::HashAggregateSharedState(main::ClientContext* context, auto& partition = globalPartitions[0]; partition.queue = std::make_unique(context->getMemoryManager(), this->aggInfo.tableSchema.copy()); + + // Always create a hash table for the first partition. Any other partitions which are non-empty + // when finalizing will create an empty copy of this table + partition.hashTable = std::make_unique(*context->getMemoryManager(), + std::move(keyTypes), std::move(payloadTypes), aggregateFunctions, distinctAggregateKeyTypes, + 0, this->aggInfo.tableSchema.copy()); for (size_t functionIdx = 0; functionIdx < aggregateFunctions.size(); functionIdx++) { auto& function = aggregateFunctions[functionIdx]; if (function.isFunctionDistinct()) { @@ -139,8 +149,7 @@ HashAggregateInfo::HashAggregateInfo(const HashAggregateInfo& other) void HashAggregateLocalState::init(HashAggregateSharedState* sharedState, ResultSet& resultSet, main::ClientContext* context, std::vector& aggregateFunctions, - std::vector types) { - sharedState->registerThread(); + std::vector distinctKeyTypes) { auto& info = sharedState->getAggregateInfo(); std::vector keyDataTypes; for (auto& pos : info.flatKeysPos) { @@ -164,7 +173,7 @@ void HashAggregateLocalState::init(HashAggregateSharedState* sharedState, Result aggregateHashTable = std::make_unique(sharedState, *context->getMemoryManager(), std::move(keyDataTypes), std::move(payloadDataTypes), - aggregateFunctions, std::move(types), info.tableSchema.copy()); + aggregateFunctions, std::move(distinctKeyTypes), info.tableSchema.copy()); } uint64_t HashAggregateLocalState::append(const std::vector& aggregateInputs, @@ -193,15 +202,6 @@ void HashAggregate::executeInternal(ExecutionContext* context) { } } localState.aggregateHashTable->mergeAll(); - sharedState->setThreadFinishedProducing(); - while (!sharedState->allThreadsFinishedProducing()) { - std::this_thread::sleep_for(std::chrono::microseconds(500)); - } - sharedState->finalizeAggregateHashTable(*localState.aggregateHashTable); -} - -void HashAggregate::finalizeInternal(ExecutionContext*) { - sharedState->assertFinalized(); } uint64_t HashAggregateSharedState::getNumTuples() const { @@ -268,8 +268,7 @@ void HashAggregateSharedState::HashTableQueue::mergeInto(AggregateHashTable& has this->headBlock = nullptr; } -void HashAggregateSharedState::finalizeAggregateHashTable( - const AggregateHashTable& localHashTable) { +void HashAggregateSharedState::finalizeAggregateHashTable() { for (auto& partition : globalPartitions) { if (!partition.finalized && partition.mtx.try_lock()) { if (partition.finalized) { @@ -281,8 +280,9 @@ void HashAggregateSharedState::finalizeAggregateHashTable( continue; } if (!partition.hashTable) { - partition.hashTable = - std::make_unique(localHashTable.createEmptyCopy()); + // We always initialize the hash table in the first partition + partition.hashTable = std::make_unique( + globalPartitions[0].hashTable->createEmptyCopy()); } // TODO(bmwinger): ideally these can be merged into a single function. // The distinct tables need to be merged first so that they exist when the other table diff --git a/src/processor/operator/physical_operator.cpp b/src/processor/operator/physical_operator.cpp index 42f0330431f..1fa222aceba 100644 --- a/src/processor/operator/physical_operator.cpp +++ b/src/processor/operator/physical_operator.cpp @@ -16,6 +16,8 @@ std::string PhysicalOperatorUtils::operatorTypeToString(PhysicalOperatorType ope return "ALTER"; case PhysicalOperatorType::AGGREGATE: return "AGGREGATE"; + case PhysicalOperatorType::AGGREGATE_FINALIZE: + return "AGGREGATE_FINALIZE"; case PhysicalOperatorType::AGGREGATE_SCAN: return "AGGREGATE_SCAN"; case PhysicalOperatorType::ATTACH_DATABASE: