Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasmine-ge committed Dec 5, 2024
1 parent 33328d0 commit c622510
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 30 deletions.
60 changes: 31 additions & 29 deletions src/Interpreters/Streaming/SubstituteStreamingFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace ErrorCodes
extern const int NOT_IMPLEMENTED;
extern const int FUNCTION_NOT_ALLOWED;
extern const int ILLEGAL_CODEC_PARAMETER;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}

namespace Streaming
Expand Down Expand Up @@ -133,7 +134,7 @@ void StreamingFunctionData::visit(DB::ASTFunction & func, DB::ASTPtr)
return;
}

translateTimeWeightedFunc(func, streaming);
translateTimeWeightedFunc(func);

if (streaming)
{
Expand Down Expand Up @@ -223,69 +224,70 @@ void substitueFunction(ASTFunction & func, const String & new_name)
func.name = new_name;
}

bool translateTimeWeightedFunc(ASTFunction & func, bool streaming)
bool translateTimeWeightedFunc(ASTFunction & func)
{
static const std::unordered_map<String, String> map = {
{"time_weighted_avg", "avg_weighted"},
{"time_weighted_median", "median_timing_weighted"}
};

/// time_weighted_median(val, _tp_time) -> quantile_timing_weighted(__streaming_neighbor(val, -1, val), cast(date_diff('millisecond', neighbor(_tp_time, -1, _tp_time), _tp_time), 'uint64'))
/// time_weighted_avg(val, _tp_time) -> avg_weighted(__streaming_neighbor(val, -1, val), cast(date_diff('millisecond', neighbor(_tp_time, -1, _tp_time), _tp_time), 'uint64'))
/// time_weighted_median(val, _tp_time) -> quantile_timing_weighted(lag(val, 1, val), cast(date_diff('millisecond', lag(_tp_time, 1, _tp_time), _tp_time), 'uint64'))
/// time_weighted_avg(val, _tp_time) -> avg_weighted(lag(val, 1, val), cast(date_diff('millisecond', lag(_tp_time, 1, _tp_time), _tp_time), 'uint64'))
if (!map.contains(func.name))
return false;

if (func.arguments->children.size() != 3 && func.arguments->children.size() != 2)
throw Exception(ErrorCodes::FUNCTION_NOT_ALLOWED, "{} aggregation function need two or three arguments", func.name);

auto num_args = func.arguments->children.size();
if (num_args != 2 && num_args != 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Number of arguments for function {} doesn't match: passed {}, should be 2 or 3", func.name, num_args);

String func_name = map.at(func.name);
auto val_arg = func.arguments->children[0];
auto time_arg = func.arguments->children[1];
String extra_arg = "";
String algorithm = "locf"; /// LOCF or Linear

if (func.arguments->children.size() == 3)
if (num_args == 3)
{
const auto * literal = func.arguments->children[2]->as<ASTLiteral>();
if (literal)
{
extra_arg = literal->value.safeGet<String>();
extra_arg = Poco::toLower(extra_arg);
if (extra_arg != "linear" && extra_arg != "locf")
throw Exception("Type argument can be linear or locf, given " + extra_arg, ErrorCodes::ILLEGAL_CODEC_PARAMETER);
}
algorithm = Poco::toLower(literal->value.safeGet<String>());
if (!literal || (algorithm != "linear" && algorithm != "locf"))
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Third argument must be literal string of algorithm: 'linear' or 'locf', but given '{}'" , algorithm);
}

auto neighbor_name = streaming ? "__streaming_neighbor" : "neighbor";
if (extra_arg == "linear")
if (algorithm == "linear")
{
auto translated_func = makeASTFunction(func_name,
func.name = func_name;
func.arguments = std::make_shared<ASTExpressionList>();
func.children.push_back(func.arguments);
func.arguments->children = {
makeASTFunction("divide",
makeASTFunction("plus",
makeASTFunction(neighbor_name, val_arg->clone(), std::make_shared<ASTLiteral>(Int8(-1)), val_arg->clone()),
makeASTFunction("lag", val_arg->clone(), std::make_shared<ASTLiteral>(Int8(1)), val_arg->clone()),
val_arg->clone()),
std::make_shared<ASTLiteral>(2)),
makeASTFunction("cast",
makeASTFunction("date_diff",
std::make_shared<ASTLiteral>("millisecond"),
makeASTFunction("neighbor", time_arg->clone(), std::make_shared<ASTLiteral>(Int8(-1)), time_arg->clone()),
makeASTFunction("lag", time_arg->clone(), std::make_shared<ASTLiteral>(Int8(1)), time_arg->clone()),
time_arg->clone()),
std::make_shared<ASTLiteral>("uint64")));
func = *translated_func;
std::make_shared<ASTLiteral>("uint64"))
};
}
else
{
auto translated_func = makeASTFunction(func_name,
makeASTFunction(neighbor_name, val_arg->clone(), std::make_shared<ASTLiteral>(Int8(-1)), val_arg->clone()),
func.name = func_name;
func.arguments = std::make_shared<ASTExpressionList>();
func.children.push_back(func.arguments);
func.arguments->children = {
makeASTFunction("lag", val_arg->clone(), std::make_shared<ASTLiteral>(Int8(1)), val_arg->clone()),
makeASTFunction("cast",
makeASTFunction("date_diff",
std::make_shared<ASTLiteral>("millisecond"),
makeASTFunction("neighbor", time_arg->clone(), std::make_shared<ASTLiteral>(Int8(-1)), time_arg->clone()),
makeASTFunction("lag", time_arg->clone(), std::make_shared<ASTLiteral>(Int8(1)), time_arg->clone()),
time_arg->clone()),
std::make_shared<ASTLiteral>("uint64")));
func = *translated_func;
std::make_shared<ASTLiteral>("uint64"))
};
}


return true;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/Interpreters/Streaming/SubstituteStreamingFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ using SubstituteStreamingFunctionVisitor


void substitueFunction(ASTFunction & func, const String & new_name);
bool translateTimeWeightedFunc(ASTFunction & func, bool streaming);
bool translateTimeWeightedFunc(ASTFunction & func);


struct SubstituteFunctionsData
Expand Down

0 comments on commit c622510

Please sign in to comment.