Skip to content

Commit

Permalink
Merge pull request #420 from mk3008/419-experiment-typesafe-query-bui…
Browse files Browse the repository at this point in the history
…lder-cte

419 experiment typesafe query builder cte
  • Loading branch information
mk3008 authored May 27, 2024
2 parents 034c227 + 06362bd commit 9181547
Show file tree
Hide file tree
Showing 25 changed files with 1,007 additions and 139 deletions.
36 changes: 36 additions & 0 deletions src/Carbunql.TypeSafe/CTEDataSet.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using Carbunql.Analysis.Parser;
using Carbunql.Building;

namespace Carbunql.TypeSafe;

public class CTEDataSet(string name, SelectQuery query) : IDataSet
{
public string Name { get; set; } = name;

public Materialized Materialized { get; set; } = Materialized.Undefined;

public SelectQuery Query { get; init; } = query;

public List<string> Columns { get; init; } = query.GetColumnNames().ToList();

public SelectQuery BuildFromClause(SelectQuery query, string alias)
{
var cte = query.With(Query).As(Name);
cte.Materialized = Materialized;
query.From(new CTETable(cte).ToSelectable()).As(alias);
return query;
}

public SelectQuery BuildJoinClause(SelectQuery query, string join, string alias, string condition)
{
var cte = query.With(Query).As(Name);
cte.Materialized = Materialized;

var r = query.FromClause!.Join(cte, join).As(alias);
if (!string.IsNullOrEmpty(condition))
{
r.On(_ => ValueParser.Parse(condition));
}
return query;
}
}
10 changes: 10 additions & 0 deletions src/Carbunql.TypeSafe/CTEDefinition.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace Carbunql.TypeSafe;

public struct CTEDefinition
{
public string Name { get; set; }

public Type RowType { get; set; }

public string Query { get; set; }
}
42 changes: 42 additions & 0 deletions src/Carbunql.TypeSafe/CTETable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using Carbunql.Clauses;
using Carbunql.Tables;

namespace Carbunql.TypeSafe;

