Skip to content

Commit

Permalink
support time_weighted functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasmine-ge committed Dec 5, 2024
1 parent 0cc5119 commit 33328d0
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 0 deletions.
70 changes: 70 additions & 0 deletions src/Interpreters/Streaming/SubstituteStreamingFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/formatAST.h>
Expand All @@ -16,6 +17,7 @@ namespace ErrorCodes
{
extern const int NOT_IMPLEMENTED;
extern const int FUNCTION_NOT_ALLOWED;
extern const int ILLEGAL_CODEC_PARAMETER;
}

namespace Streaming
Expand Down Expand Up @@ -131,6 +133,8 @@ void StreamingFunctionData::visit(DB::ASTFunction & func, DB::ASTPtr)
return;
}

translateTimeWeightedFunc(func, streaming);

if (streaming)
{
auto func_name_lower = Poco::toLower(func.name);
Expand Down Expand Up @@ -218,5 +222,71 @@ void substitueFunction(ASTFunction & func, const String & new_name)

func.name = new_name;
}

bool translateTimeWeightedFunc(ASTFunction & func, bool streaming)
{
static const std::unordered_map<String, String> map = {
{"time_weighted_avg", "avg_weighted"},
{"time_weighted_median", "median_timing_weighted"}
};

/// time_weighted_median(val, _tp_time) -> quantile_timing_weighted(__streaming_neighbor(val, -1, val), cast(date_diff('millisecond', neighbor(_tp_time, -1, _tp_time), _tp_time), 'uint64'))
/// time_weighted_avg(val, _tp_time) -> avg_weighted(__streaming_neighbor(val, -1, val), cast(date_diff('millisecond', neighbor(_tp_time, -1, _tp_time), _tp_time), 'uint64'))
if (!map.contains(func.name))
return false;

if (func.arguments->children.size() != 3 && func.arguments->children.size() != 2)
throw Exception(ErrorCodes::FUNCTION_NOT_ALLOWED, "{} aggregation function need two or three arguments", func.name);

String func_name = map.at(func.name);
auto val_arg = func.arguments->children[0];
auto time_arg = func.arguments->children[1];
String extra_arg = "";

if (func.arguments->children.size() == 3)
{
const auto * literal = func.arguments->children[2]->as<ASTLiteral>();
if (literal)
{
extra_arg = literal->value.safeGet<String>();
extra_arg = Poco::toLower(extra_arg);
if (extra_arg != "linear" && extra_arg != "locf")
throw Exception("Type argument can be linear or locf, given " + extra_arg, ErrorCodes::ILLEGAL_CODEC_PARAMETER);
}
}

auto neighbor_name = streaming ? "__streaming_neighbor" : "neighbor";
if (extra_arg == "linear")
{
auto translated_func = makeASTFunction(func_name,
makeASTFunction("divide",
makeASTFunction("plus",
makeASTFunction(neighbor_name, val_arg->clone(), std::make_shared<ASTLiteral>(Int8(-1)), val_arg->clone()),
val_arg->clone()),
std::make_shared<ASTLiteral>(2)),
makeASTFunction("cast",
makeASTFunction("date_diff",
std::make_shared<ASTLiteral>("millisecond"),
makeASTFunction("neighbor", time_arg->clone(), std::make_shared<ASTLiteral>(Int8(-1)), time_arg->clone()),
time_arg->clone()),
std::make_shared<ASTLiteral>("uint64")));
func = *translated_func;
}
else
{
auto translated_func = makeASTFunction(func_name,
makeASTFunction(neighbor_name, val_arg->clone(), std::make_shared<ASTLiteral>(Int8(-1)), val_arg->clone()),
makeASTFunction("cast",
makeASTFunction("date_diff",
std::make_shared<ASTLiteral>("millisecond"),
makeASTFunction("neighbor", time_arg->clone(), std::make_shared<ASTLiteral>(Int8(-1)), time_arg->clone()),
time_arg->clone()),
std::make_shared<ASTLiteral>("uint64")));
func = *translated_func;
}


return true;
}
}
}
2 changes: 2 additions & 0 deletions src/Interpreters/Streaming/SubstituteStreamingFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ using SubstituteStreamingFunctionVisitor


void substitueFunction(ASTFunction & func, const String & new_name);
bool translateTimeWeightedFunc(ASTFunction & func, bool streaming);


