diff --git a/src/QueryBuilderByLinq/MemberExpressionExtension.cs b/src/QueryBuilderByLinq/MemberExpressionExtension.cs index bd43742f..c819c8d2 100644 --- a/src/QueryBuilderByLinq/MemberExpressionExtension.cs +++ b/src/QueryBuilderByLinq/MemberExpressionExtension.cs @@ -42,6 +42,7 @@ internal static ValueBase ToValue(this MemberExpression exp, List tables if (exp.Expression is MemberExpression mem) { var table = tables.Where(x => x == mem.Member.Name).FirstOrDefault(); + if (mem.Member.Name.StartsWith("<>h__TransparentIdentifier")) table = mem.Member.Name; if (!string.IsNullOrEmpty(table)) { diff --git a/src/QueryBuilderByLinq/SelectQueryBuilder.cs b/src/QueryBuilderByLinq/SelectQueryBuilder.cs index 5b2928e4..dabd2f27 100644 --- a/src/QueryBuilderByLinq/SelectQueryBuilder.cs +++ b/src/QueryBuilderByLinq/SelectQueryBuilder.cs @@ -233,33 +233,59 @@ private SelectQuery BuildRootQuery(MethodCallExpression expression, ConstantExpr private SelectQuery BuildNestedQuery(MethodCallExpression expression, SelectQuery sq) { var where = GetWhereExpression(expression); - - if (sq.FromClause == null && (expression.Arguments.Count == 2)) - { - var exp = (MethodCallExpression)expression.Arguments[0]; - return BuildRootQuery(exp, sq, where); - } - var join = GetJoinExpression(expression); var select = GetSelectExpression(expression); var joinAlias = GetJoinAlias(select, where); - var tables = sq.GetSelectableTables().Select(x => x.Alias).ToList(); - - if (join != null && joinAlias != null) + if (sq.FromClause == null) { - tables.Add(joinAlias.Name!); - sq.AddJoinClause(join, tables, joinAlias); + if (expression.Arguments.Count == 2) + { + var exp = (MethodCallExpression)expression.Arguments[0]; + sq = BuildRootQuery(exp, sq); + + var ts = sq.GetSelectableTables().Select(x => x.Alias).ToList(); + if (where != null) sq.Where(where.ToValue(ts)); + return sq; + } + if (expression.Arguments.Count == 3 && join != null && joinAlias != null) + { + var exp = (MethodCallExpression)expression.Arguments[0]; + sq = BuildRootQuery(exp, sq); + + var ts = sq.GetSelectableTables().Select(x => x.Alias).ToList(); + ts.Add(joinAlias.Name!); + sq.AddJoinClause(join, ts, joinAlias); + + if (where != null) sq.Where(where.ToValue(ts)); + + return sq; + } } - if (where != null) sq.Where(where.ToValue(tables)); + if (sq.FromClause == null) throw new NotSupportedException(); + { + var tables = sq.GetSelectableTables().Select(x => x.Alias).ToList(); - //refresh select clause - sq.SelectClause = null; - return sq.AddSelectClause(select, where, tables); + if (join != null && joinAlias != null) + { + tables.Add(joinAlias.Name!); + sq.AddJoinClause(join, tables, joinAlias); + } + + if (where != null) sq.Where(where.ToValue(tables)); + + //refresh select clause + if (select != null) + { + sq.SelectClause = null; + sq.AddSelectClause(select, where, tables); + } + return sq; + } } - private SelectQuery BuildRootQuery(MethodCallExpression expression, SelectQuery sq, LambdaExpression? where) + private SelectQuery BuildRootQuery(MethodCallExpression expression, SelectQuery sq) { var select = GetSelectExpression(expression); @@ -272,7 +298,7 @@ private SelectQuery BuildRootQuery(MethodCallExpression expression, SelectQuery if (string.IsNullOrEmpty(alias?.Name)) throw new NotSupportedException(); var tables = new List { alias.Name! }; - if (where != null) sq.Where(where.ToValue(tables)); + //if (where != null) sq.Where(where.ToValue(tables)); var v = (ValueCollection)select.Body.ToValue(tables); diff --git a/src/QueryBuilderByLinq/SelectQueryExtension.cs b/src/QueryBuilderByLinq/SelectQueryExtension.cs index d7b01b64..5dbccc77 100644 --- a/src/QueryBuilderByLinq/SelectQueryExtension.cs +++ b/src/QueryBuilderByLinq/SelectQueryExtension.cs @@ -4,7 +4,6 @@ using Carbunql.Tables; using Carbunql.Values; using System.Linq.Expressions; -using System.Xml.Linq; namespace QueryBuilderByLinq; @@ -91,6 +90,10 @@ public static SelectQuery AddJoinClause(this SelectQuery sq, LambdaExpression jo return sq; } + if (sq.FromClause == null) + { + throw new InvalidProgramException(); + } var f = sq.FromClause!; if (me.Method.Name == nameof(Sql.InnerJoinTable)) @@ -113,7 +116,20 @@ public static SelectQuery AddJoinClause(this SelectQuery sq, LambdaExpression jo } else if (me.GetJoinTableName(out var name)) { - f.InnerJoin(name).As(joinAlias.Name!).On((_) => condition); + var cte = sq.WithClause?.Where(x => x.Alias == name).FirstOrDefault(); + if (cte != null && !string.IsNullOrEmpty(cte.Alias)) + { + var t = new PhysicalTable() + { + Table = cte.Alias, + ColumnNames = cte.GetColumnNames().ToList() + }; + f.InnerJoin(t.ToSelectable()).As(joinAlias.Name!).On((_) => condition); + } + else + { + f.InnerJoin(name).As(joinAlias.Name!).On((_) => condition); + } } else { diff --git a/test/QueryBuilderByLinq.Test/CommonTableTest.cs b/test/QueryBuilderByLinq.Test/CommonTableTest.cs index cdec98e0..e3985b89 100644 --- a/test/QueryBuilderByLinq.Test/CommonTableTest.cs +++ b/test/QueryBuilderByLinq.Test/CommonTableTest.cs @@ -51,5 +51,44 @@ cte AS b Assert.Equal(sql.ToValidateText(), sq.ToText().ToValidateText()); } + [Fact] + public void RelationTest() + { + var subq = from a in FromTable() select new { ID = a.a_id, Text = a.text }; + + var query = from cte in CommonTable(subq) + from b in FromTable(nameof(cte)) + from c in InnerJoinTable(nameof(cte), x => b.a_id == x.a_id) + where b.a_id == 1 + select new { b, c }; + + var sq = query.ToQueryAsPostgres(); + + Monitor.Log(sq); + + var sql = @" +WITH + cte AS ( + SELECT + a.a_id AS ID, + a.text AS Text + FROM + table_a AS a + ) +SELECT + b.ID, + b.Text, + c.ID, + c.Text +FROM + cte AS b + INNER JOIN cte AS c ON b.a_id = c.a_id +WHERE + b.a_id = 1"; + + Assert.Equal(59, sq.GetTokens().ToList().Count); + Assert.Equal(sql.ToValidateText(), sq.ToText().ToValidateText()); + } + public record struct table_a(int a_id, string text, int value); } \ No newline at end of file