Skip to content

Commit

Permalink
#253 Added support for parsing when specifying three or more CTEs
Browse files Browse the repository at this point in the history
  • Loading branch information
mk3008 committed Oct 31, 2023
1 parent 6bdb7a8 commit 7ca1cb6
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 52 deletions.
178 changes: 126 additions & 52 deletions src/QueryBuilderByLinq/SelectQueryBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@ public SelectQuery Build(MethodCallExpression expression)
else if (expression.Arguments[0] is MethodCallExpression mce)
{
var sq = Build(mce);
sq = BuildNestedQuery(expression, sq);

if (sq.FromClause == null)
{
sq = BuildRootOrNestedQuery(expression, sq);
}
else
{
sq = BuildNestedQuery(expression, sq);
}
return sq;
}

Expand Down Expand Up @@ -200,78 +208,92 @@ private SelectQuery BuildRootQuery(MethodCallExpression expression, ConstantExpr
return null;
}

private SelectQuery BuildRootQuery(MethodCallExpression expression, SelectQuery sq, LambdaExpression? condition)
private SelectQuery BuildRootOrNestedQuery(MethodCallExpression expression, SelectQuery sq)
{
var exp = (MethodCallExpression)expression.Arguments[0];
sq.SetRootQuery(exp);
if (sq.FromClause != null) throw new Exception();

var tables = sq.GetSelectableTables().Select(x => x.Alias).ToList();
if (condition != null) sq.Where(condition.ToValue(tables));
return sq;
}

private SelectQuery BuildNestedQuery(MethodCallExpression expression, SelectQuery sq)
{
var condition = expression.GetConditionLambdaFromArguments();
var method = expression.GetMethodLambdaFromArguments();
var select = expression.GetSelectLambdaFromArguments();
var joinAlias = GetJoinParameter(select, condition);

if (sq.FromClause == null)
if (method == null)
{
if (method == null)
{
return BuildRootQuery(expression, sq, condition);
}
if (expression.Arguments.Count == 3 && method != null && joinAlias != null)
{
var mc = method.GetBody<MethodCallExpression>()!;
var exp = (MethodCallExpression)expression.Arguments[0];
sq.AddRootQuery(exp);

if (mc.Method.Name == nameof(Sql.InnerJoinTable) || mc.Method.Name == nameof(Sql.LeftJoinTable) || mc.Method.Name == nameof(Sql.CrossJoinTable))
{
var exp = (MethodCallExpression)expression.Arguments[0];
var tables = sq.GetSelectableTables().Select(x => x.Alias).ToList();
if (condition != null) sq.Where(condition.ToValue(tables));

// CTE - from - relation pattern
return sq;
}

sq.SetRootQuery(exp);
var mc = method.GetBody<MethodCallExpression>()!;

var text = sq.ToCommand().CommandText;
if (mc.Method.Name == nameof(Sql.FromTable))
{
return sq;
}
else if (mc.Method.Name == nameof(Sql.InnerJoinTable) || mc.Method.Name == nameof(Sql.LeftJoinTable) || mc.Method.Name == nameof(Sql.CrossJoinTable))
{
var joinParam = GetJoinParameter(select, condition);
if (joinParam == null) throw new Exception();

var ts = sq.GetSelectableTables().Select(x => x.Alias).ToList();
ts.Add(joinAlias.Name!);
sq.AddJoinClause(method, ts, joinAlias);
var exp = (MethodCallExpression)expression.Arguments[0];

if (condition != null) sq.Where(condition.ToValue(ts));
// CTE - from - relation pattern

return sq;
}
else if (mc.Method.Name == nameof(Sql.FromTable))
{
return sq;
}
}
}
sq.AddRootQuery(exp);

if (sq.FromClause == null) throw new NotSupportedException();
var ts = sq.GetSelectableTables().Select(x => x.Alias).ToList();
ts.Add(joinParam.Name!);
sq.AddJoinClause(method, ts, joinParam);

if (condition != null) sq.Where(condition.ToValue(ts));

return sq;
}
else if (select != null && mc.Method.Name == nameof(Sql.CommonTable))
{
var tables = sq.GetSelectableTables().Select(x => x.Alias).ToList();
// add CommonTable
var alias = select.Parameters.Last();

if (method != null && joinAlias != null)
var body = method.GetBody<MethodCallExpression>()!;
if (Queryable.TryParse(body, out var cte))
{
tables.Add(joinAlias.Name!);
sq.AddJoinClause(method, tables, joinAlias);
sq.With(cte.ToQueryAsPostgres()).As(alias!.Name!);
return sq;
}
}

if (condition != null) sq.Where(condition.ToValue(tables));
throw new Exception();
}

