Skip to content

Commit

Permalink
Separate hash aggregate finalization into its own operator (#4913)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminwinger authored and ray6080 committed Feb 15, 2025
1 parent 92e5eae commit e6e4343
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 35 deletions.
53 changes: 42 additions & 11 deletions src/include/processor/operator/aggregate/hash_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <cstdint>
#include <memory>
#include <vector>

#include "aggregate_hash_table.h"
#include "common/copy_constructors.h"
Expand All @@ -14,6 +15,7 @@
#include "main/client_context.h"
#include "processor/operator/aggregate/aggregate_input.h"
#include "processor/operator/aggregate/base_aggregate.h"
#include "processor/operator/physical_operator.h"
#include "processor/result/factorized_table.h"
#include "processor/result/factorized_table_schema.h"

Expand All @@ -40,7 +42,8 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {
public:
explicit HashAggregateSharedState(main::ClientContext* context, HashAggregateInfo hashAggInfo,
const std::vector<function::AggregateFunction>& aggregateFunctions,
std::span<AggregateInfo> aggregateInfos);
std::span<AggregateInfo> aggregateInfos, std::vector<common::LogicalType> keyTypes,
std::vector<common::LogicalType> payloadTypes);

void appendTuple(std::span<uint8_t> tuple, common::hash_t hash) {
auto& partition =
Expand All @@ -59,7 +62,7 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {
overflow.push(std::make_unique<common::InMemOverflowBuffer>(std::move(overflowBuffer)));
}

void finalizeAggregateHashTable(const AggregateHashTable& localHashTable);
void finalizeAggregateHashTable();

std::pair<uint64_t, uint64_t> getNextRangeToRead() override;

Expand All @@ -78,15 +81,10 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {
return globalPartitions[0].hashTable->getTableSchema();
}

void setThreadFinishedProducing() { numThreadsFinishedProducing++; }
bool allThreadsFinishedProducing() const { return numThreadsFinishedProducing >= numThreads; }

void registerThread() { numThreads++; }
const HashAggregateInfo& getAggregateInfo() const { return aggInfo; }

void assertFinalized() const;

const HashAggregateInfo& getAggregateInfo() const { return aggInfo; }

protected:
std::tuple<const FactorizedTable*, common::offset_t> getPartitionForOffset(
common::offset_t offset) const;
Expand Down Expand Up @@ -152,10 +150,9 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {
};
std::vector<Partition> globalPartitions;
uint64_t limitNumber;
std::atomic<size_t> numThreadsFinishedProducing;
std::atomic<size_t> numThreads;
storage::MemoryManager* memoryManager;
uint8_t shiftForPartitioning;
bool readyForFinalization = false;
};

struct HashAggregateLocalState {
Expand Down Expand Up @@ -207,7 +204,10 @@ class HashAggregate final : public BaseAggregate {

void executeInternal(ExecutionContext* context) override;

void finalizeInternal(ExecutionContext* context) override;
// Delegated to HashAggregateFinalize so it can be parallelized
void finalizeInternal(ExecutionContext* /*context*/) override {
sharedState->readyForFinalization = true;
}

std::unique_ptr<PhysicalOperator> copy() override {
return make_unique<HashAggregate>(resultSetDescriptor->copy(), sharedState,
Expand All @@ -222,5 +222,36 @@ class HashAggregate final : public BaseAggregate {
std::shared_ptr<HashAggregateSharedState> sharedState;
};

class HashAggregateFinalize final : public Sink {
public:
HashAggregateFinalize(std::unique_ptr<ResultSetDescriptor> resultSetDescriptor,
std::shared_ptr<HashAggregateSharedState> sharedState,
std::unique_ptr<PhysicalOperator> child, uint32_t id,
std::unique_ptr<OPPrintInfo> printInfo)
: Sink{std::move(resultSetDescriptor), PhysicalOperatorType::AGGREGATE_FINALIZE,
std::move(child), id, std::move(printInfo)},
sharedState{std::move(sharedState)} {}

// Otherwise the runtime metrics for this operator are negative
// since it doesn't call children[0]->getNextTuple
bool isSource() const override { return true; }

void executeInternal(ExecutionContext* /*context*/) override {
KU_ASSERT(sharedState->readyForFinalization);
sharedState->finalizeAggregateHashTable();
}
void finalizeInternal(ExecutionContext* /*context*/) override {
sharedState->assertFinalized();
}

std::unique_ptr<PhysicalOperator> copy() override {
return make_unique<HashAggregateFinalize>(resultSetDescriptor->copy(), sharedState,
children[0]->copy(), id, printInfo->copy());
}

private:
std::shared_ptr<HashAggregateSharedState> sharedState;
};

} // namespace processor
} // namespace kuzu
1 change: 1 addition & 0 deletions src/include/processor/operator/physical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using physical_op_id = uint32_t;
enum class PhysicalOperatorType : uint8_t {
ALTER,
AGGREGATE,
AGGREGATE_FINALIZE,
AGGREGATE_SCAN,
ATTACH_DATABASE,
BATCH_INSERT,
Expand Down
22 changes: 19 additions & 3 deletions src/processor/map/map_aggregate.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include "binder/expression/aggregate_function_expression.h"
#include "common/types/types.h"
#include "planner/operator/logical_aggregate.h"
#include "processor/operator/aggregate/hash_aggregate.h"
#include "processor/operator/aggregate/hash_aggregate_scan.h"
#include "processor/operator/aggregate/simple_aggregate.h"
#include "processor/operator/aggregate/simple_aggregate_scan.h"
#include "processor/plan_mapper.h"
#include "processor/result/result_set_descriptor.h"

using namespace kuzu::binder;
using namespace kuzu::common;
Expand Down Expand Up @@ -143,12 +145,23 @@ std::unique_ptr<PhysicalOperator> PlanMapper::createHashAggregate(const expressi
auto aggregateInputInfos = getAggregateInputInfos(allKeys, aggregates, *inSchema);
auto flatKeys = getKeyExpressions(keys, *inSchema, true /* isFlat */);
auto unFlatKeys = getKeyExpressions(keys, *inSchema, false /* isFlat */);
std::vector<LogicalType> keyTypes, payloadTypes;
for (auto& key : flatKeys) {
keyTypes.push_back(key->getDataType().copy());
}
for (auto& key : unFlatKeys) {
keyTypes.push_back(key->getDataType().copy());
}
for (auto& payload : payloads) {
payloadTypes.push_back(payload->getDataType().copy());
}
auto tableSchema = getFactorizedTableSchema(flatKeys, unFlatKeys, payloads, aggFunctions);
HashAggregateInfo aggregateInfo{getDataPos(flatKeys, *inSchema),
getDataPos(unFlatKeys, *inSchema), getDataPos(payloads, *inSchema), std::move(tableSchema)};

auto sharedState = std::make_shared<HashAggregateSharedState>(clientContext,
std::move(aggregateInfo), aggFunctions, aggregateInputInfos);
auto sharedState =
std::make_shared<HashAggregateSharedState>(clientContext, std::move(aggregateInfo),
aggFunctions, aggregateInputInfos, std::move(keyTypes), std::move(payloadTypes));
auto printInfo = std::make_unique<HashAggregatePrintInfo>(allKeys, aggregates);
auto aggregate = make_unique<HashAggregate>(std::make_unique<ResultSetDescriptor>(inSchema),
sharedState, std::move(aggFunctions), std::move(aggregateInputInfos),
Expand All @@ -159,8 +172,11 @@ std::unique_ptr<PhysicalOperator> PlanMapper::createHashAggregate(const expressi
outputExpressions.insert(outputExpressions.end(), unFlatKeys.begin(), unFlatKeys.end());
outputExpressions.insert(outputExpressions.end(), payloads.begin(), payloads.end());
auto aggOutputPos = getDataPos(aggregates, *outSchema);
auto finalizer =
std::make_unique<HashAggregateFinalize>(std::make_unique<ResultSetDescriptor>(inSchema),
sharedState, std::move(aggregate), getOperatorID(), printInfo->copy());
return std::make_unique<HashAggregateScan>(sharedState,
getDataPos(outputExpressions, *outSchema), std::move(aggOutputPos), std::move(aggregate),
getDataPos(outputExpressions, *outSchema), std::move(aggOutputPos), std::move(finalizer),
getOperatorID(), printInfo->copy());
}

Expand Down
2 changes: 1 addition & 1 deletion src/processor/map/map_distinct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapDistinct(const LogicalOperator*
}
auto op = createDistinctHashAggregate(distinct->getKeys(), distinct->getPayloads(), inSchema,
outSchema, std::move(prevOperator));
auto hashAggregate = op->getChild(0)->ptrCast<HashAggregate>();
auto hashAggregate = op->getChild(0)->getChild(0)->ptrCast<HashAggregate>();
hashAggregate->getSharedState()->setLimitNumber(limitNum);
auto printInfo = static_cast<const HashAggregatePrintInfo*>(hashAggregate->getPrintInfo());
const_cast<HashAggregatePrintInfo*>(printInfo)->limitNum = limitNum;
Expand Down
40 changes: 20 additions & 20 deletions src/processor/operator/aggregate/hash_aggregate.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "processor/operator/aggregate/hash_aggregate.h"

#include <chrono>
#include <thread>
#include <memory>

#include "binder/expression/expression_util.h"
#include "common/assert.h"
Expand Down Expand Up @@ -35,14 +34,19 @@ std::string HashAggregatePrintInfo::toString() const {
HashAggregateSharedState::HashAggregateSharedState(main::ClientContext* context,
HashAggregateInfo hashAggInfo,
const std::vector<function::AggregateFunction>& aggregateFunctions,
std::span<AggregateInfo> aggregateInfos)
std::span<AggregateInfo> aggregateInfos, std::vector<LogicalType> keyTypes,
std::vector<LogicalType> payloadTypes)
: BaseAggregateSharedState{aggregateFunctions}, aggInfo{std::move(hashAggInfo)},
globalPartitions{static_cast<size_t>(context->getMaxNumThreadForExec())},
limitNumber{common::INVALID_LIMIT}, numThreads{0}, memoryManager{context->getMemoryManager()},
limitNumber{common::INVALID_LIMIT}, memoryManager{context->getMemoryManager()},
// .size() - 1 since we want the bit width of the largest value that could be used to index
// the partitions
shiftForPartitioning{
static_cast<uint8_t>(sizeof(hash_t) * 8 - std::bit_width(globalPartitions.size() - 1))} {
std::vector<LogicalType> distinctAggregateKeyTypes;
for (auto& aggInfo : aggregateInfos) {
distinctAggregateKeyTypes.push_back(aggInfo.distinctAggKeyType.copy());
}

// When copying directly into factorizedTables the table's schema's internal mayContainNulls
// won't be updated and it's probably less work to just always check nulls
Expand All @@ -54,6 +58,12 @@ HashAggregateSharedState::HashAggregateSharedState(main::ClientContext* context,
auto& partition = globalPartitions[0];
partition.queue = std::make_unique<HashTableQueue>(context->getMemoryManager(),
this->aggInfo.tableSchema.copy());

// Always create a hash table for the first partition. Any other partitions which are non-empty
// when finalizing will create an empty copy of this table
partition.hashTable = std::make_unique<AggregateHashTable>(*context->getMemoryManager(),
std::move(keyTypes), std::move(payloadTypes), aggregateFunctions, distinctAggregateKeyTypes,
0, this->aggInfo.tableSchema.copy());
for (size_t functionIdx = 0; functionIdx < aggregateFunctions.size(); functionIdx++) {
auto& function = aggregateFunctions[functionIdx];
if (function.isFunctionDistinct()) {
Expand Down Expand Up @@ -139,8 +149,7 @@ HashAggregateInfo::HashAggregateInfo(const HashAggregateInfo& other)

void HashAggregateLocalState::init(HashAggregateSharedState* sharedState, ResultSet& resultSet,
main::ClientContext* context, std::vector<function::AggregateFunction>& aggregateFunctions,
std::vector<common::LogicalType> types) {
sharedState->registerThread();
std::vector<common::LogicalType> distinctKeyTypes) {
auto& info = sharedState->getAggregateInfo();
std::vector<LogicalType> keyDataTypes;
for (auto& pos : info.flatKeysPos) {
Expand All @@ -164,7 +173,7 @@ void HashAggregateLocalState::init(HashAggregateSharedState* sharedState, Result

aggregateHashTable = std::make_unique<PartitioningAggregateHashTable>(sharedState,
*context->getMemoryManager(), std::move(keyDataTypes), std::move(payloadDataTypes),
aggregateFunctions, std::move(types), info.tableSchema.copy());
aggregateFunctions, std::move(distinctKeyTypes), info.tableSchema.copy());
}

uint64_t HashAggregateLocalState::append(const std::vector<AggregateInput>& aggregateInputs,
Expand Down Expand Up @@ -193,15 +202,6 @@ void HashAggregate::executeInternal(ExecutionContext* context) {
}
}
localState.aggregateHashTable->mergeAll();
sharedState->setThreadFinishedProducing();
while (!sharedState->allThreadsFinishedProducing()) {
std::this_thread::sleep_for(std::chrono::microseconds(500));
}
sharedState->finalizeAggregateHashTable(*localState.aggregateHashTable);
}

void HashAggregate::finalizeInternal(ExecutionContext*) {
sharedState->assertFinalized();
}

uint64_t HashAggregateSharedState::getNumTuples() const {
Expand Down Expand Up @@ -268,8 +268,7 @@ void HashAggregateSharedState::HashTableQueue::mergeInto(AggregateHashTable& has
this->headBlock = nullptr;
}

void HashAggregateSharedState::finalizeAggregateHashTable(
const AggregateHashTable& localHashTable) {
void HashAggregateSharedState::finalizeAggregateHashTable() {
for (auto& partition : globalPartitions) {
if (!partition.finalized && partition.mtx.try_lock()) {
if (partition.finalized) {
Expand All @@ -281,8 +280,9 @@ void HashAggregateSharedState::finalizeAggregateHashTable(
continue;
}
if (!partition.hashTable) {
partition.hashTable =
std::make_unique<AggregateHashTable>(localHashTable.createEmptyCopy());
// We always initialize the hash table in the first partition
partition.hashTable = std::make_unique<AggregateHashTable>(
globalPartitions[0].hashTable->createEmptyCopy());
}
// TODO(bmwinger): ideally these can be merged into a single function.
// The distinct tables need to be merged first so that they exist when the other table
Expand Down
2 changes: 2 additions & 0 deletions src/processor/operator/physical_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ std::string PhysicalOperatorUtils::operatorTypeToString(PhysicalOperatorType ope
return "ALTER";
case PhysicalOperatorType::AGGREGATE:
return "AGGREGATE";
case PhysicalOperatorType::AGGREGATE_FINALIZE:
return "AGGREGATE_FINALIZE";
case PhysicalOperatorType::AGGREGATE_SCAN:
return "AGGREGATE_SCAN";
case PhysicalOperatorType::ATTACH_DATABASE:
Expand Down

0 comments on commit e6e4343

Please sign in to comment.