struct SubstituteFunctionsData
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
0
3.5
3.8
3.8
3.5
18 changes: 18 additions & 0 deletions tests/queries_ported/0_stateless/99010_time_weighted_avg.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
DROP STREAM IF EXISTS test_99010;

CREATE STREAM test_99010 (val int, a DateTime, b Date, c Date32, d DateTime64);

INSERT INTO test_99010(val, a, b, c, d) VALUES (1, to_datetime('2024-11-29 12:12:13'), '2024-11-29', '2024-11-29', to_datetime64('2024-11-29 12:12:13.123', 3));
INSERT INTO test_99010(val, a, b, c, d) VALUES (2, to_datetime('2024-11-29 12:12:16'), '2024-11-30', '2024-11-30', to_datetime64('2024-11-29 12:12:13.126', 3));
INSERT INTO test_99010(val, a, b, c, d) VALUES (3, to_datetime('2024-11-29 12:12:17'), '2024-12-01', '2024-12-01', to_datetime64('2024-11-29 12:12:13.127', 3));
INSERT INTO test_99010(val, a, b, c, d) VALUES (4, to_datetime('2024-11-29 12:12:18'), '2024-12-03', '2024-12-03', to_datetime64('2024-11-29 12:12:13.128', 3));
INSERT INTO test_99010(val, a, b, c, d) VALUES (5, to_datetime('2024-11-29 12:12:19'), '2024-12-28', '2024-12-28', to_datetime64('2024-11-29 12:12:13.129', 3));
INSERT INTO test_99010(val, a, b, c, d) VALUES (6, to_datetime('2024-11-29 12:12:25'), '2024-12-29', '2024-12-29', to_datetime64('2024-11-29 12:12:13.135', 3));
SELECT sleep(3);

SELECT time_weighted_avg(val, a) FROM (SELECT * FROM table(test_99010) ORDER BY a);
SELECT time_weighted_avg(val, b) FROM (SELECT * FROM table(test_99010) ORDER BY b);
SELECT time_weighted_avg(val, c) FROM (SELECT * FROM table(test_99010) ORDER BY c);
SELECT time_weighted_avg(val, d) FROM (SELECT * FROM table(test_99010) ORDER BY d);
DROP STREAM IF EXISTS test_99010;

39 changes: 39 additions & 0 deletions tests/stream/test_stream_smoke/0035_streaming_func.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"test_suite_name": "streaming_func",
"tag": "smoke",
"test_suite_config":{
"setup": {
"statements": [
]
},
"tests_2_run": {"ids_2_run": ["all"], "tags_2_run":[], "tags_2_skip":{"default":["todo", "to_support", "change", "bug", "sample"],"cluster": ["view", "cluster_table_bug"]}}
},
"comments": "Tests covering query state checkpointing smoke test cases",
"tests": [
{
"id": 1,
"tags": ["query_state"],
"name": "global_aggr_with_fun_time_weighted_avg",
"description": "global aggregation with function time_weighted_avg state checkpoint",
"steps":[
{
"statements": [
{"client":"python", "query_type": "table", "query":"drop stream if exists test35_state_stream1"},
{"client":"python", "query_type": "table", "exist":"test35_state_stream1", "exist_wait":2, "wait":1, "query":"create stream test35_state_stream1 (val int32, timestamp datetime64(3) default now64(3))"},
{"client":"python", "query_type": "stream", "query_id":"3600", "depends_on_stream":"test35_state_stream1", "wait":1, "terminate":"manual", "query":"subscribe to select time_weighted_avg(val, timestamp) from test35_state_stream1 emit periodic 1s settings checkpoint_interval=1"},
{"client":"python", "query_type": "table", "depends_on":"3600", "kill":"3600", "kill_wait":5, "wait":3, "query": "insert into test35_state_stream1(val, timestamp) values (1, '2020-02-02 20:00:00'), (2, '2020-02-02 20:00:01'), (3, '2020-02-02 20:00:03'), (3, '2020-02-02 20:00:04'), (3, '2020-02-02 20:00:05')"},
{"client":"python", "query_type": "table", "wait":1, "query":"unsubscribe to '3600'"}
]
}
],
"expected_results": [
{
"query_id":"3600",
"expected_results":[
["2.2"]
]
}
]
}
]
}

0 comments on commit 33328d0

Please sign in to comment.