//refresh select clause
if (select != null)
{
sq.SelectClause = null;
sq.AddSelectClause(select, condition, tables);
}
return sq;
private SelectQuery BuildNestedQuery(MethodCallExpression expression, SelectQuery sq)
{
if (sq.FromClause == null) throw new NotSupportedException();

var condition = expression.GetConditionLambdaFromArguments();
var method = expression.GetMethodLambdaFromArguments();
var select = expression.GetSelectLambdaFromArguments();
var joinAlias = GetJoinParameter(select, condition);

var tables = sq.GetSelectableTables().Select(x => x.Alias).ToList();

if (method != null && joinAlias != null)
{
tables.Add(joinAlias.Name!);
sq.AddJoinClause(method, tables, joinAlias);
}

if (condition != null) sq.Where(condition.ToValue(tables));

//refresh select clause
if (select != null)
{
sq.SelectClause = null;
sq.AddSelectClause(select, condition, tables);
}
return sq;
}

private string GetTableNameOrDefault(UnaryExpression ue)
Expand Down Expand Up @@ -350,7 +372,59 @@ internal static IEnumerable<T> GetArguments<T>(this MethodCallExpression? expres
return lambdas.Where(x => x.ReturnType == typeof(bool)).FirstOrDefault();
}

internal static SelectQuery SetRootQuery(this SelectQuery sq, MethodCallExpression expression)
internal static SelectQuery AddRootQuery(this SelectQuery sq, MethodCallExpression expression)
{
var select = expression.GetSelectLambdaFromArguments();
var method = expression.GetMethodLambdaFromArguments();

if (select == null || select.Parameters.Count != 2) throw new NotSupportedException();

var table = select.Parameters[0];
var alias = select.Parameters[1];
var tableName = table.Name;

if (string.IsNullOrEmpty(table?.Name)) throw new NotSupportedException();
if (string.IsNullOrEmpty(alias?.Name)) throw new NotSupportedException();

var tables = new List<string> { alias.Name! };

var v = (ValueCollection)select.Body.ToValue(tables);

var columns = (ValueCollection)v[0];
var columnnames = columns.Select(x => ((ColumnValue)x).Column).ToList();

if (method != null)
{
var t = method.GetBody<MethodCallExpression>().GetArgument<ConstantExpression>(index: 0)!.Value!.ToString();

// select CTE
if (tableName != t)
{
tableName = t;
var w = sq.WithClause!.GetCommonTables().Where(x => x.Alias == tableName).First();
columnnames = w.GetColumnNames().ToList();
}
}

if (string.IsNullOrEmpty(tableName)) throw new NotSupportedException();

var pt = new PhysicalTable()
{
ColumnNames = columnnames,
Table = tableName
};

sq.From(pt.ToSelectable()).As(alias.Name);

foreach (var column in columnnames)
{
sq.Select(alias!.Name, column);
}

return sq;
}

internal static SelectQuery AddCommonTable(this SelectQuery sq, MethodCallExpression expression)
{
var select = expression.GetSelectLambdaFromArguments();
var method = expression.GetMethodLambdaFromArguments();
Expand Down
63 changes: 63 additions & 0 deletions test/QueryBuilderByLinq.Test/CommonTableTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,68 @@ cte1 AS b
Assert.Equal(sql.ToValidateText(), sq.ToText().ToValidateText());
}

[Fact]
public void CTEsTest3()
{
var sub_a1 = from a in FromTable<table_a>() select new { a.a_id, a.text };
var sub_a2 = from a in FromTable<table_a>() select new { a.a_id, a.value };
var sub_a3 = from a in FromTable<table_a>() select a;

var query = from cte1 in CommonTable(sub_a1)
from cte2 in CommonTable(sub_a2)
from cte3 in CommonTable(sub_a3)
from b in FromTable<table_a>(nameof(cte1))
from c in InnerJoinTable<table_a>(nameof(cte2), x => b.a_id == x.a_id)
from d in InnerJoinTable<table_a>(nameof(cte3), x => b.a_id == x.a_id)
where b.a_id == 1
select new { b, c, d };

var sq = query.ToQueryAsPostgres();

Monitor.Log(sq);

var sql = @"
WITH
cte1 AS (
SELECT
a.a_id,
a.text
FROM
table_a AS a
),
cte2 AS (
SELECT
a.a_id,
a.value
FROM
table_a AS a
),
cte3 AS (
SELECT
a.a_id,
a.text,
a.value
FROM
table_a AS a
)
SELECT
b.a_id,
b.text,
c.a_id,
c.value,
d.a_id,
d.text,
d.value
FROM
cte1 AS b
INNER JOIN cte2 AS c ON b.a_id = c.a_id
INNER JOIN cte3 AS d ON b.a_id = d.a_id
WHERE
b.a_id = 1";

Assert.Equal(117, sq.GetTokens().ToList().Count);
Assert.Equal(sql.ToValidateText(), sq.ToText().ToValidateText());
}

public record struct table_a(int a_id, string text, int value);
}

0 comments on commit 7ca1cb6

Please sign in to comment.