From 8a2448302d0a6c431c0b653c73aaa6cc39e9411d Mon Sep 17 00:00:00 2001 From: Mryange Date: Sun, 15 Mar 2026 21:38:02 +0800 Subject: [PATCH 1/4] upd --- .../operator/aggregation_sink_operator.cpp | 125 ++++++++++- .../exec/operator/aggregation_sink_operator.h | 4 + .../operator/aggregation_source_operator.cpp | 117 ++++++++++ .../streaming_aggregation_operator.cpp | 202 ++++++++++++++---- .../operator/streaming_aggregation_operator.h | 7 + be/src/exec/pipeline/dependency.h | 6 + be/src/exprs/aggregate/aggregate_function.h | 2 + .../aggregate/aggregate_function_count.h | 1 + 8 files changed, 426 insertions(+), 38 deletions(-) diff --git a/be/src/exec/operator/aggregation_sink_operator.cpp b/be/src/exec/operator/aggregation_sink_operator.cpp index 09fb3b011f7f8e..adb35759ffed8d 100644 --- a/be/src/exec/operator/aggregation_sink_operator.cpp +++ b/be/src/exec/operator/aggregation_sink_operator.cpp @@ -25,6 +25,7 @@ #include "core/data_type/primitive_type.h" #include "exec/common/hash_table/hash.h" #include "exec/operator/operator.h" +#include "exprs/aggregate/aggregate_function_count.h" #include "exprs/aggregate/aggregate_function_simple_factory.h" #include "exprs/vectorized_agg_fn.h" #include "runtime/runtime_profile.h" @@ -156,6 +157,30 @@ Status AggSinkLocalState::open(RuntimeState* state) { RETURN_IF_ERROR(_create_agg_status(_agg_data->without_key)); _shared_state->agg_data_created_without_key = true; } + + // Determine whether to use simple count aggregation. + // For queries like: SELECT xxx, count(*) / count(not_null_column) FROM table GROUP BY xxx, + // count(*) / count(not_null_column) can store a uint64 counter directly in the hash table, + // instead of storing the full aggregate state, saving memory and computation overhead. + // Requirements: + // 0. The aggregation has a GROUP BY clause. + // 1. There is exactly one count aggregate function. + // 2. No limit optimization is applied. + // 3. Spill is not enabled (the spill path accesses aggregate_data_container, which is empty in inline count mode). + // Supports update / merge / finalize / serialize phases, since count's serialization format is UInt64 itself. + + if (!Base::_shared_state->probe_expr_ctxs.empty() /* has GROUP BY */ + && (p._aggregate_evaluators.size() == 1 && + p._aggregate_evaluators[0]->function()->is_simple_count()) /* only one count(*) */ + && !_should_limit_output /* no limit optimization */ && + !Base::_shared_state->enable_spill /* spill not enabled */) { + _shared_state->use_simple_count = true; +#ifndef NDEBUG + // Randomly enable/disable in debug mode to verify correctness of multi-phase agg promotion/demotion. + _shared_state->use_simple_count = rand() % 2 == 0; +#endif + } + return Status::OK(); } @@ -335,7 +360,18 @@ Status AggSinkLocalState::_merge_with_serialized_key_helper(Block* block) { key_columns, (uint32_t)rows); rows = block->rows(); } else { - _emplace_into_hash_table(_places.data(), key_columns, (uint32_t)rows); + if (_shared_state->use_simple_count) { + DCHECK(!for_spill); + + auto col_id = AggSharedState::get_slot_column_id( + Base::_shared_state->aggregate_evaluators[0]); + + auto column = block->get_by_position(col_id).column; + _merge_into_hash_table_inline_count(key_columns, column.get(), (uint32_t)rows); + need_do_agg = false; + } else { + _emplace_into_hash_table(_places.data(), key_columns, (uint32_t)rows); + } } if (need_do_agg) { @@ -496,7 +532,9 @@ Status AggSinkLocalState::_execute_with_serialized_key_helper(Block* block) { } } else { _emplace_into_hash_table(_places.data(), key_columns, rows); - RETURN_IF_ERROR(do_aggregate_evaluators()); + if (!_shared_state->use_simple_count) { + RETURN_IF_ERROR(do_aggregate_evaluators()); + } if (_should_limit_output && !Base::_shared_state->enable_spill) { const size_t hash_table_size = get_hash_table_size(); @@ -524,6 +562,11 @@ size_t AggSinkLocalState::get_hash_table_size() const { void AggSinkLocalState::_emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, uint32_t num_rows) { + if (_shared_state->use_simple_count) { + _emplace_into_hash_table_inline_count(key_columns, num_rows); + return; + } + std::visit(Overload {[&](std::monostate& arg) -> void { throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); @@ -570,6 +613,84 @@ void AggSinkLocalState::_emplace_into_hash_table(AggregateDataPtr* places, _agg_data->method_variant); } +// For the agg hashmap, the value is a char* type which is exactly 64 bits. +// Here we treat it as a uint64 counter: each time the same key is encountered, the counter +// is incremented by 1. This avoids storing the full aggregate state, saving memory and computation overhead. +void AggSinkLocalState::_emplace_into_hash_table_inline_count(ColumnRawPtrs& key_columns, + uint32_t num_rows) { + std::visit(Overload {[&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "uninited hash table"); + }, + [&](auto& agg_method) -> void { + SCOPED_TIMER(_hash_table_compute_timer); + using HashMethodType = std::decay_t; + using AggState = typename HashMethodType::State; + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + + auto creator = [&](const auto& ctor, auto& key, auto& origin) { + HashMethodType::try_presis_key_and_origin( + key, origin, Base::_shared_state->agg_arena_pool); + AggregateDataPtr mapped = nullptr; + ctor(key, mapped); + }; + + auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; + + SCOPED_TIMER(_hash_table_emplace_timer); + for (size_t i = 0; i < num_rows; ++i) { + auto* mapped_ptr = agg_method.lazy_emplace(state, i, creator, + creator_for_null_key); + ++reinterpret_cast(*mapped_ptr); + } + + COUNTER_UPDATE(_hash_table_input_counter, num_rows); + }}, + _agg_data->method_variant); +} + +void AggSinkLocalState::_merge_into_hash_table_inline_count(ColumnRawPtrs& key_columns, + const IColumn* merge_column, + uint32_t num_rows) { + std::visit(Overload {[&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "uninited hash table"); + }, + [&](auto& agg_method) -> void { + SCOPED_TIMER(_hash_table_compute_timer); + using HashMethodType = std::decay_t; + using AggState = typename HashMethodType::State; + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + + const auto& col = + assert_cast(*merge_column); + const auto* col_data = + reinterpret_cast( + col.get_data().data()); + + auto creator = [&](const auto& ctor, auto& key, auto& origin) { + HashMethodType::try_presis_key_and_origin( + key, origin, Base::_shared_state->agg_arena_pool); + AggregateDataPtr mapped = nullptr; + ctor(key, mapped); + }; + + auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; + + SCOPED_TIMER(_hash_table_emplace_timer); + for (size_t i = 0; i < num_rows; ++i) { + auto* mapped_ptr = agg_method.lazy_emplace(state, i, creator, + creator_for_null_key); + reinterpret_cast(*mapped_ptr) += col_data[i].count; + } + + COUNTER_UPDATE(_hash_table_input_counter, num_rows); + }}, + _agg_data->method_variant); +} + bool AggSinkLocalState::_emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block, const std::vector& key_locs, ColumnRawPtrs& key_columns, diff --git a/be/src/exec/operator/aggregation_sink_operator.h b/be/src/exec/operator/aggregation_sink_operator.h index 9774d2b95e512a..0a7067ecb4130a 100644 --- a/be/src/exec/operator/aggregation_sink_operator.h +++ b/be/src/exec/operator/aggregation_sink_operator.h @@ -88,6 +88,10 @@ class AggSinkLocalState : public PipelineXSinkLocalState { uint32_t num_rows); void _emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, uint32_t num_rows); + + void _emplace_into_hash_table_inline_count(ColumnRawPtrs& key_columns, uint32_t num_rows); + void _merge_into_hash_table_inline_count(ColumnRawPtrs& key_columns, + const IColumn* merge_column, uint32_t num_rows); bool _emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block, const std::vector& key_locs, ColumnRawPtrs& key_columns, uint32_t num_rows); diff --git a/be/src/exec/operator/aggregation_source_operator.cpp b/be/src/exec/operator/aggregation_source_operator.cpp index f94a64c4a3e5bc..a142f89e29633f 100644 --- a/be/src/exec/operator/aggregation_source_operator.cpp +++ b/be/src/exec/operator/aggregation_source_operator.cpp @@ -21,6 +21,7 @@ #include #include "common/exception.h" +#include "core/column/column_fixed_length_object.h" #include "exec/operator/operator.h" #include "exprs/vectorized_agg_fn.h" #include "exprs/vexpr_fwd.h" @@ -131,6 +132,76 @@ Status AggLocalState::_get_results_with_serialized_key(RuntimeState* state, Bloc const auto size = std::min(data.size(), size_t(state->batch_size())); using KeyType = std::decay_t::Key; std::vector keys(size); + + if (shared_state.use_simple_count) { + DCHECK_EQ(shared_state.aggregate_evaluators.size(), 1); + + value_data_types[0] = shared_state.aggregate_evaluators[0] + ->function() + ->get_serialized_type(); + if (mem_reuse) { + value_columns[0] = + std::move(*block->get_by_position(key_size).column) + .mutate(); + } else { + value_columns[0] = shared_state.aggregate_evaluators[0] + ->function() + ->create_serialize_column(); + } + + std::vector inline_counts(size); + uint32_t num_rows = 0; + { + SCOPED_TIMER(_hash_table_iterate_timer); + auto& it = agg_method.begin; + while (it != agg_method.end && num_rows < state->batch_size()) { + keys[num_rows] = it.get_first(); + inline_counts[num_rows] = + reinterpret_cast(it.get_second()); + ++it; + ++num_rows; + } + } + + { + SCOPED_TIMER(_insert_keys_to_column_timer); + agg_method.insert_keys_into_columns(keys, key_columns, num_rows); + } + + // Write inline counts to serialized column + // AggregateFunctionCountData = { UInt64 count }, same layout as inline + auto& count_col = + assert_cast(*value_columns[0]); + count_col.resize(num_rows); + auto* col_data = count_col.get_data().data(); + for (uint32_t i = 0; i < num_rows; ++i) { + *reinterpret_cast(col_data + i * sizeof(UInt64)) = + inline_counts[i]; + } + + // Handle null key if present + if (agg_method.begin == agg_method.end) { + if (agg_method.hash_table->has_null_key_data()) { + DCHECK(key_columns.size() == 1); + DCHECK(key_columns[0]->is_nullable()); + if (num_rows < state->batch_size()) { + key_columns[0]->insert_data(nullptr, 0); + auto mapped = + agg_method.hash_table->template get_null_key_data< + AggregateDataPtr>(); + count_col.resize(num_rows + 1); + *reinterpret_cast(count_col.get_data().data() + + num_rows * sizeof(UInt64)) = + std::bit_cast(mapped); + *eos = true; + } + } else { + *eos = true; + } + } + return; + } + if (shared_state.values.size() < size + 1) { shared_state.values.resize(size + 1); } @@ -255,6 +326,52 @@ Status AggLocalState::_get_with_serialized_key_result(RuntimeState* state, Block const auto size = std::min(data.size(), size_t(state->batch_size())); using KeyType = std::decay_t::Key; std::vector keys(size); + + if (shared_state.use_simple_count) { + // Inline count: mapped slot stores UInt64 count directly + // (not a real AggregateDataPtr). Iterate hash table directly. + DCHECK_EQ(value_columns.size(), 1); + auto& count_column = assert_cast(*value_columns[0]); + uint32_t num_rows = 0; + { + SCOPED_TIMER(_hash_table_iterate_timer); + auto& it = agg_method.begin; + while (it != agg_method.end && num_rows < state->batch_size()) { + keys[num_rows] = it.get_first(); + auto& mapped = it.get_second(); + count_column.insert_value(static_cast( + reinterpret_cast(mapped))); + ++it; + ++num_rows; + } + } + { + SCOPED_TIMER(_insert_keys_to_column_timer); + agg_method.insert_keys_into_columns(keys, key_columns, num_rows); + } + + // Handle null key if present + if (agg_method.begin == agg_method.end) { + if (agg_method.hash_table->has_null_key_data()) { + DCHECK(key_columns.size() == 1); + DCHECK(key_columns[0]->is_nullable()); + if (key_columns[0]->size() < state->batch_size()) { + key_columns[0]->insert_data(nullptr, 0); + auto mapped = + agg_method.hash_table->template get_null_key_data< + AggregateDataPtr>(); + count_column.insert_value( + static_cast(std::bit_cast(mapped))); + *eos = true; + } + } else { + *eos = true; + } + } + return; + } + + // Normal (non-simple-count) path if (shared_state.values.size() < size) { shared_state.values.resize(size); } diff --git a/be/src/exec/operator/streaming_aggregation_operator.cpp b/be/src/exec/operator/streaming_aggregation_operator.cpp index 5ea488cf7e7697..cae652678109a7 100644 --- a/be/src/exec/operator/streaming_aggregation_operator.cpp +++ b/be/src/exec/operator/streaming_aggregation_operator.cpp @@ -24,8 +24,10 @@ #include "common/cast_set.h" #include "common/compiler_util.h" // IWYU pragma: keep +#include "core/column/column_fixed_length_object.h" #include "exec/operator/operator.h" #include "exec/operator/streaming_agg_min_reduction.h" +#include "exprs/aggregate/aggregate_function_count.h" #include "exprs/aggregate/aggregate_function_simple_factory.h" #include "exprs/vectorized_agg_fn.h" #include "exprs/vslot_ref.h" @@ -97,22 +99,36 @@ Status StreamingAggLocalState::open(RuntimeState* state) { RETURN_IF_ERROR(_init_hash_method(_probe_expr_ctxs)); - std::visit(Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - }, - [&](auto& agg_method) { - using HashTableType = std::decay_t; - using KeyType = typename HashTableType::Key; - - /// some aggregate functions (like AVG for decimal) have align issues. - _aggregate_data_container = std::make_unique( - sizeof(KeyType), ((p._total_size_of_aggregate_states + - p._align_aggregate_states - 1) / - p._align_aggregate_states) * - p._align_aggregate_states); - }}, - _agg_data->method_variant); + // Determine whether to use simple count aggregation. + // StreamingAgg only operates in update + serialize mode: input is raw data, output is serialized intermediate state. + // The serialization format of count is UInt64 itself, so it can be inlined into the hash table mapped slot. + if (_aggregate_evaluators.size() == 1 && + _aggregate_evaluators[0]->function()->is_simple_count()) { + _use_simple_count = true; +#ifndef NDEBUG + // Randomly enable/disable in debug mode to verify correctness of multi-phase agg promotion/demotion. + _use_simple_count = rand() % 2 == 0; +#endif + } + + std::visit( + Overload {[&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + }, + [&](auto& agg_method) { + using HashTableType = std::decay_t; + using KeyType = typename HashTableType::Key; + + if (!_use_simple_count) { + /// some aggregate functions (like AVG for decimal) have align issues. + _aggregate_data_container = std::make_unique( + sizeof(KeyType), ((p._total_size_of_aggregate_states + + p._align_aggregate_states - 1) / + p._align_aggregate_states) * + p._align_aggregate_states); + } + }}, + _agg_data->method_variant); limit = p._sort_limit; do_sort_limit = p._do_sort_limit; @@ -139,8 +155,11 @@ void StreamingAggLocalState::_update_memusage_with_serialized_key() { }, [&](auto& agg_method) -> void { auto& data = *agg_method.hash_table; - int64_t arena_memory_usage = _agg_arena_pool.size() + - _aggregate_data_container->memory_usage(); + int64_t arena_memory_usage = + _agg_arena_pool.size() + + (_aggregate_data_container + ? _aggregate_data_container->memory_usage() + : 0); int64_t hash_table_memory_usage = data.get_buffer_size_in_bytes(); COUNTER_SET(_memory_used_counter, @@ -274,23 +293,22 @@ bool StreamingAggLocalState::_should_not_do_pre_agg(size_t rows) { const auto spill_streaming_agg_mem_limit = p._spill_streaming_agg_mem_limit; const bool used_too_much_memory = spill_streaming_agg_mem_limit > 0 && _memory_usage() > spill_streaming_agg_mem_limit; - std::visit( - Overload { - [&](std::monostate& arg) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); - }, - [&](auto& agg_method) { - auto& hash_tbl = *agg_method.hash_table; - /// If too much memory is used during the pre-aggregation stage, - /// it is better to output the data directly without performing further aggregation. - // do not try to do agg, just init and serialize directly return the out_block - if (used_too_much_memory || (hash_tbl.add_elem_size_overflow(rows) && - !_should_expand_preagg_hash_tables())) { - SCOPED_TIMER(_streaming_agg_timer); - ret_flag = true; - } - }}, - _agg_data->method_variant); + std::visit(Overload {[&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "uninited hash table"); + }, + [&](auto& agg_method) { + auto& hash_tbl = *agg_method.hash_table; + /// If too much memory is used during the pre-aggregation stage, + /// it is better to output the data directly without performing further aggregation. + // do not try to do agg, just init and serialize directly return the out_block + if (used_too_much_memory || (hash_tbl.add_elem_size_overflow(rows) && + !_should_expand_preagg_hash_tables())) { + SCOPED_TIMER(_streaming_agg_timer); + ret_flag = true; + } + }}, + _agg_data->method_variant); return ret_flag; } @@ -388,7 +406,12 @@ Status StreamingAggLocalState::_pre_agg_with_serialized_key(doris::Block* in_blo } else { bool need_agg = true; if (need_do_sort_limit != 1) { - _emplace_into_hash_table(_places.data(), key_columns, rows); + if (_use_simple_count) { + _emplace_into_hash_table_inline_count(key_columns, rows); + need_agg = false; + } else { + _emplace_into_hash_table(_places.data(), key_columns, rows); + } } else { need_agg = _emplace_into_hash_table_limit(_places.data(), in_block, key_columns, rows); } @@ -456,6 +479,74 @@ Status StreamingAggLocalState::_get_results_with_serialized_key(RuntimeState* st const auto size = std::min(data.size(), size_t(state->batch_size())); using KeyType = std::decay_t::Key; std::vector keys(size); + + if (_use_simple_count) { + DCHECK_EQ(_aggregate_evaluators.size(), 1); + + value_data_types[0] = + _aggregate_evaluators[0]->function()->get_serialized_type(); + if (mem_reuse) { + value_columns[0] = + std::move(*block->get_by_position(key_size).column) + .mutate(); + } else { + value_columns[0] = _aggregate_evaluators[0] + ->function() + ->create_serialize_column(); + } + + std::vector inline_counts(size); + uint32_t num_rows = 0; + { + SCOPED_TIMER(_hash_table_iterate_timer); + auto& it = agg_method.begin; + while (it != agg_method.end && num_rows < state->batch_size()) { + keys[num_rows] = it.get_first(); + inline_counts[num_rows] = + reinterpret_cast(it.get_second()); + ++it; + ++num_rows; + } + } + + { + SCOPED_TIMER(_insert_keys_to_column_timer); + agg_method.insert_keys_into_columns(keys, key_columns, num_rows); + } + + // Write inline counts to serialized column + auto& count_col = + assert_cast(*value_columns[0]); + count_col.resize(num_rows); + auto* col_data = count_col.get_data().data(); + for (uint32_t i = 0; i < num_rows; ++i) { + *reinterpret_cast(col_data + i * sizeof(UInt64)) = + inline_counts[i]; + } + + // Handle null key if present + if (agg_method.begin == agg_method.end) { + if (agg_method.hash_table->has_null_key_data()) { + DCHECK(key_columns.size() == 1); + DCHECK(key_columns[0]->is_nullable()); + if (num_rows < state->batch_size()) { + key_columns[0]->insert_data(nullptr, 0); + auto mapped = + agg_method.hash_table->template get_null_key_data< + AggregateDataPtr>(); + count_col.resize(num_rows + 1); + *reinterpret_cast(count_col.get_data().data() + + num_rows * sizeof(UInt64)) = + std::bit_cast(mapped); + *eos = true; + } + } else { + *eos = true; + } + } + return; + } + if (_values.size() < size + 1) { _values.resize(size + 1); } @@ -728,6 +819,11 @@ bool StreamingAggLocalState::_do_limit_filter(size_t num_rows, ColumnRawPtrs& ke void StreamingAggLocalState::_emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, const uint32_t num_rows) { + if (_use_simple_count) { + _emplace_into_hash_table_inline_count(key_columns, num_rows); + return; + } + std::visit(Overload {[&](std::monostate& arg) -> void { throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); @@ -772,6 +868,40 @@ void StreamingAggLocalState::_emplace_into_hash_table(AggregateDataPtr* places, _agg_data->method_variant); } +void StreamingAggLocalState::_emplace_into_hash_table_inline_count(ColumnRawPtrs& key_columns, + uint32_t num_rows) { + std::visit(Overload {[&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "uninited hash table"); + }, + [&](auto& agg_method) -> void { + SCOPED_TIMER(_hash_table_compute_timer); + using HashMethodType = std::decay_t; + using AggState = typename HashMethodType::State; + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + + auto creator = [&](const auto& ctor, auto& key, auto& origin) { + HashMethodType::try_presis_key_and_origin(key, origin, + _agg_arena_pool); + AggregateDataPtr mapped = nullptr; + ctor(key, mapped); + }; + + auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; + + SCOPED_TIMER(_hash_table_emplace_timer); + for (size_t i = 0; i < num_rows; ++i) { + auto* mapped_ptr = agg_method.lazy_emplace(state, i, creator, + creator_for_null_key); + ++reinterpret_cast(*mapped_ptr); + } + + COUNTER_UPDATE(_hash_table_input_counter, num_rows); + }}, + _agg_data->method_variant); +} + StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* pool, int operator_id, const TPlanNode& tnode, const DescriptorTbl& descs) : StatefulOperatorX(pool, tnode, operator_id, descs), diff --git a/be/src/exec/operator/streaming_aggregation_operator.h b/be/src/exec/operator/streaming_aggregation_operator.h index cd4ab29b068180..cf1100f8dc126c 100644 --- a/be/src/exec/operator/streaming_aggregation_operator.h +++ b/be/src/exec/operator/streaming_aggregation_operator.h @@ -68,6 +68,7 @@ class StreamingAggLocalState MOCK_REMOVE(final) : public PipelineXLocalState _aggregate_data_container = nullptr; + bool _use_simple_count = false; bool _reach_limit = false; size_t _input_num_rows = 0; @@ -178,6 +180,11 @@ class StreamingAggLocalState MOCK_REMOVE(final) : public PipelineXLocalState void { + if (_use_simple_count) { + // Inline count: mapped slots hold UInt64, + // not real agg state pointers. Skip destroy. + return; + } auto& data = *agg_method.hash_table; data.for_each_mapped([&](auto& mapped) { if (mapped) { diff --git a/be/src/exec/pipeline/dependency.h b/be/src/exec/pipeline/dependency.h index 33c2ebfa2f3a33..a88d25e4c50721 100644 --- a/be/src/exec/pipeline/dependency.h +++ b/be/src/exec/pipeline/dependency.h @@ -320,6 +320,7 @@ struct AggSharedState : public BasicSharedState { bool enable_spill = false; bool reach_limit = false; + bool use_simple_count = false; int64_t limit = -1; bool do_sort_limit = false; MutableColumns limit_columns; @@ -392,6 +393,11 @@ struct AggSharedState : public BasicSharedState { // Do nothing }, [&](auto& agg_method) -> void { + if (use_simple_count) { + // Inline count: mapped slots hold UInt64, + // not real agg state pointers. Skip destroy. + return; + } auto& data = *agg_method.hash_table; data.for_each_mapped([&](auto& mapped) { if (mapped) { diff --git a/be/src/exprs/aggregate/aggregate_function.h b/be/src/exprs/aggregate/aggregate_function.h index 475439cd39ce1d..d7c97a3f944d56 100644 --- a/be/src/exprs/aggregate/aggregate_function.h +++ b/be/src/exprs/aggregate/aggregate_function.h @@ -263,6 +263,8 @@ class IAggregateFunction { virtual bool is_blockable() const { return false; } + virtual bool is_simple_count() const { return false; } + /** * Executes the aggregate function in incremental mode. * This is a virtual function that should be overridden by aggregate functions supporting incremental calculation. diff --git a/be/src/exprs/aggregate/aggregate_function_count.h b/be/src/exprs/aggregate/aggregate_function_count.h index 35317a6240ac77..3bc825a4a5a99e 100644 --- a/be/src/exprs/aggregate/aggregate_function_count.h +++ b/be/src/exprs/aggregate/aggregate_function_count.h @@ -57,6 +57,7 @@ class AggregateFunctionCount final AggregateFunctionCount(const DataTypes& argument_types_) : IAggregateFunctionDataHelper(argument_types_) {} + bool is_simple_count() const override { return true; } String get_name() const override { return "count"; } DataTypePtr get_return_type() const override { return std::make_shared(); } From 0ec4bbb12fa56c39f7554d342f693f825407b403 Mon Sep 17 00:00:00 2001 From: Mryange Date: Thu, 12 Mar 2026 16:21:51 +0800 Subject: [PATCH 2/4] format --- .../streaming_aggregation_operator.cpp | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/be/src/exec/operator/streaming_aggregation_operator.cpp b/be/src/exec/operator/streaming_aggregation_operator.cpp index cae652678109a7..d419e015ffcf34 100644 --- a/be/src/exec/operator/streaming_aggregation_operator.cpp +++ b/be/src/exec/operator/streaming_aggregation_operator.cpp @@ -293,22 +293,23 @@ bool StreamingAggLocalState::_should_not_do_pre_agg(size_t rows) { const auto spill_streaming_agg_mem_limit = p._spill_streaming_agg_mem_limit; const bool used_too_much_memory = spill_streaming_agg_mem_limit > 0 && _memory_usage() > spill_streaming_agg_mem_limit; - std::visit(Overload {[&](std::monostate& arg) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - }, - [&](auto& agg_method) { - auto& hash_tbl = *agg_method.hash_table; - /// If too much memory is used during the pre-aggregation stage, - /// it is better to output the data directly without performing further aggregation. - // do not try to do agg, just init and serialize directly return the out_block - if (used_too_much_memory || (hash_tbl.add_elem_size_overflow(rows) && - !_should_expand_preagg_hash_tables())) { - SCOPED_TIMER(_streaming_agg_timer); - ret_flag = true; - } - }}, - _agg_data->method_variant); + std::visit( + Overload { + [&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + }, + [&](auto& agg_method) { + auto& hash_tbl = *agg_method.hash_table; + /// If too much memory is used during the pre-aggregation stage, + /// it is better to output the data directly without performing further aggregation. + // do not try to do agg, just init and serialize directly return the out_block + if (used_too_much_memory || (hash_tbl.add_elem_size_overflow(rows) && + !_should_expand_preagg_hash_tables())) { + SCOPED_TIMER(_streaming_agg_timer); + ret_flag = true; + } + }}, + _agg_data->method_variant); return ret_flag; } From 5ce94ed68cfcd9ff50ec9940bfb55094490fbb85 Mon Sep 17 00:00:00 2001 From: Mryange Date: Wed, 18 Mar 2026 18:58:02 +0800 Subject: [PATCH 3/4] use lazy_emplace_batch --- .../operator/aggregation_sink_operator.cpp | 19 +++++++++---------- .../streaming_aggregation_operator.cpp | 11 +++++------ 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/be/src/exec/operator/aggregation_sink_operator.cpp b/be/src/exec/operator/aggregation_sink_operator.cpp index adb35759ffed8d..c02d4186c32f39 100644 --- a/be/src/exec/operator/aggregation_sink_operator.cpp +++ b/be/src/exec/operator/aggregation_sink_operator.cpp @@ -639,11 +639,10 @@ void AggSinkLocalState::_emplace_into_hash_table_inline_count(ColumnRawPtrs& key auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; SCOPED_TIMER(_hash_table_emplace_timer); - for (size_t i = 0; i < num_rows; ++i) { - auto* mapped_ptr = agg_method.lazy_emplace(state, i, creator, - creator_for_null_key); - ++reinterpret_cast(*mapped_ptr); - } + lazy_emplace_batch(agg_method, state, num_rows, creator, + creator_for_null_key, [&](uint32_t, auto& mapped) { + ++reinterpret_cast(mapped); + }); COUNTER_UPDATE(_hash_table_input_counter, num_rows); }}, @@ -680,11 +679,11 @@ void AggSinkLocalState::_merge_into_hash_table_inline_count(ColumnRawPtrs& key_c auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; SCOPED_TIMER(_hash_table_emplace_timer); - for (size_t i = 0; i < num_rows; ++i) { - auto* mapped_ptr = agg_method.lazy_emplace(state, i, creator, - creator_for_null_key); - reinterpret_cast(*mapped_ptr) += col_data[i].count; - } + lazy_emplace_batch( + agg_method, state, num_rows, creator, creator_for_null_key, + [&](uint32_t i, auto& mapped) { + reinterpret_cast(mapped) += col_data[i].count; + }); COUNTER_UPDATE(_hash_table_input_counter, num_rows); }}, diff --git a/be/src/exec/operator/streaming_aggregation_operator.cpp b/be/src/exec/operator/streaming_aggregation_operator.cpp index d419e015ffcf34..a5fa875e7cccfe 100644 --- a/be/src/exec/operator/streaming_aggregation_operator.cpp +++ b/be/src/exec/operator/streaming_aggregation_operator.cpp @@ -103,7 +103,7 @@ Status StreamingAggLocalState::open(RuntimeState* state) { // StreamingAgg only operates in update + serialize mode: input is raw data, output is serialized intermediate state. // The serialization format of count is UInt64 itself, so it can be inlined into the hash table mapped slot. if (_aggregate_evaluators.size() == 1 && - _aggregate_evaluators[0]->function()->is_simple_count()) { + _aggregate_evaluators[0]->function()->is_simple_count() && limit == -1) { _use_simple_count = true; #ifndef NDEBUG // Randomly enable/disable in debug mode to verify correctness of multi-phase agg promotion/demotion. @@ -892,11 +892,10 @@ void StreamingAggLocalState::_emplace_into_hash_table_inline_count(ColumnRawPtrs auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; SCOPED_TIMER(_hash_table_emplace_timer); - for (size_t i = 0; i < num_rows; ++i) { - auto* mapped_ptr = agg_method.lazy_emplace(state, i, creator, - creator_for_null_key); - ++reinterpret_cast(*mapped_ptr); - } + lazy_emplace_batch(agg_method, state, num_rows, creator, + creator_for_null_key, [&](uint32_t, auto& mapped) { + ++reinterpret_cast(mapped); + }); COUNTER_UPDATE(_hash_table_input_counter, num_rows); }}, From 7fa5b91bec7738488ef5d28594ffe843bb784abd Mon Sep 17 00:00:00 2001 From: Mryange Date: Wed, 18 Mar 2026 19:09:35 +0800 Subject: [PATCH 4/4] format --- .../operator/aggregation_sink_operator.cpp | 72 +++++++++---------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/be/src/exec/operator/aggregation_sink_operator.cpp b/be/src/exec/operator/aggregation_sink_operator.cpp index c02d4186c32f39..0f7505423a3552 100644 --- a/be/src/exec/operator/aggregation_sink_operator.cpp +++ b/be/src/exec/operator/aggregation_sink_operator.cpp @@ -652,42 +652,42 @@ void AggSinkLocalState::_emplace_into_hash_table_inline_count(ColumnRawPtrs& key void AggSinkLocalState::_merge_into_hash_table_inline_count(ColumnRawPtrs& key_columns, const IColumn* merge_column, uint32_t num_rows) { - std::visit(Overload {[&](std::monostate& arg) -> void { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "uninited hash table"); - }, - [&](auto& agg_method) -> void { - SCOPED_TIMER(_hash_table_compute_timer); - using HashMethodType = std::decay_t; - using AggState = typename HashMethodType::State; - AggState state(key_columns); - agg_method.init_serialized_keys(key_columns, num_rows); - - const auto& col = - assert_cast(*merge_column); - const auto* col_data = - reinterpret_cast( - col.get_data().data()); - - auto creator = [&](const auto& ctor, auto& key, auto& origin) { - HashMethodType::try_presis_key_and_origin( - key, origin, Base::_shared_state->agg_arena_pool); - AggregateDataPtr mapped = nullptr; - ctor(key, mapped); - }; - - auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; - - SCOPED_TIMER(_hash_table_emplace_timer); - lazy_emplace_batch( - agg_method, state, num_rows, creator, creator_for_null_key, - [&](uint32_t i, auto& mapped) { - reinterpret_cast(mapped) += col_data[i].count; - }); - - COUNTER_UPDATE(_hash_table_input_counter, num_rows); - }}, - _agg_data->method_variant); + std::visit( + Overload {[&](std::monostate& arg) -> void { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + }, + [&](auto& agg_method) -> void { + SCOPED_TIMER(_hash_table_compute_timer); + using HashMethodType = std::decay_t; + using AggState = typename HashMethodType::State; + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + + const auto& col = + assert_cast(*merge_column); + const auto* col_data = + reinterpret_cast( + col.get_data().data()); + + auto creator = [&](const auto& ctor, auto& key, auto& origin) { + HashMethodType::try_presis_key_and_origin( + key, origin, Base::_shared_state->agg_arena_pool); + AggregateDataPtr mapped = nullptr; + ctor(key, mapped); + }; + + auto creator_for_null_key = [&](auto& mapped) { mapped = nullptr; }; + + SCOPED_TIMER(_hash_table_emplace_timer); + lazy_emplace_batch(agg_method, state, num_rows, creator, + creator_for_null_key, [&](uint32_t i, auto& mapped) { + reinterpret_cast(mapped) += + col_data[i].count; + }); + + COUNTER_UPDATE(_hash_table_input_counter, num_rows); + }}, + _agg_data->method_variant); } bool AggSinkLocalState::_emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block,