Skip to content

Commit

Permalink
Merge pull request #440 from mk3008/439-typesafe-query-builder-review…
Browse files Browse the repository at this point in the history
…-window-functions

Supports window functions
  • Loading branch information
mk3008 authored Jun 10, 2024
2 parents 426d146 + e30d063 commit 208e32b
Show file tree
Hide file tree
Showing 7 changed files with 568 additions and 152 deletions.
44 changes: 44 additions & 0 deletions src/Carbunql.TypeSafe/Extensions/ExpresionExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using System.Linq.Expressions;

namespace Carbunql.TypeSafe.Extensions;

public static class ExpresionExtension
{
public static string ToValue(this Expression exp, Func<string, object?, string> addParameter)
{
if (exp is MemberExpression mem)
{
return mem.ToValue(ToValue, addParameter);
}
else if (exp is ConstantExpression ce)
{
return ce.ToValue(ToValue, addParameter);
}
else if (exp is NewExpression ne)
{
return ne.ToValue(ToValue, addParameter);
}
else if (exp is BinaryExpression be)
{
return be.ToValue(ToValue, addParameter);
}
else if (exp is UnaryExpression ue)
{
return ue.ToValue(ToValue, addParameter);
}
else if (exp is MethodCallExpression mce)
{
return mce.ToValue(ToValue, addParameter);
}
else if (exp is ConditionalExpression cnd)
{
return cnd.ToValue(ToValue, addParameter);
}
else if (exp is ParameterExpression prm)
{
return prm.ToValue(ToValue, addParameter);
}

throw new InvalidProgramException(exp.ToString());
}
}
169 changes: 116 additions & 53 deletions src/Carbunql.TypeSafe/Extensions/MethodCallExpressionExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Carbunql.Annotations;
using Carbunql.Building;
using Carbunql.Extensions;
using Carbunql.TypeSafe.Extensions;
using Carbunql.Values;
using System.Collections;
using System.Linq.Expressions;
Expand Down Expand Up @@ -133,52 +134,8 @@ private static string CreateSqlCommand(this MethodCallExpression mce
break;

case nameof(Sql.RowNumber):
if (mce.Arguments.Count == 0)
{
try
{
return DbmsConfiguration.GetRowNumberCommandLogic();
}
catch (Exception ex)
{
throw new InvalidOperationException("Failed to get RowNumber command logic.", ex);
}
}
if (mce.Arguments.Count == 2)
{
try
{
var argList = mce.Arguments.ToList();
if (argList[0] is NewExpression arg1st && argList[1] is NewExpression arg2nd)
{
var arg1stText = string.Join(",", arg1st.Arguments.Select(x => mainConverter(x, addParameter)));
var arg2ndText = string.Join(",", arg2nd.Arguments.Select(x => mainConverter(x, addParameter)));

return DbmsConfiguration.GetRowNumberPartitionByOrderByCommandLogic(arg1stText, arg2ndText);
}
}
catch (Exception ex)
{
throw new InvalidOperationException("Failed to process RowNumber with parameters.", ex);
}
}
throw new ArgumentException("Invalid arguments count for RowNumber.");

case nameof(Sql.RowNumberOrderBy):
if (mce.Arguments.First() is NewExpression argOrderbyBy)
{
var argOrderbyByText = string.Join(",", argOrderbyBy.Arguments.Select(x => mainConverter(x, addParameter)));
return DbmsConfiguration.GetRowNumberOrderByCommandLogic(argOrderbyByText);
}
break;
return Aggregate(mce, mainConverter, addParameter, "row_number");

case nameof(Sql.RowNumberPartitionBy):
if (mce.Arguments.First() is NewExpression argPartitionBy)
{
var argPartitionByText = string.Join(",", argPartitionBy.Arguments.Select(x => mainConverter(x, addParameter)));
return DbmsConfiguration.GetRowNumberPartitionByCommandLogic(argPartitionByText);
}
break;

case nameof(Sql.Exists):
case nameof(Sql.NotExists):
Expand All @@ -205,35 +162,141 @@ private static string CreateSqlCommand(this MethodCallExpression mce
}

private static string Aggregate(MethodCallExpression mce
, Func<Expression, Func<string, object?, string>, string> mainConverter
, Func<string, object?, string> addParameter
, string aggregateFunction)
, Func<Expression, Func<string, object?, string>, string> mainConverter
, Func<string, object?, string> addParameter
, string aggregateFunction)
{
#if DEBUG
// Analyze the expression tree for debugging purposes
var analyze = ExpressionReader.Analyze(mce);
#endif

// Extract the main aggregate function
string value;
if (aggregateFunction.IsEqualNoCase("count"))
{
return $"{aggregateFunction}(*)";
value = "count(*)";
}
else if (aggregateFunction.IsEqualNoCase("row_number"))
{
value = "row_number()";
}
else
{
value = ExtractFunction(mce, mainConverter, addParameter, aggregateFunction, mce.Arguments[0]);
}

// Determine the argument indices for partition and order
int partitionArgumentIndex = (aggregateFunction.IsEqualNoCase("count") || aggregateFunction.IsEqualNoCase("row_number")) ? 0 : 1;
int orderArgumentIndex = partitionArgumentIndex + 1;

// Extract the partition and order clauses
// The arguments can be:
// - Main function argument, partition, order
// - Partition, order
// There are no functions that only have a partition or only have an order.
string partitionby = mce.Arguments.Count <= partitionArgumentIndex ? string.Empty : ExtractPartition(mce, mainConverter, addParameter, mce.Arguments[partitionArgumentIndex]);
string orderby = mce.Arguments.Count <= orderArgumentIndex ? string.Empty : ExtractOrder(mce, mainConverter, addParameter, mce.Arguments[orderArgumentIndex]);

// Construct the final SQL function string with the over clause
if (!string.IsNullOrEmpty(partitionby) && !string.IsNullOrEmpty(orderby))
{
value += $" over({partitionby} {orderby})";
}
else if (!string.IsNullOrEmpty(partitionby))
{
value += $" over({partitionby})";
}
else if (!string.IsNullOrEmpty(orderby))
{
value += $" over({orderby})";
}

var ue = (UnaryExpression)mce.Arguments[0];
return value;
}