public class CTETable(CommonTable ct) : TableBase
{
public CommonTable CommonTable { get; init; } = ct;

/// <inheritdoc/>
public override IEnumerable<Token> GetTokens(Token? parent)
{
yield return new Token(this, parent, CommonTable.Alias);
}

/// <inheritdoc/>
public override IEnumerable<QueryParameter> GetParameters()
{
yield break;
}

/// <inheritdoc/>
public override IList<string> GetColumnNames() => CommonTable.GetColumnNames().ToList();

/// <inheritdoc/>
public override IEnumerable<PhysicalTable> GetPhysicalTables()
{
yield break;
}

/// <inheritdoc/>
public override IEnumerable<CommonTable> GetCommonTables()
{
yield break;
}

/// <inheritdoc/>
public override IEnumerable<SelectQuery> GetInternalQueries()
{
yield break;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal static string ToValue(this MemberExpression mem
// ex. Sql.Now, Sql.CurrentTimestamp
return CreateSqlCommand(mem);
}
if (mem.Expression is MemberExpression && typeof(ITableRowDefinition).IsAssignableFrom(tp))
if (mem.Expression is MemberExpression && typeof(IDataRow).IsAssignableFrom(tp))
{
//column
var table = ((MemberExpression)mem.Expression).Member.Name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ private static string CreateSqlCommand(this MethodCallExpression mce
}
throw new ArgumentException("Invalid arguments count for RowNumber.");

case nameof(Sql.RowNumberOrderbyBy):
case nameof(Sql.RowNumberOrderBy):
if (mce.Arguments.First() is NewExpression argOrderbyBy)
{
var argOrderbyByText = string.Join(",", argOrderbyBy.Arguments.Select(x => mainConverter(x, addParameter)));
Expand Down
148 changes: 134 additions & 14 deletions src/Carbunql.TypeSafe/FluentSelectQuery.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
using Carbunql.Analysis.Parser;
using Carbunql.Annotations;
using Carbunql.Building;
using Carbunql.Clauses;
using Carbunql.Definitions;
using Carbunql.Tables;
using Carbunql.TypeSafe.Extensions;
using Carbunql.Values;
using System.Linq.Expressions;
Expand Down Expand Up @@ -32,7 +36,7 @@ public FluentSelectQuery Select<T>(Expression<Func<T>> expression) where T : cla
return this;
}

public FluentSelectQuery InnerJoin<T>(Expression<Func<T>> tableExpression, Expression<Func<bool>> conditionExpression) where T : ITableRowDefinition
public FluentSelectQuery InnerJoin<T>(Expression<Func<T>> tableExpression, Expression<Func<bool>> conditionExpression) where T : IDataRow
{
#if DEBUG
var analyzed = ExpressionReader.Analyze(conditionExpression);
Expand All @@ -44,17 +48,15 @@ public FluentSelectQuery InnerJoin<T>(Expression<Func<T>> tableExpression, Expre
var compiledExpression = tableExpression.Compile();
var table = compiledExpression();


var prmManager = new ParameterManager(GetParameters(), AddParameter);

var condition = ToValue(conditionExpression.Body, prmManager.AddParaemter);

this.FromClause!.InnerJoin(table.TableDefinition).As(tableAlias).On(_ => ValueParser.Parse(condition));
table.DataSet.BuildJoinClause(this, "inner join", tableAlias, condition);

return this;
}

public FluentSelectQuery LeftJoin<T>(Expression<Func<T>> tableExpression, Expression<Func<bool>> conditionExpression) where T : ITableRowDefinition
public FluentSelectQuery LeftJoin<T>(Expression<Func<T>> tableExpression, Expression<Func<bool>> conditionExpression) where T : IDataRow
{
#if DEBUG
var analyzed = ExpressionReader.Analyze(conditionExpression);
Expand All @@ -66,17 +68,15 @@ public FluentSelectQuery LeftJoin<T>(Expression<Func<T>> tableExpression, Expres
var compiledExpression = tableExpression.Compile();
var table = compiledExpression();


var prmManager = new ParameterManager(GetParameters(), AddParameter);

var condition = ToValue(conditionExpression.Body, prmManager.AddParaemter);

this.FromClause!.LeftJoin(table.TableDefinition).As(tableAlias).On(_ => ValueParser.Parse(condition));
table.DataSet.BuildJoinClause(this, "left join", tableAlias, condition);

return this;
}

public FluentSelectQuery RightJoin<T>(Expression<Func<T>> tableExpression, Expression<Func<bool>> conditionExpression) where T : ITableRowDefinition
public FluentSelectQuery RightJoin<T>(Expression<Func<T>> tableExpression, Expression<Func<bool>> conditionExpression) where T : IDataRow
{
#if DEBUG
var analyzed = ExpressionReader.Analyze(conditionExpression);
Expand All @@ -88,17 +88,15 @@ public FluentSelectQuery RightJoin<T>(Expression<Func<T>> tableExpression, Expre
var compiledExpression = tableExpression.Compile();
var table = compiledExpression();


var prmManager = new ParameterManager(GetParameters(), AddParameter);

var condition = ToValue(conditionExpression.Body, prmManager.AddParaemter);

this.FromClause!.RightJoin(table.TableDefinition).As(tableAlias).On(_ => ValueParser.Parse(condition));
table.DataSet.BuildJoinClause(this, "right join", tableAlias, condition);

return this;
}

public FluentSelectQuery CrossJoin<T>(Expression<Func<T>> tableExpression) where T : ITableRowDefinition
public FluentSelectQuery CrossJoin<T>(Expression<Func<T>> tableExpression) where T : IDataRow
{

var tableAlias = ((MemberExpression)tableExpression.Body).Member.Name;
Expand All @@ -107,7 +105,7 @@ public FluentSelectQuery CrossJoin<T>(Expression<Func<T>> tableExpression) where
var compiledExpression = tableExpression.Compile();
var table = compiledExpression();

this.FromClause!.CrossJoin(table.TableDefinition).As(tableAlias);
table.DataSet.BuildJoinClause(this, "cross join", tableAlias);

return this;
}
Expand Down Expand Up @@ -189,4 +187,126 @@ private string RemoveRootBracketOrDefault(string value)
}
return value;
}

/// <summary>
/// Compiles a FluentSelectQuery for a specified table row definition type.
/// </summary>
/// <typeparam name="T">The type of the table row definition.</typeparam>
/// <returns>A compiled FluentSelectQuery of type T.</returns>
/// <exception cref="InvalidProgramException">
/// Thrown when the select clause does not include all required columns of the table row definition type.
/// </exception>
public FluentSelectQuery<T> Compile<T>(bool force = false) where T : IDataRow, new()
{
var q = new FluentSelectQuery<T>();

// Copy clauses and parameters to the new query object
q.WithClause = WithClause;
q.SelectClause = SelectClause;
q.FromClause = FromClause;
q.WhereClause = WhereClause;
q.GroupClause = GroupClause;
q.HavingClause = HavingClause;
q.WindowClause = WindowClause;
q.OperatableQueries = OperatableQueries;
q.OrderClause = OrderClause;
q.LimitClause = LimitClause;
q.Parameters = Parameters;

var clause = TableDefinitionClauseFactory.Create<T>();

if (force)
{
CorrectSelectClause(q, clause);
}

TypeValidate<T>(q, clause);

if (SelectClause == null)
{
foreach (var item in clause.OfType<ColumnDefinition>())
{
q.Select(q.FromClause!.Root, item.ColumnName);
}
};

return q;
}

private static void CorrectSelectClause(SelectQuery q, TableDefinitionClause clause)
{
if (q.SelectClause == null)
{
// End without making corrections
return;
}

// Check if all properties of T are specified in the select clause
var aliases = q.GetSelectableItems().Select(x => x.Alias).ToHashSet();
var missingColumns = clause.OfType<ColumnDefinition>().Where(x => !aliases.Contains(x.ColumnName));

// Automatically add missing columns
foreach (var item in missingColumns)
{
q.Select($"cast(null as {item.ColumnType.ToText()})").As(item.ColumnName);
}
return;
}

private static void TypeValidate<T>(SelectQuery q, TableDefinitionClause clause)
{
if (q.SelectClause != null)
{
// Check if all properties of T are specified in the select clause
var aliases = q.GetSelectableItems().Select(x => x.Alias).ToHashSet();
var missingColumns = clause.ColumnNames.Where(item => !aliases.Contains(item)).ToList();

if (missingColumns.Any())
{
// If there are missing columns, include all of them in the error message
throw new InvalidProgramException($"The select query is not compatible with '{typeof(T).Name}'. The following columns are missing: {string.Join(", ", missingColumns)}");
}
return;
}
else if (q.FromClause != null)
{
var actual = q.FromClause.Root.Table.GetTableFullName();
var expect = clause.GetTableFullName();

if (q.FromClause.Root.Table is VirtualTable v && v.Query is SelectQuery vq)
{
TypeValidate<T>(vq, clause);
}
else if (q.FromClause.Root.Table is CTETable ct)
{
// Check if all properties of T are specified in the select clause
var aliases = ct.GetColumnNames().ToHashSet();
var missingColumns = clause.ColumnNames.Where(item => !aliases.Contains(item)).ToList();

if (missingColumns.Any())
{
// If there are missing columns, include all of them in the error message
throw new InvalidProgramException($"The select query is not compatible with '{typeof(T).Name}'. The following columns are missing: {string.Join(", ", missingColumns)}");
}
return;
}
else if (!actual.Equals(expect))
{
throw new InvalidProgramException($"The select query is not compatible with '{typeof(T).Name}'. Expect: {expect}, Actual: {actual}");
}
return;
}
else
{
throw new InvalidProgramException($"The select query is not compatible with '{typeof(T).Name}'. FromClause is null.");
}
}
}

public class FluentSelectQuery<T> : FluentSelectQuery where T : IDataRow, new()
{
public T ToTable()
{
throw new InvalidOperationException();
}
}
9 changes: 9 additions & 0 deletions src/Carbunql.TypeSafe/IDataRow.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using Carbunql.Annotations;

namespace Carbunql.TypeSafe;

public interface IDataRow
{
[IgnoreMapping]
IDataSet DataSet { get; set; }
}
10 changes: 10 additions & 0 deletions src/Carbunql.TypeSafe/IDataSet.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace Carbunql.TypeSafe;

public interface IDataSet
{
public List<string> Columns { get; }

SelectQuery BuildFromClause(SelectQuery query, string alias);

SelectQuery BuildJoinClause(SelectQuery query, string join, string alias, string condition = "");
}
10 changes: 0 additions & 10 deletions src/Carbunql.TypeSafe/ITableRowDefinition.cs

This file was deleted.

30 changes: 30 additions & 0 deletions src/Carbunql.TypeSafe/PhysicalTableDataSet.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using Carbunql.Analysis.Parser;
using Carbunql.Building;
using Carbunql.Clauses;
using Carbunql.Definitions;
using Carbunql.Tables;

namespace Carbunql.TypeSafe;

public class PhysicalTableDataSet(ITable tb, IEnumerable<string> columns) : IDataSet
{
public SelectableTable Table { get; set; } = new PhysicalTable(tb).ToSelectable();

public List<string> Columns { get; init; } = columns.ToList();

public SelectQuery BuildFromClause(SelectQuery query, string alias)
{
query.From(Table).As(alias);
return query;
}

public SelectQuery BuildJoinClause(SelectQuery query, string join, string alias, string condition)
{
var r = query.FromClause!.Join(Table, join).As(alias);
if (!string.IsNullOrEmpty(condition))
{
r.On(_ => ValueParser.Parse(condition));
}
return query;
}
}
Loading

0 comments on commit 9181547

Please sign in to comment.