diff --git a/src/QueryBuilderByLinq/SelectQueryBuilder.cs b/src/QueryBuilderByLinq/SelectQueryBuilder.cs index 86eb234b..d6a3faee 100644 --- a/src/QueryBuilderByLinq/SelectQueryBuilder.cs +++ b/src/QueryBuilderByLinq/SelectQueryBuilder.cs @@ -24,7 +24,15 @@ public SelectQuery Build(MethodCallExpression expression) else if (expression.Arguments[0] is MethodCallExpression mce) { var sq = Build(mce); - sq = BuildNestedQuery(expression, sq); + + if (sq.FromClause == null) + { + sq = BuildRootOrNestedQuery(expression, sq); + } + else + { + sq = BuildNestedQuery(expression, sq); + } return sq; } @@ -200,78 +208,92 @@ private SelectQuery BuildRootQuery(MethodCallExpression expression, ConstantExpr return null; } - private SelectQuery BuildRootQuery(MethodCallExpression expression, SelectQuery sq, LambdaExpression? condition) + private SelectQuery BuildRootOrNestedQuery(MethodCallExpression expression, SelectQuery sq) { - var exp = (MethodCallExpression)expression.Arguments[0]; - sq.SetRootQuery(exp); + if (sq.FromClause != null) throw new Exception(); - var tables = sq.GetSelectableTables().Select(x => x.Alias).ToList(); - if (condition != null) sq.Where(condition.ToValue(tables)); - return sq; - } - - private SelectQuery BuildNestedQuery(MethodCallExpression expression, SelectQuery sq) - { var condition = expression.GetConditionLambdaFromArguments(); var method = expression.GetMethodLambdaFromArguments(); var select = expression.GetSelectLambdaFromArguments(); - var joinAlias = GetJoinParameter(select, condition); - if (sq.FromClause == null) + if (method == null) { - if (method == null) - { - return BuildRootQuery(expression, sq, condition); - } - if (expression.Arguments.Count == 3 && method != null && joinAlias != null) - { - var mc = method.GetBody()!; + var exp = (MethodCallExpression)expression.Arguments[0]; + sq.AddRootQuery(exp); - if (mc.Method.Name == nameof(Sql.InnerJoinTable) || mc.Method.Name == nameof(Sql.LeftJoinTable) || mc.Method.Name == nameof(Sql.CrossJoinTable)) - { - var exp = (MethodCallExpression)expression.Arguments[0]; + var tables = sq.GetSelectableTables().Select(x => x.Alias).ToList(); + if (condition != null) sq.Where(condition.ToValue(tables)); - // CTE - from - relation pattern + return sq; + } - sq.SetRootQuery(exp); + var mc = method.GetBody()!; - var text = sq.ToCommand().CommandText; + if (mc.Method.Name == nameof(Sql.FromTable)) + { + return sq; + } + else if (mc.Method.Name == nameof(Sql.InnerJoinTable) || mc.Method.Name == nameof(Sql.LeftJoinTable) || mc.Method.Name == nameof(Sql.CrossJoinTable)) + { + var joinParam = GetJoinParameter(select, condition); + if (joinParam == null) throw new Exception(); - var ts = sq.GetSelectableTables().Select(x => x.Alias).ToList(); - ts.Add(joinAlias.Name!); - sq.AddJoinClause(method, ts, joinAlias); + var exp = (MethodCallExpression)expression.Arguments[0]; - if (condition != null) sq.Where(condition.ToValue(ts)); + // CTE - from - relation pattern - return sq; - } - else if (mc.Method.Name == nameof(Sql.FromTable)) - { - return sq; - } - } - } + sq.AddRootQuery(exp); - if (sq.FromClause == null) throw new NotSupportedException(); + var ts = sq.GetSelectableTables().Select(x => x.Alias).ToList(); + ts.Add(joinParam.Name!); + sq.AddJoinClause(method, ts, joinParam); + + if (condition != null) sq.Where(condition.ToValue(ts)); + + return sq; + } + else if (select != null && mc.Method.Name == nameof(Sql.CommonTable)) { - var tables = sq.GetSelectableTables().Select(x => x.Alias).ToList(); + // add CommonTable + var alias = select.Parameters.Last(); - if (method != null && joinAlias != null) + var body = method.GetBody()!; + if (Queryable.TryParse(body, out var cte)) { - tables.Add(joinAlias.Name!); - sq.AddJoinClause(method, tables, joinAlias); + sq.With(cte.ToQueryAsPostgres()).As(alias!.Name!); + return sq; } + } - if (condition != null) sq.Where(condition.ToValue(tables)); + throw new Exception(); + } - //refresh select clause - if (select != null) - { - sq.SelectClause = null; - sq.AddSelectClause(select, condition, tables); - } - return sq; + private SelectQuery BuildNestedQuery(MethodCallExpression expression, SelectQuery sq) + { + if (sq.FromClause == null) throw new NotSupportedException(); + + var condition = expression.GetConditionLambdaFromArguments(); + var method = expression.GetMethodLambdaFromArguments(); + var select = expression.GetSelectLambdaFromArguments(); + var joinAlias = GetJoinParameter(select, condition); + + var tables = sq.GetSelectableTables().Select(x => x.Alias).ToList(); + + if (method != null && joinAlias != null) + { + tables.Add(joinAlias.Name!); + sq.AddJoinClause(method, tables, joinAlias); } + + if (condition != null) sq.Where(condition.ToValue(tables)); + + //refresh select clause + if (select != null) + { + sq.SelectClause = null; + sq.AddSelectClause(select, condition, tables); + } + return sq; } private string GetTableNameOrDefault(UnaryExpression ue) @@ -350,7 +372,59 @@ internal static IEnumerable GetArguments(this MethodCallExpression? expres return lambdas.Where(x => x.ReturnType == typeof(bool)).FirstOrDefault(); } - internal static SelectQuery SetRootQuery(this SelectQuery sq, MethodCallExpression expression) + internal static SelectQuery AddRootQuery(this SelectQuery sq, MethodCallExpression expression) + { + var select = expression.GetSelectLambdaFromArguments(); + var method = expression.GetMethodLambdaFromArguments(); + + if (select == null || select.Parameters.Count != 2) throw new NotSupportedException(); + + var table = select.Parameters[0]; + var alias = select.Parameters[1]; + var tableName = table.Name; + + if (string.IsNullOrEmpty(table?.Name)) throw new NotSupportedException(); + if (string.IsNullOrEmpty(alias?.Name)) throw new NotSupportedException(); + + var tables = new List { alias.Name! }; + + var v = (ValueCollection)select.Body.ToValue(tables); + + var columns = (ValueCollection)v[0]; + var columnnames = columns.Select(x => ((ColumnValue)x).Column).ToList(); + + if (method != null) + { + var t = method.GetBody().GetArgument(index: 0)!.Value!.ToString(); + + // select CTE + if (tableName != t) + { + tableName = t; + var w = sq.WithClause!.GetCommonTables().Where(x => x.Alias == tableName).First(); + columnnames = w.GetColumnNames().ToList(); + } + } + + if (string.IsNullOrEmpty(tableName)) throw new NotSupportedException(); + + var pt = new PhysicalTable() + { + ColumnNames = columnnames, + Table = tableName + }; + + sq.From(pt.ToSelectable()).As(alias.Name); + + foreach (var column in columnnames) + { + sq.Select(alias!.Name, column); + } + + return sq; + } + + internal static SelectQuery AddCommonTable(this SelectQuery sq, MethodCallExpression expression) { var select = expression.GetSelectLambdaFromArguments(); var method = expression.GetMethodLambdaFromArguments(); diff --git a/test/QueryBuilderByLinq.Test/CommonTableTest.cs b/test/QueryBuilderByLinq.Test/CommonTableTest.cs index a57ada89..15437902 100644 --- a/test/QueryBuilderByLinq.Test/CommonTableTest.cs +++ b/test/QueryBuilderByLinq.Test/CommonTableTest.cs @@ -138,5 +138,68 @@ cte1 AS b Assert.Equal(sql.ToValidateText(), sq.ToText().ToValidateText()); } + [Fact] + public void CTEsTest3() + { + var sub_a1 = from a in FromTable() select new { a.a_id, a.text }; + var sub_a2 = from a in FromTable() select new { a.a_id, a.value }; + var sub_a3 = from a in FromTable() select a; + + var query = from cte1 in CommonTable(sub_a1) + from cte2 in CommonTable(sub_a2) + from cte3 in CommonTable(sub_a3) + from b in FromTable(nameof(cte1)) + from c in InnerJoinTable(nameof(cte2), x => b.a_id == x.a_id) + from d in InnerJoinTable(nameof(cte3), x => b.a_id == x.a_id) + where b.a_id == 1 + select new { b, c, d }; + + var sq = query.ToQueryAsPostgres(); + + Monitor.Log(sq); + + var sql = @" +WITH + cte1 AS ( + SELECT + a.a_id, + a.text + FROM + table_a AS a + ), + cte2 AS ( + SELECT + a.a_id, + a.value + FROM + table_a AS a + ), + cte3 AS ( + SELECT + a.a_id, + a.text, + a.value + FROM + table_a AS a + ) +SELECT + b.a_id, + b.text, + c.a_id, + c.value, + d.a_id, + d.text, + d.value +FROM + cte1 AS b + INNER JOIN cte2 AS c ON b.a_id = c.a_id + INNER JOIN cte3 AS d ON b.a_id = d.a_id +WHERE + b.a_id = 1"; + + Assert.Equal(117, sq.GetTokens().ToList().Count); + Assert.Equal(sql.ToValidateText(), sq.ToText().ToValidateText()); + } + public record struct table_a(int a_id, string text, int value); } \ No newline at end of file