From 22733ff7f6c9546ebcac719a0d0c1025113e2442 Mon Sep 17 00:00:00 2001 From: Jasmine-ge Date: Mon, 9 Dec 2024 18:53:45 +0800 Subject: [PATCH] combinator function --- .../AggregateFunctionAvgWeighted.h | 2 +- .../AggregateFunctionFactory.cpp | 6 +- .../AggregateFunctionTimeWeighted.cpp | 38 +- .../AggregateFunctionTimeWeighted.h | 413 +++++++++--------- .../Streaming/SubstituteStreamingFunction.cpp | 68 --- .../0_stateless/99010_time_weighted_avg.sql | 10 +- .../0035_streaming_func.json | 6 +- 7 files changed, 235 insertions(+), 308 deletions(-) diff --git a/src/AggregateFunctions/AggregateFunctionAvgWeighted.h b/src/AggregateFunctions/AggregateFunctionAvgWeighted.h index 71c1cb149af..580a958ae36 100644 --- a/src/AggregateFunctions/AggregateFunctionAvgWeighted.h +++ b/src/AggregateFunctions/AggregateFunctionAvgWeighted.h @@ -30,7 +30,7 @@ class AggregateFunctionAvgWeighted final : using Numerator = typename Base::Numerator; using Denominator = typename Base::Denominator; - using Fraction = typename Base::Fraction; + using Fraction = typename Base::Fraction; void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { diff --git a/src/AggregateFunctions/AggregateFunctionFactory.cpp b/src/AggregateFunctions/AggregateFunctionFactory.cpp index 3c3fc4e303b..633aae3118b 100644 --- a/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -183,8 +183,10 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl( if (combinator_name == "_time_weighted" && nested_name == "avg") nested_name = "avg_weighted"; else if (combinator_name == "_time_weighted" && nested_name == "median") - nested_name = "median_timing"; - + nested_name = "median_timing_weighted"; + else + throw Exception(ErrorCodes::ILLEGAL_AGGREGATION, "Combinator '{}' with {} is not supported", combinator_name, nested_name); + /// Nested identical combinators (i.e. uniqCombinedIfIf) is not /// supported (since they don't work -- silently). /// diff --git a/src/AggregateFunctions/AggregateFunctionTimeWeighted.cpp b/src/AggregateFunctions/AggregateFunctionTimeWeighted.cpp index dd039dc01f6..f17ae20ab3f 100644 --- a/src/AggregateFunctions/AggregateFunctionTimeWeighted.cpp +++ b/src/AggregateFunctions/AggregateFunctionTimeWeighted.cpp @@ -32,21 +32,7 @@ class AggregateFunctionCombinatorTimeWeighted final : public IAggregateFunctionC throw Exception("Incorrect number of arguments for aggregate function with " + getName() + " suffix", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - DataTypes nested_arguments; - nested_arguments.push_back(arguments[0]); - //delete - if (isDate(*arguments.begin())) - return nested_arguments; - if (isDate(arguments.back())) - nested_arguments.push_back(std::make_shared()); - else if(isDate32(arguments.back())) - nested_arguments.push_back(std::make_shared()); - else if(isDateTime(arguments.back())) - nested_arguments.push_back(std::make_shared()); - else if(isDateTime64(arguments.back())) - nested_arguments.push_back(std::make_shared()); - - return nested_arguments; + return {arguments[0], std::make_shared()}; } /// Decimal128 and Decimal256 aren't supported @@ -97,29 +83,17 @@ class AggregateFunctionCombinatorTimeWeighted final : public IAggregateFunctionC const auto data_type = static_cast(arguments[0]); const auto data_type_time_weight = static_cast(arguments[1]); - const WhichDataType dt(data_type), t_dt(data_type_time_weight); + const WhichDataType t_dt(data_type_time_weight); - if ((dt.isInt() || dt.isUInt() || dt.isFloat() || dt.isDecimal()) && (t_dt.isDateOrDate32() || t_dt.isDateTime()|| t_dt.isDateTime64())) - else - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Types {} and {} are non-conforming as arguments for aggregate function {}", data_type->getName(), data_type_time_weight->getName(), this->getName()); + if (!t_dt.isDateOrDate32() && !t_dt.isDateTime() && !t_dt.isDateTime64()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Types {} are non-conforming as time weighted arguments for aggregate function {}", data_type_time_weight->getName(), this->getName()); if (arguments.size() == 3) { const auto data_type_third_arg = static_cast(arguments[2]); - if(data_type_third_arg != data_type_time_weight) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The second and the third argument should be the same for aggregate function {}", this->getName()); - // AggregateFunctionPtr ptr; - - // // const bool left_decimal = isDecimal(data_type); - // // data_type_time_weight = UInt64; - // // auto data_type_uint64 = std::make_shared(); - // // if (left_decimal) - // // ptr.reset(create(*data_type, *data_type_time_weight, nested_function, arguments, params, - // // getDecimalScale(*data_type))); - // // else - // ptr.reset(create(*data_type, *data_type_time_weight, nested_function, arguments, params)); - // return ptr; + if(data_type_third_arg->getTypeId() != data_type_time_weight->getTypeId()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The second and the third argument should be the same for aggregate function {}", this->getName()); } AggregateFunctionPtr ptr; ptr.reset(create(*data_type, *data_type_time_weight, nested_function, arguments, params)); diff --git a/src/AggregateFunctions/AggregateFunctionTimeWeighted.h b/src/AggregateFunctions/AggregateFunctionTimeWeighted.h index 831a39a3075..07b7935fe81 100644 --- a/src/AggregateFunctions/AggregateFunctionTimeWeighted.h +++ b/src/AggregateFunctions/AggregateFunctionTimeWeighted.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -23,38 +24,24 @@ namespace ErrorCodes struct Settings; -template constexpr bool DecimalOrExtendedInt = - is_decimal - || std::is_same_v - || std::is_same_v - || std::is_same_v - || std::is_same_v; - -/** - * Helper class to encapsulate values conversion for avg and avgWeighted. - */ -template -struct AvgTimeFraction +template +struct TimeWeightedData { struct Last { - Numerator last_value; - Denominator last_time; + Field last_value; + TimeType last_time; }; std::optional last; - std::optional current_time; + std::optional start_time; + std::optional end_time; }; - -// template -// using MaxFieldType = std::conditional_t<(sizeof(AvgTimeWeightedFieldType) > sizeof(AvgTimeWeightedFieldType)), -// AvgTimeWeightedFieldType, AvgTimeWeightedFieldType>; - template class AggregateFunctionTimeWeighted: - public IAggregateFunctionDataHelper>, + public IAggregateFunctionDataHelper, AggregateFunctionTimeWeighted> { @@ -74,12 +61,9 @@ class AggregateFunctionTimeWeighted: return place + prefix_size; } public: - using Base = IAggregateFunctionDataHelper>, + using Base = IAggregateFunctionDataHelper, AggregateFunctionTimeWeighted>; - using Numerator = Value; - using Denominator = NearestFieldType; - using Fraction = AvgTimeFraction; AggregateFunctionTimeWeighted(AggregateFunctionPtr nested_func_, const DataTypes & arguments, const Array & params_) : Base(arguments, params_) , nested_func(nested_func_) @@ -87,12 +71,49 @@ class AggregateFunctionTimeWeighted: , logger(&Poco::Logger::get("AggregateFunctionTimeWeighted")) { size_t nested_size = nested_func->alignOfData(); - prefix_size = (sizeof(AvgTimeFraction>) + nested_size - 1) / nested_size * nested_size; + prefix_size = (sizeof(TimeWeightedData) + nested_size - 1) / nested_size * nested_size; + } + + void last_time_calculation(size_t row_begin, size_t row_end, AggregateDataPtr __restrict place, const IColumn ** columns, Arena * arena) const + { + auto & data = this->data(place); + const auto & time_data = assert_cast &>(*columns[1]).getData(); + /// last time caculation + if (data.last.has_value()) + { + MutableColumnPtr value_column, weight_column; + value_column = this->argument_types[0]->createColumn(); + weight_column = ColumnUInt64::create(); + if (time_data[row_begin] >= data.last->last_time) [[likely]] + { + value_column->insert(data.last->last_value); + weight_column->insert(static_cast(time_data[row_begin] - data.last->last_time)); + } + else + { + LOG_WARNING(logger, "Illegal time argument, should be in ascending order, {}, {}" ,data.last->last_time ,time_data[row_begin]); + } + ColumnRawPtrs raw_columns{value_column.get(), weight_column.get()}; + nested_func->add(getNestedPlace(place), raw_columns.data(), 0, arena); + } + + const auto & value_data = assert_cast &>(*columns[0]).getData(); + auto last_row_pos = row_end - 1; + data.last = { + static_cast(value_data[last_row_pos]), + static_cast(time_data[last_row_pos]) + }; + /// remember start time + data.start_time = time_data[row_begin]; + /// remember current time + if (this->argument_types.size() == 3) + data.end_time = assert_cast &>(*columns[2]).getData()[last_row_pos]; } void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override { - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "merge() function isn't implemented for {}", getName()); + + last_time_calculation(row_num, row_num + 1, place, columns, arena); } void addBatchSinglePlace( @@ -104,72 +125,29 @@ class AggregateFunctionTimeWeighted: ssize_t if_argument_pos, const IColumn * delta_col [[maybe_unused]]) const final { - auto & data = this->data(place); - MutableColumnPtr value_column, weight_column; - const auto & value_data = assert_cast &>(*columns[0]).getData(); - const auto & time_data = assert_cast &>(*columns[1]).getData(); - auto last_row_pos = row_end - 1; + if (if_argument_pos >= 0 || delta_col != nullptr) + return nested_func->addBatchSinglePlace(row_begin, row_end, place, columns, arena, if_argument_pos, delta_col); + else if (row_end - row_begin == 1) + return add(place, columns, 0, arena); - /// last time caculation - if (data.last.has_value()) - { - value_column = this->argument_types[0]->createColumn(); - weight_column = ColumnUInt64::create(); - value_column->insert(data.last->last_value); - weight_column->insert(static_cast(time_data[0] - data.last->last_time)); - - ColumnRawPtrs raw_columns{value_column.get(), weight_column.get()}; - nested_func->add(getNestedPlace(place), raw_columns.data(), 0, arena); - } + const auto & time_data = assert_cast &>(*columns[1]).getData(); + last_time_calculation(row_begin, row_end, place, columns, arena); + auto last_row_pos = row_end - 1; /// caculate time - weight_column = ColumnUInt64::create(); - if (if_argument_pos >= 0) + MutableColumnPtr weight_column = ColumnUInt64::create(); + for (size_t i = row_begin; i < last_row_pos; i++) { - const auto & flags = assert_cast(*columns[if_argument_pos]).getData(); - for (size_t i = row_begin; i < last_row_pos; i++) - { - if (flags[i]) - { - if (time_data[i + 1] < time_data[i]) - LOG_WARNING(logger, "Illegal time argument, should be in ascending order, {}, {}" ,time_data[i] ,time_data[i + 1]); - else - weight_column->insert(static_cast(time_data[i + 1] - time_data[i])); - } - } + if (time_data[i + 1] >= time_data[i]) [[likely]] + weight_column->insert(static_cast(time_data[i + 1] - time_data[i])); + else + LOG_WARNING(logger, "Illegal time argument, should be in ascending order, {}, {}" ,time_data[i] ,time_data[i + 1]); } - else - { - for (size_t i = row_begin; i < last_row_pos; i++) - { - if (time_data[i + 1] < time_data[i]) - LOG_WARNING(logger, "Illegal time argument, should be in ascending order, {}, {}" ,time_data[i] ,time_data[i + 1]); - else - weight_column->insert(static_cast(time_data[i + 1] - time_data[i])); - } - } - - //weight_column->insertDefault(); /// prepare data - ColumnRawPtrs raw_columns{*columns[0].get(), weight_column.get()}; - - nested_func-> addBatchSinglePlace(row_begin, last_row_pos, getNestedPlace(place), raw_columns.data(), arena, if_argument_pos); - - // if (data.last_value.has_value()) - // { - // data.last_value.value() = static_cast(value_data[row_end - 1]); - // data.last_time.value() = static_cast(time_data[row_end - 1]); - // } - // else - // { - data.last->last_value = static_cast(value_data[last_row_pos]); - data.last->last_time = static_cast(time_data[last_row_pos]); - /// remember current time - if (this->argument_types.size() == 3) - data.current_time = assert_cast &>(*columns[2]).getData()[last_row_pos]; + ColumnRawPtrs raw_columns{columns[0], weight_column.get()}; - // } + nested_func->addBatchSinglePlace(row_begin, last_row_pos, getNestedPlace(place), raw_columns.data(), arena, if_argument_pos, delta_col); } void addBatchSinglePlaceNotNull( @@ -183,139 +161,180 @@ class AggregateFunctionTimeWeighted: const IColumn * delta_col [[maybe_unused]]) const final { - auto & data = this->data(place); - MutableColumnPtr value_column, weight_column; - const auto & value_data = assert_cast &>(*columns[0]).getData(); + if (if_argument_pos >= 0 || delta_col != nullptr) + return nested_func->addBatchSinglePlaceNotNull(row_begin, row_end, place, columns, null_map, arena, if_argument_pos, delta_col); + else if (row_end - row_begin == 1) + return add(place, columns, 0, arena); + + // const auto & value_data = assert_cast &>(*columns[0]).getData(); const auto & time_data = assert_cast &>(*columns[1]).getData(); - auto last_row_pos = row_end - 1; - /// last time caculation - if (data.last.has_value()) - { - value_column = this->argument_types[0]->createColumn(); - weight_column = ColumnUInt64::create(); - value_column->insert(data.last->last_value); - weight_column->insert(static_cast(time_data[0] - data.last->last_time)); - - ColumnRawPtrs raw_columns{value_column.get(), weight_column.get()}; - nested_func->add(getNestedPlace(place), raw_columns.data(), 0, arena); - } + last_time_calculation(row_begin, row_end, place, columns, arena); + auto last_row_pos = row_end - 1; /// caculate time - weight_column = ColumnUInt64::create(); - if (if_argument_pos >= 0) + MutableColumnPtr weight_column = ColumnUInt64::create(); + for (size_t i = row_begin; i < row_end - 1; i++) { - const auto & flags = assert_cast(*columns[if_argument_pos]).getData(); - for (size_t i = row_begin; i < row_end - 1; i++) + if (!null_map[i]) { - if (flags[i] && !null_map[i]) - { - if (value_column[i + 1] < value_column[i]) - LOG_WARNING(logger, "Illegal time argument, should be in ascending order, {}, {}" ,value_column[i] ,value_column[i + 1]); - else - weight_column->insert(static_cast(value_column[i + 1] - value_column[i])); - } + if (time_data[i + 1] < time_data[i]) + LOG_WARNING(logger, "Illegal time argument, should be in ascending order, {}, {}" ,time_data[i] ,time_data[i + 1]); + else + weight_column->insert(static_cast(time_data[i + 1] - time_data[i])); } } - else + ColumnRawPtrs raw_columns{columns[0], weight_column.get()}; + + nested_func-> addBatchSinglePlaceNotNull(row_begin, last_row_pos, getNestedPlace(place), raw_columns.data(), null_map, arena, if_argument_pos, delta_col); + } + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override + { + /// FIXME, time disorder may happen, the outcome might not be accurate + auto & data = this->data(place); + auto & rhs_data = this->data(rhs); + if (data.last.has_value()) { - for (size_t i = row_begin; i < row_end - 1; i++) + if (rhs_data.start_time.has_value()) { - if (!null_map[i]) + MutableColumnPtr value_column, weight_column; + value_column = this->argument_types[0]->createColumn(); + weight_column = ColumnUInt64::create(); + if (rhs_data.start_time.value() >= data.last->last_time) + { + value_column->insert(data.last->last_value); + weight_column->insert(static_cast(rhs_data.start_time.value() - data.last->last_time)); + if (rhs_data.last.has_value()) + data.last = rhs_data.last; + if (rhs_data.end_time.has_value()) + data.end_time = rhs_data.end_time; + } + else { - if (value_column[i + 1] < value_column[i]) - LOG_WARNING(logger, "Illegal time argument, should be in ascending order, {}, {}" ,value_column[i] ,value_column[i + 1]); - else - weight_column->insert(static_cast(value_column[i + 1] - value_column[i])); + if (data.start_time.has_value()) + { + if (data.start_time.value() >= rhs_data.last->last_time) + { + value_column->insert(rhs_data.last->last_value); + weight_column->insert(static_cast(data.last->last_time - rhs_data.start_time.value())); + data.start_time.value() = rhs_data.start_time.value(); + } + else + { + LOG_WARNING(logger, "Illegal time argument, should be in ascending order, {}, {}" ,data.last->last_time ,rhs_data.start_time.value()); + } + } } + ColumnRawPtrs raw_columns{value_column.get(), weight_column.get()}; + nested_func->add(getNestedPlace(place), raw_columns.data(), 0, arena); } } - //weight_column->insertDefault(); - - ColumnRawPtrs raw_columns{columns[0].get(), weight_column.get()}; - - nested_func-> addBatchSinglePlace(row_begin, last_row_pos, getNestedPlace(place), raw_columns.data(), arena, if_argument_pos); - - // if (data.last_value.has_value()) - // { - // data.last_value.value() = static_cast(value_data[row_end - 1]); - // data.last_time.value() = static_cast(time_data[row_end - 1]); - // } - // else - // { - data.last->last_value = static_cast(value_data[last_row_pos]); - data.last->last_time = static_cast(time_data[last_row_pos]); - /// remember current time - if (this->argument_types.size() == 3) - data.current_time = assert_cast &>(*columns[2]).getData()[last_row_pos]; - - // } + else + { + if (rhs_data.last.has_value()) + data.last = rhs_data.last; + if (rhs_data.start_time.has_value()) + data.start_time = rhs_data.start_time; + if (rhs_data.end_time.has_value()) + data.end_time = rhs_data.end_time; + } - } - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override - { - /// FIXME, time disorder may happen, the outcome might not be accurate - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "merge() function isn't implemented for {}", getName()); + nested_func->merge(getNestedPlace(place), rhs, arena); } void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional /* version */) const override { - // if (this->data(place).last_value.has_value() && this->data(place).last_time.has_value()) - // { - // writeBinary(this->data(place).last_value.value(), buf); - // if constexpr (std::is_unsigned_v) - // writeVarUInt(this->data(place).last_time.value(), buf); - // else - // writeBinary(this->data(place).last_time.value(), buf); - // } - // if (this->data(place).current_time.has_value()) - // { - // if constexpr (std::is_unsigned_v) - // writeVarUInt(this->data(place).current_time.value(), buf); - // else - // writeBinary(this->data(place).current_time.value(), buf); - // } - // nested_func->serialize(getNestedPlace(place), buf); + writeBinary(this->data(place).last.has_value(), buf); + if (this->data(place).last.has_value()) + { + writeBinary(true, buf); + writeFieldBinary(this->data(place).last->last_value, buf); + if constexpr (std::is_unsigned_v) + writeVarUInt(this->data(place).last->last_time, buf); + else + writeBinary(this->data(place).last->last_time, buf); + } + + writeBinary(this->data(place).start_time.has_value(), buf); + if (this->data(place).start_time.has_value()) + { + writeBinary(true, buf); + if constexpr (std::is_unsigned_v) + writeVarUInt(this->data(place).start_time.value(), buf); + else + writeBinary(this->data(place).start_time.value(), buf); + } + + writeBinary(this->data(place).end_time.has_value(), buf); + if (this->data(place).end_time.has_value()) + { + writeBinary(true, buf); + if constexpr (std::is_unsigned_v) + writeVarUInt(this->data(place).end_time.value(), buf); + else + writeBinary(this->data(place).end_time.value(), buf); + } + nested_func->serialize(getNestedPlace(place), buf); } void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional /* version */, Arena * arena) const override { - // if (this->data(place).last_value.has_value() && this->data(place).last_time.has_value()) - // { - // readBinary(this->data(place).last_value.value(), buf); - // if constexpr (std::is_unsigned_v) - // readVarUInt(this->data(place).last_time.value(), buf); - // else /// Floating point denominator type can be used - // readBinary(this->data(place).last_time.value(), buf); - // } - // if (this->data(place).current_time.has_value()) - // { - // if constexpr (std::is_unsigned_v) - // readVarUInt(this->data(place).current_time.value(), buf); - // else - // readBinary(this->data(place).current_time.value(), buf); - // } - // nested_func->deserialize(getNestedPlace(place), buf, std::nullopt /* version */, arena); + bool last_has_value, start_has_value, end_has_value; + readBinary(last_has_value, buf); + if(last_has_value) + { + this->data(place).last->last_value = readFieldBinary(buf); + if constexpr (std::is_unsigned_v) + readVarUInt(this->data(place).last->last_time, buf); + else /// Floating point TimeWeight type can be used + readBinary(this->data(place).last->last_time, buf); + } + + readBinary(start_has_value, buf); + if(start_has_value) + { + if constexpr (std::is_unsigned_v) + readVarUInt(this->data(place).start_time.value(), buf); + else + readBinary(this->data(place).start_time.value(), buf); + } + + readBinary(end_has_value, buf); + if(end_has_value) + { + if constexpr (std::is_unsigned_v) + readVarUInt(this->data(place).end_time.value(), buf); + else + readBinary(this->data(place).end_time.value(), buf); + } + + nested_func->deserialize(getNestedPlace(place), buf, std::nullopt /* version */, arena); } - void insertResultIntoImpl(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override { auto & data = this->data(place); - if (data.current_time.has_value()) + if (data.end_time.has_value()) { MutableColumnPtr value_column, weight_column; - ColumnRawPtrs argument_raw_columns(2); chassert(data.last.has_value()); value_column = this->argument_types[0]->createColumn(); weight_column = ColumnUInt64::create(); - value_column->insert(data.last->last_value); - weight_column->insert(data.current_time.value() - data.last->last_value); + if (data.end_time.value() >= data.last->last_time) [[likely]] + { + value_column->insert(data.last->last_value); + weight_column->insert(static_cast(data.end_time.value() - data.last->last_time)); + } + else + { + LOG_WARNING(logger, "Illegal time argument, should be in ascending order, {}, {}" ,data.last->last_time ,data.end_time.value()); + } + - for (size_t i = 0; i < argument_columns.size(); ++i) - argument_raw_columns[i] = argument_columns[i].get(); + ColumnRawPtrs raw_columns{value_column.get(), weight_column.get()}; - nested_func -> add(getNestedPlace(place), argument_raw_columns.data(), 0, arena); + nested_func -> add(getNestedPlace(place), raw_columns.data(), 0, arena); } // assert(!data.arguments.empty()); @@ -323,15 +342,15 @@ class AggregateFunctionTimeWeighted: nested_func->insertResultInto(getNestedPlace(place), to, arena); } - void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override - { - insertResultIntoImpl(place, to, arena); - } + // void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + // { + // insertResultIntoImpl(place, to, arena); + // } - void insertMergeResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override - { - insertResultIntoImpl(place, to, arena); - } + // void insertMergeResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override + // { + // insertResultIntoImpl(place, to, arena); + // } size_t sizeOfData() const override { @@ -340,30 +359,30 @@ class AggregateFunctionTimeWeighted: void create(AggregateDataPtr __restrict place) const override { - new (place) AvgTimeFraction>; + new (place) TimeWeightedData; nested_func->create(getNestedPlace(place)); } void destroy(AggregateDataPtr __restrict place) const noexcept override { - this->data(place).~AvgTimeFraction>(); + this->data(place).~TimeWeightedData(); nested_func->destroy(getNestedPlace(place)); } bool hasTrivialDestructor() const override { - return std::is_trivially_destructible_v>> && nested_func->hasTrivialDestructor(); + return std::is_trivially_destructible_v> && nested_func->hasTrivialDestructor(); } void destroyUpToState(AggregateDataPtr __restrict place) const noexcept override { - this->data(place).~AvgTimeFraction>(); + this->data(place).~TimeWeightedData(); nested_func->destroyUpToState(getNestedPlace(place)); } String getName() const override { - return nested_func->getName() + "_time"; + return nested_func->getName() + "_time_weighted"; } DataTypePtr getReturnType() const override @@ -399,4 +418,4 @@ class AggregateFunctionTimeWeighted: AggregateFunctionPtr getNestedFunction() const override { return nested_func; } }; -} \ No newline at end of file +} diff --git a/src/Interpreters/Streaming/SubstituteStreamingFunction.cpp b/src/Interpreters/Streaming/SubstituteStreamingFunction.cpp index 29d2e0f7a98..94bb5033a3a 100644 --- a/src/Interpreters/Streaming/SubstituteStreamingFunction.cpp +++ b/src/Interpreters/Streaming/SubstituteStreamingFunction.cpp @@ -134,8 +134,6 @@ void StreamingFunctionData::visit(DB::ASTFunction & func, DB::ASTPtr) return; } - translateTimeWeightedFunc(func); - if (streaming) { auto func_name_lower = Poco::toLower(func.name); @@ -224,71 +222,5 @@ void substitueFunction(ASTFunction & func, const String & new_name) func.name = new_name; } -bool translateTimeWeightedFunc(ASTFunction & func) -{ - static const std::unordered_map map = { - {"time_weighted_avg", "avg_weighted"}, - {"time_weighted_median", "median_timing_weighted"} - }; - - /// time_weighted_median(val, _tp_time) -> quantile_timing_weighted(lag(val, 1, val), cast(date_diff('millisecond', lag(_tp_time, 1, _tp_time), _tp_time), 'uint64')) - /// time_weighted_avg(val, _tp_time) -> avg_weighted(lag(val, 1, val), cast(date_diff('millisecond', lag(_tp_time, 1, _tp_time), _tp_time), 'uint64')) - if (!map.contains(func.name)) - return false; - - auto num_args = func.arguments->children.size(); - if (num_args != 2 && num_args != 3) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Number of arguments for function {} doesn't match: passed {}, should be 2 or 3", func.name, num_args); - - String func_name = map.at(func.name); - auto val_arg = func.arguments->children[0]; - auto time_arg = func.arguments->children[1]; - String algorithm = "locf"; /// LOCF or Linear - - if (num_args == 3) - { - const auto * literal = func.arguments->children[2]->as(); - if (literal) - algorithm = Poco::toLower(literal->value.safeGet()); - if (!literal || (algorithm != "linear" && algorithm != "locf")) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Third argument must be literal string of algorithm: 'linear' or 'locf', but given '{}'" , algorithm); - } - if (algorithm == "linear") - { - func.name = func_name; - func.arguments = std::make_shared(); - func.children.push_back(func.arguments); - func.arguments->children = { - makeASTFunction("divide", - makeASTFunction("plus", - makeASTFunction("lag", val_arg->clone(), std::make_shared(Int8(1)), val_arg->clone()), - val_arg->clone()), - std::make_shared(2)), - makeASTFunction("cast", - makeASTFunction("date_diff", - std::make_shared("millisecond"), - makeASTFunction("lag", time_arg->clone(), std::make_shared(Int8(1)), time_arg->clone()), - time_arg->clone()), - std::make_shared("uint64")) - }; - } - else - { - func.name = func_name; - func.arguments = std::make_shared(); - func.children.push_back(func.arguments); - func.arguments->children = { - makeASTFunction("lag", val_arg->clone(), std::make_shared(Int8(1)), val_arg->clone()), - makeASTFunction("cast", - makeASTFunction("date_diff", - std::make_shared("millisecond"), - makeASTFunction("lag", time_arg->clone(), std::make_shared(Int8(1)), time_arg->clone()), - time_arg->clone()), - std::make_shared("uint64")) - }; - } - - return true; -} } } diff --git a/tests/queries_ported/0_stateless/99010_time_weighted_avg.sql b/tests/queries_ported/0_stateless/99010_time_weighted_avg.sql index 11442f4fbe0..87c4204026e 100644 --- a/tests/queries_ported/0_stateless/99010_time_weighted_avg.sql +++ b/tests/queries_ported/0_stateless/99010_time_weighted_avg.sql @@ -1,4 +1,4 @@ -DROP STREAM IF EXISTS test_99010; + DROP STREAM IF EXISTS test_99010; CREATE STREAM test_99010 (val int, a DateTime, b Date, c Date32, d DateTime64); @@ -10,9 +10,9 @@ INSERT INTO test_99010(val, a, b, c, d) VALUES (5, to_datetime('2024-11-29 12:12 INSERT INTO test_99010(val, a, b, c, d) VALUES (6, to_datetime('2024-11-29 12:12:25'), '2024-12-29', '2024-12-29', to_datetime64('2024-11-29 12:12:13.135', 3)); SELECT sleep(3); -SELECT time_weighted_avg(val, a) FROM (SELECT * FROM table(test_99010) ORDER BY a); -SELECT time_weighted_avg(val, b) FROM (SELECT * FROM table(test_99010) ORDER BY b); -SELECT time_weighted_avg(val, c) FROM (SELECT * FROM table(test_99010) ORDER BY c); -SELECT time_weighted_avg(val, d) FROM (SELECT * FROM table(test_99010) ORDER BY d); +SELECT avg_time_weighted(val, a) FROM (SELECT * FROM table(test_99010) ORDER BY a); +SELECT avg_time_weighted(val, b) FROM (SELECT * FROM table(test_99010) ORDER BY b); +SELECT avg_time_weighted(val, c) FROM (SELECT * FROM table(test_99010) ORDER BY c); +SELECT avg_time_weighted(val, d) FROM (SELECT * FROM table(test_99010) ORDER BY d); DROP STREAM IF EXISTS test_99010; diff --git a/tests/stream/test_stream_smoke/0035_streaming_func.json b/tests/stream/test_stream_smoke/0035_streaming_func.json index 670c8f15bd3..d92300e1c97 100644 --- a/tests/stream/test_stream_smoke/0035_streaming_func.json +++ b/tests/stream/test_stream_smoke/0035_streaming_func.json @@ -13,14 +13,14 @@ { "id": 1, "tags": ["query_state"], - "name": "global_aggr_with_fun_time_weighted_avg", - "description": "global aggregation with function time_weighted_avg state checkpoint", + "name": "global_aggr_with_fun_avg_time_weighted", + "description": "global aggregation with function avg_time_weighted state checkpoint", "steps":[ { "statements": [ {"client":"python", "query_type": "table", "query":"drop stream if exists test35_state_stream1"}, {"client":"python", "query_type": "table", "exist":"test35_state_stream1", "exist_wait":2, "wait":1, "query":"create stream test35_state_stream1 (val int32, timestamp datetime64(3) default now64(3))"}, - {"client":"python", "query_type": "stream", "query_id":"3600", "depends_on_stream":"test35_state_stream1", "wait":1, "terminate":"manual", "query":"subscribe to select time_weighted_avg(val, timestamp) from test35_state_stream1 emit periodic 1s settings checkpoint_interval=1"}, + {"client":"python", "query_type": "stream", "query_id":"3600", "depends_on_stream":"test35_state_stream1", "wait":1, "terminate":"manual", "query":"subscribe to select avg_time_weighted(val, timestamp) from test35_state_stream1 emit periodic 1s settings checkpoint_interval=1"}, {"client":"python", "query_type": "table", "depends_on":"3600", "kill":"3600", "kill_wait":5, "wait":3, "query": "insert into test35_state_stream1(val, timestamp) values (1, '2020-02-02 20:00:00'), (2, '2020-02-02 20:00:01'), (3, '2020-02-02 20:00:03'), (3, '2020-02-02 20:00:04'), (3, '2020-02-02 20:00:05')"}, {"client":"python", "query_type": "table", "wait":1, "query":"unsubscribe to '3600'"} ]