Skip to content
Merged
19 changes: 9 additions & 10 deletions src/query/expression/src/aggregate/aggregate_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::Arc;
use bumpalo::Bump;
use databend_common_exception::Result;

use super::group_hash_columns;
use super::group_hash_entries;
use super::hash_index::AdapterImpl;
use super::hash_index::HashIndex;
use super::partitioned_payload::PartitionedPayload;
Expand All @@ -29,6 +29,7 @@ use super::probe_state::ProbeState;
use super::Entry;
use super::HashTableConfig;
use super::Payload;
use super::BATCH_SIZE;
use super::LOAD_FACTOR;
use super::MAX_PAGE_SIZE;
use crate::types::DataType;
Expand All @@ -37,8 +38,6 @@ use crate::BlockEntry;
use crate::ColumnBuilder;
use crate::ProjectedBlock;

const BATCH_ADD_SIZE: usize = 2048;

pub struct AggregateHashTable {
pub payload: PartitionedPayload,
// use for append rows directly during deserialize
Expand Down Expand Up @@ -129,12 +128,12 @@ impl AggregateHashTable {
agg_states: ProjectedBlock,
row_count: usize,
) -> Result<usize> {
if row_count <= BATCH_ADD_SIZE {
if row_count <= BATCH_SIZE {
self.add_groups_inner(state, group_columns, params, agg_states, row_count)
} else {
let mut new_count = 0;
for start in (0..row_count).step_by(BATCH_ADD_SIZE) {
let end = (start + BATCH_ADD_SIZE).min(row_count);
for start in (0..row_count).step_by(BATCH_SIZE) {
let end = (start + BATCH_SIZE).min(row_count);
let step_group_columns = group_columns
.iter()
.map(|entry| entry.slice(start..end))
Expand Down Expand Up @@ -188,11 +187,11 @@ impl AggregateHashTable {
}

state.row_count = row_count;
group_hash_columns(group_columns, &mut state.group_hashes);
group_hash_entries(group_columns, &mut state.group_hashes[..row_count]);

let new_group_count = if self.direct_append {
for idx in 0..row_count {
state.empty_vector[idx] = idx;
for i in 0..row_count {
state.empty_vector[i] = i.into();
}
self.payload.append_rows(state, row_count, group_columns);
row_count
Expand Down Expand Up @@ -232,7 +231,7 @@ impl AggregateHashTable {

if self.config.partial_agg {
// check size
if self.hash_index.count + BATCH_ADD_SIZE > self.hash_index.resize_threshold()
if self.hash_index.count + BATCH_SIZE > self.hash_index.resize_threshold()
&& self.hash_index.capacity >= self.config.max_partial_capacity
{
self.clear_ht();
Expand Down
249 changes: 209 additions & 40 deletions src/query/expression/src/aggregate/group_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,13 @@ use databend_common_column::types::Index;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;

use crate::types::i256;
use crate::types::number::Number;
use crate::types::AccessType;
use crate::types::AnyType;
use crate::types::BinaryColumn;
use crate::types::BinaryType;
use crate::types::BitmapType;
use crate::types::BooleanType;
use crate::types::DataType;
use crate::types::DateType;
use crate::types::DecimalColumn;
use crate::types::DecimalDataKind;
use crate::types::DecimalScalar;
use crate::types::DecimalView;
use crate::types::GeographyColumn;
use crate::types::GeographyType;
use crate::types::GeometryType;
use crate::types::NullableColumn;
use crate::types::NumberColumn;
use crate::types::NumberDataType;
use crate::types::NumberScalar;
use crate::types::NumberType;
use crate::types::OpaqueScalarRef;
use crate::types::StringColumn;
use crate::types::StringType;
use crate::types::TimestampType;
use crate::types::ValueType;
use crate::types::VariantType;
use crate::types::decimal::Decimal;
use crate::types::*;
use crate::visitor::ValueVisitor;
use crate::with_decimal_mapped_type;
use crate::with_number_mapped_type;
use crate::with_number_type;
use crate::BlockEntry;
use crate::Column;
use crate::ProjectedBlock;
use crate::Scalar;
Expand All @@ -59,23 +34,101 @@ use crate::Value;

const NULL_HASH_VAL: u64 = 0xd1cefa08eb382d69;

pub fn group_hash_columns(cols: ProjectedBlock, values: &mut [u64]) {
debug_assert!(!cols.is_empty());
for (i, entry) in cols.iter().enumerate() {
if i == 0 {
combine_group_hash_column::<true>(&entry.to_column(), values);
} else {
combine_group_hash_column::<false>(&entry.to_column(), values);
pub fn group_hash_entries(entries: ProjectedBlock, values: &mut [u64]) {
debug_assert!(!entries.is_empty());
for (i, entry) in entries.iter().enumerate() {
debug_assert_eq!(entry.len(), values.len());
match entry {
BlockEntry::Const(scalar, data_type, _) => {
if i == 0 {
combine_group_hash_const::<true>(scalar, data_type, values);
} else {
combine_group_hash_const::<false>(scalar, data_type, values);
}
}
BlockEntry::Column(column) => {
if i == 0 {
combine_group_hash_column::<true>(column, values);
} else {
combine_group_hash_column::<false>(column, values);
}
}
}
}
}

pub fn combine_group_hash_column<const IS_FIRST: bool>(c: &Column, values: &mut [u64]) {
fn combine_group_hash_column<const IS_FIRST: bool>(c: &Column, values: &mut [u64]) {
HashVisitor::<IS_FIRST> { values }
.visit_column(c.clone())
.unwrap()
}

fn combine_group_hash_const<const IS_FIRST: bool>(
scalar: &Scalar,
data_type: &DataType,
values: &mut [u64],
) {
match data_type {
DataType::Null | DataType::EmptyArray | DataType::EmptyMap => {}
DataType::Nullable(inner) => {
if scalar.is_null() {
apply_const_hash::<IS_FIRST>(values, NULL_HASH_VAL);
} else {
combine_group_hash_const_nonnull::<IS_FIRST>(scalar, inner, values);
}
}
_ => combine_group_hash_const_nonnull::<IS_FIRST>(scalar, data_type, values),
}
}

fn combine_group_hash_const_nonnull<const IS_FIRST: bool>(
scalar: &Scalar,
_data_type: &DataType,
values: &mut [u64],
) {
let hash = match scalar {
Scalar::Null => unreachable!(),
Scalar::EmptyArray | Scalar::EmptyMap => return,
Scalar::Number(v) => with_number_type!(|NUM_TYPE| match v {
NumberScalar::NUM_TYPE(value) => value.agg_hash(),
}),
Scalar::Decimal(v) => {
with_decimal_mapped_type!(|F| match v {
DecimalScalar::F(v, size) => {
with_decimal_mapped_type!(|T| match size.data_kind() {
DecimalDataKind::T => {
v.as_decimal::<T>().agg_hash()
}
})
}
})
}
Scalar::Timestamp(value) => value.agg_hash(),
Scalar::Date(value) => value.agg_hash(),
Scalar::Boolean(value) => value.agg_hash(),
Scalar::String(value) => value.as_bytes().agg_hash(),
Scalar::Binary(value)
| Scalar::Bitmap(value)
| Scalar::Variant(value)
| Scalar::Geometry(value) => value.agg_hash(),
Scalar::Geography(value) => value.0.agg_hash(),
_ => scalar.as_ref().agg_hash(),
};
apply_const_hash::<IS_FIRST>(values, hash);
}

fn apply_const_hash<const IS_FIRST: bool>(values: &mut [u64], hash: u64) {
if IS_FIRST {
for val in values.iter_mut() {
*val = hash;
}
} else {
for val in values.iter_mut() {
*val = merge_hash(*val, hash);
}
}
}

struct HashVisitor<'a, const IS_FIRST: bool> {
values: &'a mut [u64],
}
Expand All @@ -101,7 +154,7 @@ impl<const IS_FIRST: bool> ValueVisitor for HashVisitor<'_, IS_FIRST> {
fn visit_any_number(&mut self, column: NumberColumn) -> Result<()> {
with_number_mapped_type!(|NUM_TYPE| match column.data_type() {
NumberDataType::NUM_TYPE => {
let c = NUM_TYPE::try_downcast_column(&column).unwrap();
let c = <NUM_TYPE as Number>::try_downcast_column(&column).unwrap();
self.combine_group_hash_type_column::<NumberType<NUM_TYPE>>(&c)
}
});
Expand Down Expand Up @@ -573,22 +626,138 @@ impl AggHash for ScalarRef<'_> {
#[cfg(test)]
mod tests {
use databend_common_column::bitmap::Bitmap;
use databend_common_column::types::months_days_micros;
use databend_common_column::types::timestamp_tz;

use super::*;
use crate::types::geography::Geography;
use crate::types::ArgType;
use crate::types::Int32Type;
use crate::types::DecimalSize;
use crate::types::NullableColumn;
use crate::types::NullableType;
use crate::types::StringType;
use crate::types::NumberScalar;
use crate::types::OpaqueScalar;
use crate::types::VectorDataType;
use crate::types::VectorScalar;
use crate::BlockEntry;
use crate::DataBlock;
use crate::FromData;
use crate::ProjectedBlock;
use crate::Value;

fn merge_hash_slice(ls: &[u64]) -> u64 {
ls.iter().cloned().reduce(merge_hash).unwrap()
}

#[test]
fn test_group_hash_entries_const_equals_column() {
let num_rows = 5;
let block = sample_block(num_rows);
let full_block = block.convert_to_full();

for projection in (0..block.num_columns())
.map_windows(|[a, b]| vec![*a, *b])
.chain(Some((0..block.num_columns()).collect()))
.collect::<Vec<_>>()
{
let mut const_hashes = vec![0; block.num_rows()];
let mut col_hashes = vec![0; full_block.num_rows()];
group_hash_entries(
ProjectedBlock::project(&projection, &block),
&mut const_hashes,
);
group_hash_entries(
ProjectedBlock::project(&projection, &full_block),
&mut col_hashes,
);
assert_eq!(const_hashes, col_hashes);
}
}

fn sample_block(num_rows: usize) -> DataBlock {
let cases = [
(DataType::Null, Scalar::Null),
(DataType::EmptyArray, Scalar::EmptyArray),
(DataType::EmptyMap, Scalar::EmptyMap),
(DataType::Boolean, Scalar::Boolean(true)),
(DataType::Binary, Scalar::Binary(vec![1, 2, 3, 4])),
(DataType::String, Scalar::String("const_str".to_string())),
(
Int32Type::data_type(),
Scalar::Number(NumberScalar::Int32(123)),
),
(
DataType::Number(NumberDataType::Float64),
Scalar::Number(NumberScalar::Float64(OrderedFloat(1.25))),
),
{
let decimal_size = DecimalSize::new(20, 3).unwrap();
(
DataType::Decimal(decimal_size),
Scalar::Decimal(DecimalScalar::Decimal128(123456_i128, decimal_size)),
)
},
(DataType::Timestamp, Scalar::Timestamp(123_456_789)),
(
DataType::TimestampTz,
Scalar::TimestampTz(timestamp_tz::new(123_456, 3_600)),
),
(DataType::Date, Scalar::Date(42)),
(
DataType::Interval,
Scalar::Interval(months_days_micros::new(1, 2, 3)),
),
(DataType::Bitmap, Scalar::Bitmap(vec![0b10101010])),
(DataType::Variant, Scalar::Variant(vec![1, 2, 3, 4])),
(DataType::Geometry, Scalar::Geometry(vec![9, 9, 9])),
(
DataType::Geography,
Scalar::Geography(Geography(vec![1, 2, 3])),
),
(
DataType::Vector(VectorDataType::Int8(2)),
Scalar::Vector(VectorScalar::Int8(vec![1, 2])),
),
(
DataType::Opaque(2),
Scalar::Opaque(OpaqueScalar::Opaque2([1, 2])),
),
{
let array_ty = DataType::Array(Box::new(Int32Type::data_type()));
(array_ty.clone(), Scalar::default_value(&array_ty))
},
{
let map_ty = DataType::Map(Box::new(DataType::Tuple(vec![
DataType::String,
Int32Type::data_type(),
])));
let val = Scalar::default_value(&map_ty);
(map_ty, val)
},
{
let tuple_ty = DataType::Tuple(vec![DataType::String, Int32Type::data_type()]);
(
tuple_ty,
Scalar::Tuple(vec![
Scalar::String("tuple_0".to_string()),
Scalar::Number(NumberScalar::Int32(0)),
]),
)
},
(
Int32Type::data_type().wrap_nullable(),
Scalar::Number(NumberScalar::Int32(999)),
),
(Int32Type::data_type().wrap_nullable(), Scalar::Null),
];

DataBlock::from_iter(
cases.into_iter().map(|(data_type, scalar)| {
BlockEntry::new_const_column(data_type, scalar, num_rows)
}),
num_rows,
)
}

#[test]
fn test_value_spread() -> Result<()> {
let data = DataBlock::new(
Expand Down
Loading