Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 122 additions & 2 deletions be/src/exec/operator/aggregation_sink_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "core/data_type/primitive_type.h"
#include "exec/common/hash_table/hash.h"
#include "exec/operator/operator.h"
#include "exprs/aggregate/aggregate_function_count.h"
#include "exprs/aggregate/aggregate_function_simple_factory.h"
#include "exprs/vectorized_agg_fn.h"
#include "runtime/runtime_profile.h"
Expand Down Expand Up @@ -156,6 +157,30 @@ Status AggSinkLocalState::open(RuntimeState* state) {
RETURN_IF_ERROR(_create_agg_status(_agg_data->without_key));
_shared_state->agg_data_created_without_key = true;
}

// Determine whether to use simple count aggregation.
// For queries like: SELECT xxx, count(*) / count(not_null_column) FROM table GROUP BY xxx,
// count(*) / count(not_null_column) can store a uint64 counter directly in the hash table,
// instead of storing the full aggregate state, saving memory and computation overhead.
// Requirements:
// 0. The aggregation has a GROUP BY clause.
// 1. There is exactly one count aggregate function.
// 2. No limit optimization is applied.
// 3. Spill is not enabled (the spill path accesses aggregate_data_container, which is empty in inline count mode).
// Supports update / merge / finalize / serialize phases, since count's serialization format is UInt64 itself.

if (!Base::_shared_state->probe_expr_ctxs.empty() /* has GROUP BY */
&& (p._aggregate_evaluators.size() == 1 &&
p._aggregate_evaluators[0]->function()->is_simple_count()) /* only one count(*) */
&& !_should_limit_output /* no limit optimization */ &&
!Base::_shared_state->enable_spill /* spill not enabled */) {
_shared_state->use_simple_count = true;
#ifndef NDEBUG
// Randomly enable/disable in debug mode to verify correctness of multi-phase agg promotion/demotion.
_shared_state->use_simple_count = rand() % 2 == 0;
#endif
}

return Status::OK();
}

Expand Down Expand Up @@ -335,7 +360,18 @@ Status AggSinkLocalState::_merge_with_serialized_key_helper(Block* block) {
key_columns, (uint32_t)rows);
rows = block->rows();
} else {
_emplace_into_hash_table(_places.data(), key_columns, (uint32_t)rows);
if (_shared_state->use_simple_count) {
DCHECK(!for_spill);

auto col_id = AggSharedState::get_slot_column_id(
Base::_shared_state->aggregate_evaluators[0]);

auto column = block->get_by_position(col_id).column;
_merge_into_hash_table_inline_count(key_columns, column.get(), (uint32_t)rows);
need_do_agg = false;
} else {
_emplace_into_hash_table(_places.data(), key_columns, (uint32_t)rows);
}
}