private static string ExtractFunction(MethodCallExpression mce
, Func<Expression, Func<string, object?, string>, string> mainConverter
, Func<string, object?, string> addParameter
, string functionName
, Expression? argument)
{
if (argument == null) throw new NotSupportedException();

var ue = (UnaryExpression)argument;
var expression = (LambdaExpression)ue.Operand;

if (expression.Body is BinaryExpression be)
{
var value = be.ToValue(mainConverter, addParameter);
return $"{aggregateFunction}({value})";
return $"{functionName}({value})";
}
if (expression.Body is MemberExpression me)
{
var value = me.ToValue(mainConverter, addParameter);
return $"{aggregateFunction}({value})";
return $"{functionName}({value})";
}

throw new NotSupportedException();
}

private static string ExtractPartition(MethodCallExpression mce
, Func<Expression, Func<string, object?, string>, string> mainConverter
, Func<string, object?, string> addParameter
, Expression argument)
{
var functionName = "partition by";

var ue = (UnaryExpression)argument;
var expression = (LambdaExpression)ue.Operand;

if (expression.Body is ConstantExpression) return string.Empty;

if (expression.Body is NewExpression ne && ne.Members != null)
{
var value = string.Join(",", ne.Arguments.Select(x => x.ToValue(addParameter)));
return $"{functionName} {value}";
}

throw new NotSupportedException();
}

private static string ExtractOrder(MethodCallExpression mce
, Func<Expression, Func<string, object?, string>, string> mainConverter
, Func<string, object?, string> addParameter
, Expression argument)
{
var functionName = "order by";

var ue = (UnaryExpression)argument;
var expression = (LambdaExpression)ue.Operand;

if (expression.Body is ConstantExpression) return string.Empty;

if (expression.Body is NewExpression ne && ne.Members != null)
{
var cnt = ne.Members.Count();
var args = new List<string>() { Capacity = cnt };

// If an alias is specified, it is determined to be in "descending order".
for (var i = 0; i < cnt; i++)
{
var alias = ne.Members[i].Name;
var val = ne.Arguments[i].ToValue(addParameter);
if (ValueParser.Parse(val).GetDefaultName() == alias)
{
args.Add(val);
}
else
{
args.Add($"{val} desc");
}
}
var value = string.Join(",", args);
return $"{functionName} {value}";
}

throw new NotSupportedException();
}

private static string ToExistsClause(MethodCallExpression mce)
Expand All @@ -247,7 +310,7 @@ private static string ToExistsClause(MethodCallExpression mce)
var fsql = new FluentSelectQuery();
var (f, x) = fsql.From(clause).As(expression.Parameters[0].Name!);
var prmManager = new ParameterManager(fsql.GetParameters(), fsql.AddParameter);
var value = fsql.ToValue(expression.Body, prmManager.AddParameter);
var value = expression.Body.ToValue(prmManager.AddParameter);
fsql.Where(value);

if (mce.Method.Name == nameof(Sql.Exists))
Expand Down
Loading

0 comments on commit 208e32b

Please sign in to comment.