Skip to content

Commit

Permalink
#253 Support for table join queries using CTE
Browse files Browse the repository at this point in the history
  • Loading branch information
mk3008 committed Oct 24, 2023
1 parent 55c51b8 commit 3535083
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/QueryBuilderByLinq/MemberExpressionExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ internal static ValueBase ToValue(this MemberExpression exp, List<string> tables
if (exp.Expression is MemberExpression mem)
{
var table = tables.Where(x => x == mem.Member.Name).FirstOrDefault();
if (mem.Member.Name.StartsWith("<>h__TransparentIdentifier")) table = mem.Member.Name;

if (!string.IsNullOrEmpty(table))
{
Expand Down
62 changes: 44 additions & 18 deletions src/QueryBuilderByLinq/SelectQueryBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -233,33 +233,59 @@ private SelectQuery BuildRootQuery(MethodCallExpression expression, ConstantExpr
private SelectQuery BuildNestedQuery(MethodCallExpression expression, SelectQuery sq)
{
var where = GetWhereExpression(expression);

if (sq.FromClause == null && (expression.Arguments.Count == 2))
{
var exp = (MethodCallExpression)expression.Arguments[0];
return BuildRootQuery(exp, sq, where);
}

var join = GetJoinExpression(expression);
var select = GetSelectExpression(expression);
var joinAlias = GetJoinAlias(select, where);

var tables = sq.GetSelectableTables().Select(x => x.Alias).ToList();

if (join != null && joinAlias != null)
if (sq.FromClause == null)
{
tables.Add(joinAlias.Name!);
sq.AddJoinClause(join, tables, joinAlias);
if (expression.Arguments.Count == 2)
{
var exp = (MethodCallExpression)expression.Arguments[0];
sq = BuildRootQuery(exp, sq);

var ts = sq.GetSelectableTables().Select(x => x.Alias).ToList();
if (where != null) sq.Where(where.ToValue(ts));
return sq;
}
if (expression.Arguments.Count == 3 && join != null && joinAlias != null)
{
var exp = (MethodCallExpression)expression.Arguments[0];
sq = BuildRootQuery(exp, sq);

var ts = sq.GetSelectableTables().Select(x => x.Alias).ToList();
ts.Add(joinAlias.Name!);
sq.AddJoinClause(join, ts, joinAlias);

if (where != null) sq.Where(where.ToValue(ts));

return sq;
}
}

if (where != null) sq.Where(where.ToValue(tables));
if (sq.FromClause == null) throw new NotSupportedException();
{
var tables = sq.GetSelectableTables().Select(x => x.Alias).ToList();

//refresh select clause
sq.SelectClause = null;
return sq.AddSelectClause(select, where, tables);
if (join != null && joinAlias != null)
{
tables.Add(joinAlias.Name!);
sq.AddJoinClause(join, tables, joinAlias);
}

if (where != null) sq.Where(where.ToValue(tables));

//refresh select clause
if (select != null)
{
sq.SelectClause = null;
sq.AddSelectClause(select, where, tables);
}
return sq;
}
}

private SelectQuery BuildRootQuery(MethodCallExpression expression, SelectQuery sq, LambdaExpression? where)
private SelectQuery BuildRootQuery(MethodCallExpression expression, SelectQuery sq)
{
var select = GetSelectExpression(expression);

Expand All @@ -272,7 +298,7 @@ private SelectQuery BuildRootQuery(MethodCallExpression expression, SelectQuery
if (string.IsNullOrEmpty(alias?.Name)) throw new NotSupportedException();

var tables = new List<string> { alias.Name! };
if (where != null) sq.Where(where.ToValue(tables));
//if (where != null) sq.Where(where.ToValue(tables));

var v = (ValueCollection)select.Body.ToValue(tables);

Expand Down
20 changes: 18 additions & 2 deletions src/QueryBuilderByLinq/SelectQueryExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using Carbunql.Tables;
using Carbunql.Values;
using System.Linq.Expressions;
using System.Xml.Linq;

namespace QueryBuilderByLinq;

Expand Down Expand Up @@ -91,6 +90,10 @@ public static SelectQuery AddJoinClause(this SelectQuery sq, LambdaExpression jo
return sq;
}

if (sq.FromClause == null)
{
throw new InvalidProgramException();
}
var f = sq.FromClause!;

if (me.Method.Name == nameof(Sql.InnerJoinTable))
Expand All @@ -113,7 +116,20 @@ public static SelectQuery AddJoinClause(this SelectQuery sq, LambdaExpression jo
}
else if (me.GetJoinTableName(out var name))
{
f.InnerJoin(name).As(joinAlias.Name!).On((_) => condition);
var cte = sq.WithClause?.Where(x => x.Alias == name).FirstOrDefault();
if (cte != null && !string.IsNullOrEmpty(cte.Alias))
{
var t = new PhysicalTable()
{
Table = cte.Alias,
ColumnNames = cte.GetColumnNames().ToList()
};
f.InnerJoin(t.ToSelectable()).As(joinAlias.Name!).On((_) => condition);
}
else
{
f.InnerJoin(name).As(joinAlias.Name!).On((_) => condition);
}
}
else
{
Expand Down
39 changes: 39 additions & 0 deletions test/QueryBuilderByLinq.Test/CommonTableTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,44 @@ cte AS b
Assert.Equal(sql.ToValidateText(), sq.ToText().ToValidateText());
}

[Fact]
public void RelationTest()
{
var subq = from a in FromTable<table_a>() select new { ID = a.a_id, Text = a.text };

var query = from cte in CommonTable(subq)
from b in FromTable<table_a>(nameof(cte))
from c in InnerJoinTable<table_a>(nameof(cte), x => b.a_id == x.a_id)
where b.a_id == 1
select new { b, c };

var sq = query.ToQueryAsPostgres();

Monitor.Log(sq);

var sql = @"
WITH
cte AS (
SELECT
a.a_id AS ID,
a.text AS Text
FROM
table_a AS a
)
SELECT
b.ID,
b.Text,
c.ID,
c.Text
FROM
cte AS b
INNER JOIN cte AS c ON b.a_id = c.a_id
WHERE
b.a_id = 1";

Assert.Equal(59, sq.GetTokens().ToList().Count);
Assert.Equal(sql.ToValidateText(), sq.ToText().ToValidateText());
}

public record struct table_a(int a_id, string text, int value);
}

0 comments on commit 3535083

Please sign in to comment.