if (need_do_agg) {
Expand Down Expand Up @@ -496,7 +532,9 @@ Status AggSinkLocalState::_execute_with_serialized_key_helper(Block* block) {
}
} else {
_emplace_into_hash_table(_places.data(), key_columns, rows);
RETURN_IF_ERROR(do_aggregate_evaluators());
if (!_shared_state->use_simple_count) {
RETURN_IF_ERROR(do_aggregate_evaluators());
}

if (_should_limit_output && !Base::_shared_state->enable_spill) {
const size_t hash_table_size = get_hash_table_size();
Expand Down Expand Up @@ -524,6 +562,11 @@ size_t AggSinkLocalState::get_hash_table_size() const {

void AggSinkLocalState::_emplace_into_hash_table(AggregateDataPtr* places,
ColumnRawPtrs& key_columns, uint32_t num_rows) {
if (_shared_state->use_simple_count) {
_emplace_into_hash_table_inline_count(key_columns, num_rows);
return;
}

std::visit(Overload {[&](std::monostate& arg) -> void {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
Expand Down Expand Up @@ -570,6 +613,83 @@ void AggSinkLocalState::_emplace_into_hash_table(AggregateDataPtr* places,
_agg_data->method_variant);
}

// For the agg hashmap<key, value>, the value is a char* type which is exactly 64 bits.
// Here we treat it as a uint64 counter: each time the same key is encountered, the counter
// is incremented by 1. This avoids storing the full aggregate state, saving memory and computation overhead.
void AggSinkLocalState::_emplace_into_hash_table_inline_count(ColumnRawPtrs& key_columns,
uint32_t num_rows) {
std::visit(Overload {[&](std::monostate& arg) -> void {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
},
[&](auto& agg_method) -> void {
SCOPED_TIMER(_hash_table_compute_timer);
using HashMethodType = std::decay_t<decltype(agg_method)>;
using AggState = typename HashMethodType::State;
AggState state(key_columns);
agg_method.init_serialized_keys(key_columns, num_rows);

auto creator = [&](const auto& ctor, auto& key, auto& origin) {
HashMethodType::try_presis_key_and_origin(
key, origin, Base::_shared_state->agg_arena_pool);
AggregateDataPtr mapped = nullptr;
ctor(key, mapped);
};

auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; };

SCOPED_TIMER(_hash_table_emplace_timer);
lazy_emplace_batch(agg_method, state, num_rows, creator,
creator_for_null_key, [&](uint32_t, auto& mapped) {
++reinterpret_cast<UInt64&>(mapped);
});

COUNTER_UPDATE(_hash_table_input_counter, num_rows);
}},
_agg_data->method_variant);
}

void AggSinkLocalState::_merge_into_hash_table_inline_count(ColumnRawPtrs& key_columns,
const IColumn* merge_column,
uint32_t num_rows) {
std::visit(
Overload {[&](std::monostate& arg) -> void {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table");
},
[&](auto& agg_method) -> void {
SCOPED_TIMER(_hash_table_compute_timer);
using HashMethodType = std::decay_t<decltype(agg_method)>;
using AggState = typename HashMethodType::State;
AggState state(key_columns);
agg_method.init_serialized_keys(key_columns, num_rows);

const auto& col =
assert_cast<const ColumnFixedLengthObject&>(*merge_column);
const auto* col_data =
reinterpret_cast<const AggregateFunctionCountData*>(
col.get_data().data());

auto creator = [&](const auto& ctor, auto& key, auto& origin) {
HashMethodType::try_presis_key_and_origin(
key, origin, Base::_shared_state->agg_arena_pool);
AggregateDataPtr mapped = nullptr;
ctor(key, mapped);
};

auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; };

SCOPED_TIMER(_hash_table_emplace_timer);
lazy_emplace_batch(agg_method, state, num_rows, creator,
creator_for_null_key, [&](uint32_t i, auto& mapped) {
reinterpret_cast<UInt64&>(mapped) +=
col_data[i].count;
});

COUNTER_UPDATE(_hash_table_input_counter, num_rows);
}},
_agg_data->method_variant);
}

bool AggSinkLocalState::_emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block,
const std::vector<int>& key_locs,
ColumnRawPtrs& key_columns,
Expand Down
4 changes: 4 additions & 0 deletions be/src/exec/operator/aggregation_sink_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class AggSinkLocalState : public PipelineXSinkLocalState<AggSharedState> {
uint32_t num_rows);
void _emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns,
uint32_t num_rows);

void _emplace_into_hash_table_inline_count(ColumnRawPtrs& key_columns, uint32_t num_rows);
void _merge_into_hash_table_inline_count(ColumnRawPtrs& key_columns,
const IColumn* merge_column, uint32_t num_rows);
bool _emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block,
const std::vector<int>& key_locs,
ColumnRawPtrs& key_columns, uint32_t num_rows);
Expand Down
117 changes: 117 additions & 0 deletions be/src/exec/operator/aggregation_source_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <string>

#include "common/exception.h"
#include "core/column/column_fixed_length_object.h"
#include "exec/operator/operator.h"
#include "exprs/vectorized_agg_fn.h"
#include "exprs/vexpr_fwd.h"
Expand Down Expand Up @@ -131,6 +132,76 @@ Status AggLocalState::_get_results_with_serialized_key(RuntimeState* state, Bloc
const auto size = std::min(data.size(), size_t(state->batch_size()));
using KeyType = std::decay_t<decltype(agg_method)>::Key;
std::vector<KeyType> keys(size);

if (shared_state.use_simple_count) {
DCHECK_EQ(shared_state.aggregate_evaluators.size(), 1);

value_data_types[0] = shared_state.aggregate_evaluators[0]
->function()
->get_serialized_type();
if (mem_reuse) {
value_columns[0] =
std::move(*block->get_by_position(key_size).column)
.mutate();
} else {
value_columns[0] = shared_state.aggregate_evaluators[0]
->function()
->create_serialize_column();
}

std::vector<UInt64> inline_counts(size);
uint32_t num_rows = 0;
{
SCOPED_TIMER(_hash_table_iterate_timer);
auto& it = agg_method.begin;
while (it != agg_method.end && num_rows < state->batch_size()) {
keys[num_rows] = it.get_first();
inline_counts[num_rows] =
reinterpret_cast<const UInt64&>(it.get_second());
++it;
++num_rows;
}
}

{
SCOPED_TIMER(_insert_keys_to_column_timer);
agg_method.insert_keys_into_columns(keys, key_columns, num_rows);
}

// Write inline counts to serialized column
// AggregateFunctionCountData = { UInt64 count }, same layout as inline
auto& count_col =
assert_cast<ColumnFixedLengthObject&>(*value_columns[0]);
count_col.resize(num_rows);
auto* col_data = count_col.get_data().data();
for (uint32_t i = 0; i < num_rows; ++i) {
*reinterpret_cast<UInt64*>(col_data + i * sizeof(UInt64)) =
inline_counts[i];
}

// Handle null key if present
if (agg_method.begin == agg_method.end) {
if (agg_method.hash_table->has_null_key_data()) {
DCHECK(key_columns.size() == 1);
DCHECK(key_columns[0]->is_nullable());
if (num_rows < state->batch_size()) {
key_columns[0]->insert_data(nullptr, 0);
auto mapped =
agg_method.hash_table->template get_null_key_data<
AggregateDataPtr>();
count_col.resize(num_rows + 1);
*reinterpret_cast<UInt64*>(count_col.get_data().data() +
num_rows * sizeof(UInt64)) =
std::bit_cast<UInt64>(mapped);
*eos = true;
}
} else {
*eos = true;
}
}
return;
}

if (shared_state.values.size() < size + 1) {
shared_state.values.resize(size + 1);
}
Expand Down Expand Up @@ -255,6 +326,52 @@ Status AggLocalState::_get_with_serialized_key_result(RuntimeState* state, Block
const auto size = std::min(data.size(), size_t(state->batch_size()));
using KeyType = std::decay_t<decltype(agg_method)>::Key;
std::vector<KeyType> keys(size);

if (shared_state.use_simple_count) {
// Inline count: mapped slot stores UInt64 count directly
// (not a real AggregateDataPtr). Iterate hash table directly.
DCHECK_EQ(value_columns.size(), 1);
auto& count_column = assert_cast<ColumnInt64&>(*value_columns[0]);
uint32_t num_rows = 0;
{
SCOPED_TIMER(_hash_table_iterate_timer);
auto& it = agg_method.begin;
while (it != agg_method.end && num_rows < state->batch_size()) {
keys[num_rows] = it.get_first();
auto& mapped = it.get_second();
count_column.insert_value(static_cast<Int64>(
reinterpret_cast<const UInt64&>(mapped)));
++it;
++num_rows;
}
}
{
SCOPED_TIMER(_insert_keys_to_column_timer);
agg_method.insert_keys_into_columns(keys, key_columns, num_rows);
}

// Handle null key if present
if (agg_method.begin == agg_method.end) {
if (agg_method.hash_table->has_null_key_data()) {
DCHECK(key_columns.size() == 1);
DCHECK(key_columns[0]->is_nullable());
if (key_columns[0]->size() < state->batch_size()) {
key_columns[0]->insert_data(nullptr, 0);
auto mapped =
agg_method.hash_table->template get_null_key_data<
AggregateDataPtr>();
count_column.insert_value(
static_cast<Int64>(std::bit_cast<UInt64>(mapped)));
*eos = true;
}
} else {
*eos = true;
}
}
return;
}

// Normal (non-simple-count) path
if (shared_state.values.size() < size) {
shared_state.values.resize(size);
}
Expand Down
Loading
Loading