From f74588e926484dbb4e697fd1cd6a664c0c0e66cc Mon Sep 17 00:00:00 2001 From: warny Date: Sun, 4 Jan 2026 18:09:30 +0100 Subject: [PATCH 1/2] Simplify select parser terminator wiring --- Utils.Data/Properties/AssemblyInfo.cs | 3 + Utils.Data/Sql/ClauseStartKeywordRegistry.cs | 98 ++ Utils.Data/Sql/DeleteStatementParser.cs | 80 ++ Utils.Data/Sql/ExpressionListReader.cs | 57 + Utils.Data/Sql/ExpressionReader.cs | 209 ++++ Utils.Data/Sql/FieldListReader.cs | 113 ++ Utils.Data/Sql/InsertStatementParser.cs | 87 ++ Utils.Data/Sql/PredicateReader.cs | 90 ++ Utils.Data/Sql/SelectStatementParser.cs | 117 ++ Utils.Data/Sql/SqlParser.cs | 445 +++----- Utils.Data/Sql/SqlQueryAnalyzer.cs | 253 +++++ Utils.Data/Sql/SqlStatementPartReaders.cs | 1000 +++++++++++++++++ Utils.Data/Sql/SqlStatementParts.cs | 252 +++++ Utils.Data/Sql/TableListReader.cs | 170 +++ Utils.Data/Sql/UpdateStatementParser.cs | 70 ++ UtilsTest/Data/ExpressionReaderTests.cs | 79 ++ .../Data/FieldAndPredicateReaderTests.cs | 74 ++ UtilsTest/Data/SqlQueryAnalyzerTests.cs | 18 + UtilsTest/Data/SqlStatementPartReaderTests.cs | 114 ++ UtilsTest/Data/SqlStatementPartTests.cs | 91 ++ UtilsTest/Data/TableListReaderTests.cs | 72 ++ 21 files changed, 3182 insertions(+), 310 deletions(-) create mode 100644 Utils.Data/Properties/AssemblyInfo.cs create mode 100644 Utils.Data/Sql/ClauseStartKeywordRegistry.cs create mode 100644 Utils.Data/Sql/DeleteStatementParser.cs create mode 100644 Utils.Data/Sql/ExpressionListReader.cs create mode 100644 Utils.Data/Sql/ExpressionReader.cs create mode 100644 Utils.Data/Sql/FieldListReader.cs create mode 100644 Utils.Data/Sql/InsertStatementParser.cs create mode 100644 Utils.Data/Sql/PredicateReader.cs create mode 100644 Utils.Data/Sql/SelectStatementParser.cs create mode 100644 Utils.Data/Sql/SqlStatementPartReaders.cs create mode 100644 Utils.Data/Sql/SqlStatementParts.cs create mode 100644 Utils.Data/Sql/TableListReader.cs create mode 100644 Utils.Data/Sql/UpdateStatementParser.cs create mode 100644 UtilsTest/Data/ExpressionReaderTests.cs create mode 100644 UtilsTest/Data/FieldAndPredicateReaderTests.cs create mode 100644 UtilsTest/Data/SqlStatementPartReaderTests.cs create mode 100644 UtilsTest/Data/SqlStatementPartTests.cs create mode 100644 UtilsTest/Data/TableListReaderTests.cs diff --git a/Utils.Data/Properties/AssemblyInfo.cs b/Utils.Data/Properties/AssemblyInfo.cs new file mode 100644 index 0000000..9c44e60 --- /dev/null +++ b/Utils.Data/Properties/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("UtilsTest")] diff --git a/Utils.Data/Sql/ClauseStartKeywordRegistry.cs b/Utils.Data/Sql/ClauseStartKeywordRegistry.cs new file mode 100644 index 0000000..843366f --- /dev/null +++ b/Utils.Data/Sql/ClauseStartKeywordRegistry.cs @@ -0,0 +1,98 @@ +using System.Collections.Generic; + +namespace Utils.Data.Sql; + +/// +/// Provides keyword metadata for clause boundaries so can detect clause transitions +/// without hardcoded keyword checks. +/// +internal static class ClauseStartKeywordRegistry +{ + /// + /// Gets the default mapping between clause identifiers and the keyword sequences that start the clause. + /// + public static IReadOnlyDictionary>> KnownClauseKeywords { get; } = + BuildKnownClauseKeywords(); + + /// + /// Attempts to retrieve the keyword sequences that mark the beginning of the specified clause. + /// + /// The clause identifier. + /// The keyword sequences associated with the clause. + /// true when the clause metadata is available; otherwise, false. + public static bool TryGetClauseKeywords( + ClauseStart clauseStart, + out IReadOnlyList> keywordSequences) + { + return KnownClauseKeywords.TryGetValue(clauseStart, out keywordSequences!); + } + + private static IReadOnlyDictionary>> BuildKnownClauseKeywords() + { + var definitions = new List + { + SelectPartReader.KeywordDefinition, + FromPartReader.KeywordDefinition, + IntoPartReader.KeywordDefinition, + WherePartReader.KeywordDefinition, + GroupByPartReader.KeywordDefinition, + HavingPartReader.KeywordDefinition, + OrderByPartReader.KeywordDefinition, + LimitPartReader.KeywordDefinition, + OffsetPartReader.KeywordDefinition, + ValuesPartReader.KeywordDefinition, + OutputPartReader.KeywordDefinition, + ReturningPartReader.KeywordDefinition, + SetOperatorPartReader.KeywordDefinition, + ClauseKeywordDefinition.FromKeywords(ClauseStart.Using, new[] { "USING" }), + }; + + var map = new Dictionary>>(); + foreach (var definition in definitions) + { + map[definition.ClauseKeyword] = definition.KeywordSequences; + } + + return map; + } +} + +/// +/// Represents the keyword sequences that start a specific SQL clause. +/// +internal sealed class ClauseKeywordDefinition +{ + /// + /// Initializes a new instance of the class. + /// + /// The clause identifier. + /// The keyword sequences that open the clause. + public ClauseKeywordDefinition( + ClauseStart clauseKeyword, + IReadOnlyList> keywordSequences) + { + ClauseKeyword = clauseKeyword; + KeywordSequences = keywordSequences; + } + + /// + /// Gets the clause identifier associated with the keywords. + /// + public ClauseStart ClauseKeyword { get; } + + /// + /// Gets the keyword sequences that mark the start of the clause. + /// + public IReadOnlyList> KeywordSequences { get; } + + /// + /// Creates a clause keyword definition from the specified keyword sequences. + /// + /// The clause identifier. + /// The keyword sequences that start the clause. + /// The created . + public static ClauseKeywordDefinition FromKeywords(ClauseStart clauseKeyword, params string[][] keywordSequences) + { + return new ClauseKeywordDefinition(clauseKeyword, keywordSequences); + } +} diff --git a/Utils.Data/Sql/DeleteStatementParser.cs b/Utils.Data/Sql/DeleteStatementParser.cs new file mode 100644 index 0000000..6a0de8e --- /dev/null +++ b/Utils.Data/Sql/DeleteStatementParser.cs @@ -0,0 +1,80 @@ +using System; + +namespace Utils.Data.Sql; + +#nullable enable + +/// +/// Parses DELETE statements relying on a shared instance. +/// +internal sealed class DeleteStatementParser +{ + private readonly SqlParser parser; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying parser providing token access. + public DeleteStatementParser(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Parses a DELETE statement. + /// + /// The WITH clause bound to the statement, if present. + /// The parsed . + public SqlDeleteStatement Parse(WithClause? withClause) + { + parser.ExpectKeyword("DELETE"); + var deleteTargetReader = new DeletePartReader(parser); + var fromReader = new FromPartReader(parser); + var whereReader = new WherePartReader(parser); + var outputReader = new OutputPartReader(parser); + var returningReader = new ReturningPartReader(parser); + + var targetSegment = deleteTargetReader.TryReadDeleteTarget(); + + var fromSegment = fromReader.TryReadFromPart( + outputReader.ClauseKeyword, + ClauseStart.Using, + whereReader.ClauseKeyword, + returningReader.ClauseKeyword, + ClauseStart.StatementEnd); + if (fromSegment == null) + { + throw new SqlParseException("Expected FROM clause in DELETE statement."); + } + + SqlSegment? usingSegment = null; + SqlSegment? whereSegment = null; + SqlSegment? outputSegment = null; + SqlSegment? returningSegment = null; + + if (parser.TryConsumeKeyword("OUTPUT")) + { + outputSegment = outputReader.ReadOutputPart( + "Output", + ClauseStart.Using, + whereReader.ClauseKeyword, + returningReader.ClauseKeyword, + ClauseStart.StatementEnd); + } + + if (parser.TryConsumeKeyword("USING")) + { + var usingTokens = parser.ReadSectionTokens(ClauseStart.Where, ClauseStart.Returning, ClauseStart.StatementEnd); + usingSegment = parser.BuildSegment("Using", usingTokens); + } + + whereSegment = whereReader.TryReadWherePart(returningReader.ClauseKeyword, ClauseStart.StatementEnd); + + if (parser.TryConsumeKeyword("RETURNING")) + { + returningSegment = returningReader.ReadReturningPart(ClauseStart.StatementEnd); + } + + return new SqlDeleteStatement(targetSegment, fromSegment, usingSegment, whereSegment, outputSegment, returningSegment, withClause); + } +} diff --git a/Utils.Data/Sql/ExpressionListReader.cs b/Utils.Data/Sql/ExpressionListReader.cs new file mode 100644 index 0000000..ed2a92f --- /dev/null +++ b/Utils.Data/Sql/ExpressionListReader.cs @@ -0,0 +1,57 @@ +using System; +using System.Collections.Generic; + +namespace Utils.Data.Sql; + +#nullable enable + +/// +/// Reads comma-separated SQL expressions using an underlying . +/// +internal sealed class ExpressionListReader +{ + private readonly SqlParser parser; + private readonly ExpressionReader expressionReader; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public ExpressionListReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + expressionReader = new ExpressionReader(this.parser); + } + + /// + /// Reads a sequence of expressions separated by commas until a clause boundary is reached. + /// + /// Prefix used to name expression segments. + /// Indicates whether expressions may declare aliases. + /// Clause boundaries that stop the list. + /// The parsed expressions with their aliases. + public IReadOnlyList ReadExpressions(string segmentNamePrefix, bool allowAliases, params ClauseStart[] clauseTerminators) + { + if (string.IsNullOrWhiteSpace(segmentNamePrefix)) + { + throw new ArgumentException("Segment name prefix cannot be null or whitespace.", nameof(segmentNamePrefix)); + } + + var results = new List(); + int index = 1; + while (true) + { + results.Add(expressionReader.ReadExpression($"{segmentNamePrefix}{index}", allowAliases, clauseTerminators)); + index++; + + if (parser.IsAtEnd || parser.Peek().Text != ",") + { + break; + } + + parser.Read(); + } + + return results; + } +} diff --git a/Utils.Data/Sql/ExpressionReader.cs b/Utils.Data/Sql/ExpressionReader.cs new file mode 100644 index 0000000..ef9545f --- /dev/null +++ b/Utils.Data/Sql/ExpressionReader.cs @@ -0,0 +1,209 @@ +using System; +using System.Collections.Generic; + +namespace Utils.Data.Sql; + +#nullable enable + +/// +/// Reads SQL expressions that may optionally include aliases from a shared context. +/// +internal sealed class ExpressionReader +{ + private readonly SqlParser parser; + + /// + /// Initializes a new instance of the class. + /// + /// The parser providing token access. + public ExpressionReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Reads a single expression up to the next comma or specified clause boundary. + /// + /// The segment name assigned to the parsed expression. + /// Indicates whether an alias is allowed after the expression. + /// Clause boundaries that stop the expression. + /// The parsed expression along with its optional alias. + /// Thrown when no expression could be read or an alias is malformed. + public ExpressionReadResult ReadExpression(string segmentName, bool allowAlias, params ClauseStart[] clauseTerminators) + { + if (string.IsNullOrWhiteSpace(segmentName)) + { + throw new ArgumentException("Segment name cannot be null or whitespace.", nameof(segmentName)); + } + + var tokens = new List(); + int parenthesisDepth = 0; + int caseDepth = 0; + while (!parser.IsAtEnd) + { + var current = parser.Peek(); + if (IsTopLevel(parenthesisDepth, caseDepth)) + { + if (current.Text == ",") + { + break; + } + + if (clauseTerminators.Length > 0 && parser.IsClauseStart(clauseTerminators)) + { + break; + } + } + + tokens.Add(parser.Read()); + UpdateDepths(current, ref parenthesisDepth, ref caseDepth); + } + + if (tokens.Count == 0) + { + throw new SqlParseException("Expected expression but none was found."); + } + + var fullTokens = new List(tokens); + string? alias = null; + if (allowAlias) + { + alias = ExtractAlias(tokens); + } + + if (tokens.Count == 0) + { + throw new SqlParseException("Expression cannot be reduced to an alias only."); + } + + return new ExpressionReadResult(parser.BuildSegment(segmentName, tokens), alias, fullTokens); + } + + /// + /// Extracts an alias from the end of the provided token list when present. + /// + /// Tokens composing the expression and potential alias. + /// The alias text when found; otherwise, null. + /// Thrown when an alias indicator is not followed by a valid identifier. + private static string? ExtractAlias(List tokens) + { + if (tokens.Count >= 2) + { + var last = tokens[^1]; + var beforeLast = tokens[^2]; + if (beforeLast.Normalized == "AS") + { + if (!last.IsIdentifier) + { + throw new SqlParseException($"Expected identifier after AS but found '{last.Text}'."); + } + + tokens.RemoveRange(tokens.Count - 2, 2); + return last.Text; + } + } + + if (tokens.Count >= 2) + { + var last = tokens[^1]; + var beforeLast = tokens[^2]; + if (last.IsIdentifier && !last.IsKeyword && !IsAliasSeparator(beforeLast)) + { + tokens.RemoveAt(tokens.Count - 1); + return last.Text; + } + } + + return null; + } + + /// + /// Updates the current parenthesis depth based on the provided token. + /// + /// The token being processed. + /// The tracked depth value. + private static void UpdateDepths(SqlToken token, ref int parenthesisDepth, ref int caseDepth) + { + if (token.Text == "(") + { + parenthesisDepth++; + } + else if (token.Text == ")" && parenthesisDepth > 0) + { + parenthesisDepth--; + } + + if (token.Normalized == "CASE") + { + caseDepth++; + } + else if (token.Normalized == "END" && caseDepth > 0) + { + caseDepth--; + } + } + + /// + /// Determines whether parsing is currently at the top-level expression scope. + /// + /// The tracked parenthesis depth. + /// The tracked CASE expression depth. + /// true when both depths indicate the outermost scope. + private static bool IsTopLevel(int parenthesisDepth, int caseDepth) + { + return parenthesisDepth == 0 && caseDepth == 0; + } + + /// + /// Determines whether the provided token prevents alias extraction because it is attached to an identifier. + /// + /// The token immediately preceding an alias candidate. + /// true when the token is considered part of the identifier chain. + private static bool IsAliasSeparator(SqlToken token) + { + return token.Text is "." or "::"; + } +} + +/// +/// Represents an expression read from SQL along with its optional alias. +/// +internal sealed class ExpressionReadResult +{ + /// + /// Initializes a new instance of the class. + /// + /// The parsed expression segment. + /// The optional alias associated with the expression. + public ExpressionReadResult(SqlSegment expression, string? alias, IReadOnlyList tokens) + { + Expression = expression ?? throw new ArgumentNullException(nameof(expression)); + Alias = alias; + Tokens = tokens ?? throw new ArgumentNullException(nameof(tokens)); + } + + /// + /// Gets the parsed expression segment. + /// + public SqlSegment Expression { get; } + + /// + /// Gets the optional alias associated with the expression. + /// + public string? Alias { get; } + + /// + /// Gets the tokens that compose the expression including any alias tokens. + /// + public IReadOnlyList Tokens { get; } + + /// + /// Builds the SQL snippet represented by the expression and its alias when present. + /// + /// The SQL text for the expression. + public string ToSql() + { + string expressionText = Expression.ToSql(); + return string.IsNullOrWhiteSpace(Alias) ? expressionText : $"{expressionText} {Alias}"; + } +} diff --git a/Utils.Data/Sql/FieldListReader.cs b/Utils.Data/Sql/FieldListReader.cs new file mode 100644 index 0000000..a3bfb57 --- /dev/null +++ b/Utils.Data/Sql/FieldListReader.cs @@ -0,0 +1,113 @@ +using System; +using System.Collections.Generic; + +namespace Utils.Data.Sql; + +#nullable enable + +/// +/// Reads simple comma-separated field names from a shared context. +/// +internal sealed class FieldListReader +{ + private readonly SqlParser parser; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public FieldListReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Reads a sequence of fields separated by commas until a clause boundary is reached. + /// + /// Prefix used to name field segments. + /// Clause boundaries that stop the list. + /// The parsed field segments. + /// Thrown when is null or whitespace. + public IReadOnlyList ReadFields(string segmentNamePrefix, params ClauseStart[] clauseTerminators) + { + if (string.IsNullOrWhiteSpace(segmentNamePrefix)) + { + throw new ArgumentException("Segment name prefix cannot be null or whitespace.", nameof(segmentNamePrefix)); + } + + var results = new List(); + int index = 1; + while (true) + { + results.Add(ReadField($"{segmentNamePrefix}{index}", clauseTerminators)); + index++; + + if (parser.IsAtEnd || parser.IsClauseStart(clauseTerminators)) + { + break; + } + + var next = parser.Peek(); + if (next.Text == ",") + { + parser.Read(); + continue; + } + + if (next.Text == ")") + { + break; + } + + throw new SqlParseException($"Unexpected token '{next.Text}' while reading field list."); + } + + return results; + } + + /// + /// Reads a single field, stopping before the next comma, closing parenthesis, or clause boundary. + /// + /// The segment name assigned to the parsed field. + /// Clause boundaries that stop the field. + /// The parsed field segment. + /// Thrown when no field tokens are found. + private SqlSegment ReadField(string segmentName, params ClauseStart[] clauseTerminators) + { + var tokens = new List(); + int depth = 0; + while (!parser.IsAtEnd) + { + var current = parser.Peek(); + if (depth == 0) + { + if (current.Text == "," || current.Text == ")") + { + break; + } + + if (clauseTerminators.Length > 0 && parser.IsClauseStart(clauseTerminators)) + { + break; + } + } + + tokens.Add(parser.Read()); + if (current.Text == "(") + { + depth++; + } + else if (current.Text == ")" && depth > 0) + { + depth--; + } + } + + if (tokens.Count == 0) + { + throw new SqlParseException("Expected field but none was found."); + } + + return parser.BuildSegment(segmentName, tokens); + } +} diff --git a/Utils.Data/Sql/InsertStatementParser.cs b/Utils.Data/Sql/InsertStatementParser.cs new file mode 100644 index 0000000..720b21b --- /dev/null +++ b/Utils.Data/Sql/InsertStatementParser.cs @@ -0,0 +1,87 @@ +using System; + +namespace Utils.Data.Sql; + +#nullable enable + +/// +/// Parses INSERT statements using a shared helper context. +/// +internal sealed class InsertStatementParser +{ + private readonly SqlParser parser; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying parser providing token utilities. + public InsertStatementParser(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Parses an INSERT statement. + /// + /// The optional WITH clause attached to the statement. + /// The created . + public SqlInsertStatement Parse(WithClause? withClause) + { + parser.ExpectKeyword("INSERT"); + if (parser.TryConsumeKeyword("INTO") == false) + { + parser.ExpectKeyword("INTO"); + } + + var intoReader = new IntoPartReader(parser); + var valuesReader = new ValuesPartReader(parser); + var outputReader = new OutputPartReader(parser); + var returningReader = new ReturningPartReader(parser); + + var targetSegment = intoReader.ReadIntoTarget( + outputReader.ClauseKeyword, + valuesReader.ClauseKeyword, + ClauseStart.Select, + returningReader.ClauseKeyword, + ClauseStart.StatementEnd); + SqlSegment? valuesSegment = null; + SqlStatement? sourceQuery = null; + SqlSegment? outputSegment = null; + SqlSegment? returningSegment = null; + + int returningIndex = parser.FindClauseIndex("RETURNING"); + + if (parser.TryConsumeKeyword("OUTPUT")) + { + outputSegment = outputReader.ReadOutputPart( + "Output", + valuesReader.ClauseKeyword, + ClauseStart.Select, + returningReader.ClauseKeyword, + ClauseStart.StatementEnd); + } + + if (parser.CheckKeyword("VALUES")) + { + parser.ExpectKeyword("VALUES"); + valuesSegment = valuesReader.ReadValuesPart(returningReader.ClauseKeyword, ClauseStart.StatementEnd); + } + else if (parser.CheckKeyword("SELECT") || parser.CheckKeyword("WITH")) + { + int end = returningIndex >= 0 ? returningIndex : parser.Tokens.Count; + var sourceTokens = parser.Tokens.GetRange(parser.Position, end - parser.Position); + var subParser = new SqlParser(sourceTokens, parser.SyntaxOptions); + sourceQuery = subParser.ParseStatementWithOptionalCte(); + subParser.ConsumeOptionalTerminator(); + subParser.EnsureEndOfInput(); + parser.Position = end; + } + + if (parser.TryConsumeKeyword("RETURNING")) + { + returningSegment = returningReader.ReadReturningPart(ClauseStart.StatementEnd); + } + + return new SqlInsertStatement(targetSegment, valuesSegment, sourceQuery, outputSegment, returningSegment, withClause); + } +} diff --git a/Utils.Data/Sql/PredicateReader.cs b/Utils.Data/Sql/PredicateReader.cs new file mode 100644 index 0000000..c58004d --- /dev/null +++ b/Utils.Data/Sql/PredicateReader.cs @@ -0,0 +1,90 @@ +using System; +using System.Collections.Generic; + +namespace Utils.Data.Sql; + +#nullable enable + +/// + /// Reads predicates appearing in clauses such as WHERE, HAVING, or JOIN conditions. +/// +internal sealed class PredicateReader +{ + private readonly SqlParser parser; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public PredicateReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Reads a predicate that may include comparisons, logical operators, IN lists, and nested subqueries. + /// + /// The segment name assigned to the parsed predicate. + /// Clause boundaries that stop the predicate. + /// The parsed predicate segment. + /// Thrown when is null or whitespace. + /// Thrown when no predicate tokens are found. + public SqlSegment ReadPredicate(string segmentName, params ClauseStart[] clauseTerminators) + { + if (string.IsNullOrWhiteSpace(segmentName)) + { + throw new ArgumentException("Segment name cannot be null or whitespace.", nameof(segmentName)); + } + + var tokens = new List(); + int depth = 0; + while (!parser.IsAtEnd) + { + var current = parser.Peek(); + if (depth == 0) + { + if (current.Text == ")") + { + break; + } + + if (current.Text == ";") + { + break; + } + + if (clauseTerminators.Length > 0 && parser.IsClauseStart(clauseTerminators)) + { + break; + } + } + + tokens.Add(parser.Read()); + UpdateDepth(current, ref depth); + } + + if (tokens.Count == 0) + { + throw new SqlParseException("Expected predicate but none was found."); + } + + return parser.BuildSegment(segmentName, tokens); + } + + /// + /// Updates the tracked parenthesis depth for the provided token. + /// + /// The token to evaluate. + /// The tracked depth to update. + private static void UpdateDepth(SqlToken token, ref int depth) + { + if (token.Text == "(") + { + depth++; + } + else if (token.Text == ")" && depth > 0) + { + depth--; + } + } +} diff --git a/Utils.Data/Sql/SelectStatementParser.cs b/Utils.Data/Sql/SelectStatementParser.cs new file mode 100644 index 0000000..b15fddb --- /dev/null +++ b/Utils.Data/Sql/SelectStatementParser.cs @@ -0,0 +1,117 @@ +using System; + +namespace Utils.Data.Sql; + +#nullable enable + +/// +/// Parses SELECT statements by delegating token consumption to a shared context. +/// +internal sealed class SelectStatementParser +{ + private readonly SqlParser parser; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying parser managing token access and helper utilities. + public SelectStatementParser(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Parses a SELECT statement. + /// + /// The WITH clause associated with the SELECT statement, if any. + /// The built . + public SqlSelectStatement Parse(WithClause? withClause) + { + parser.ExpectKeyword("SELECT"); + bool isDistinct = parser.TryConsumeKeyword("DISTINCT"); + var selectReader = new SelectPartReader(parser); + var fromReader = new FromPartReader(parser); + var whereReader = new WherePartReader(parser); + var groupByReader = new GroupByPartReader(parser); + var havingReader = new HavingPartReader(parser); + var orderByReader = new OrderByPartReader(parser); + var limitReader = new LimitPartReader(parser); + var offsetReader = new OffsetPartReader(parser); + var returningReader = new ReturningPartReader(parser); + var setOperatorReader = new SetOperatorPartReader(parser); + + var clauseTerminators = BuildClauseTerminators( + fromReader.ClauseKeyword, + whereReader.ClauseKeyword, + groupByReader.ClauseKeyword, + havingReader.ClauseKeyword, + orderByReader.ClauseKeyword, + limitReader.ClauseKeyword, + offsetReader.ClauseKeyword, + returningReader.ClauseKeyword, + setOperatorReader.ClauseKeyword); + + var selectPart = selectReader.ReadSelectPart(clauseTerminators); + + var fromPart = fromReader.TryReadFromPart(GetRemainingTerminators(clauseTerminators, 1)); + + var wherePart = whereReader.TryReadWherePart(GetRemainingTerminators(clauseTerminators, 2)); + + var groupByPart = groupByReader.TryReadGroupByPart(GetRemainingTerminators(clauseTerminators, 3)); + + var havingPart = havingReader.TryReadHavingPart(GetRemainingTerminators(clauseTerminators, 4)); + + var orderByPart = orderByReader.TryReadOrderByPart(GetRemainingTerminators(clauseTerminators, 5)); + + var limitPart = limitReader.TryReadLimitPart(GetRemainingTerminators(clauseTerminators, 6)); + + var offsetPart = offsetReader.TryReadOffsetPart(GetRemainingTerminators(clauseTerminators, 7)); + + var tailPart = setOperatorReader.TryReadTailPart(); + + return new SqlSelectStatement( + selectPart, + fromPart, + wherePart, + groupByPart, + havingPart, + orderByPart, + limitPart, + offsetPart, + tailPart, + withClause, + isDistinct); + } + + /// + /// Builds the ordered list of clause terminators using the provided clause keywords and the statement end marker. + /// + /// The clause keywords that may terminate the current part. + /// An ordered array of clause terminators. + private static ClauseStart[] BuildClauseTerminators(params ClauseStart[] clauseKeywords) + { + var terminators = new ClauseStart[clauseKeywords.Length + 1]; + Array.Copy(clauseKeywords, terminators, clauseKeywords.Length); + terminators[^1] = ClauseStart.StatementEnd; + return terminators; + } + + /// + /// Retrieves a subset of clause terminators beginning at the specified index. + /// + /// The ordered terminators. + /// The index to start from. + /// The remaining terminators starting at . + private static ClauseStart[] GetRemainingTerminators(ClauseStart[] terminators, int startIndex) + { + var remainingLength = terminators.Length - startIndex; + if (remainingLength <= 0) + { + return Array.Empty(); + } + + var remaining = new ClauseStart[remainingLength]; + Array.Copy(terminators, startIndex, remaining, 0, remainingLength); + return remaining; + } +} diff --git a/Utils.Data/Sql/SqlParser.cs b/Utils.Data/Sql/SqlParser.cs index 7268a02..29664b2 100644 --- a/Utils.Data/Sql/SqlParser.cs +++ b/Utils.Data/Sql/SqlParser.cs @@ -14,35 +14,47 @@ internal sealed class SqlParser private static readonly IReadOnlyDictionary> StatementParsers = new Dictionary>(StringComparer.OrdinalIgnoreCase) { - { "SELECT", (parser, withClause) => parser.ParseSelect(withClause) }, - { "INSERT", (parser, withClause) => parser.ParseInsert(withClause) }, - { "UPDATE", (parser, withClause) => parser.ParseUpdate(withClause) }, - { "DELETE", (parser, withClause) => parser.ParseDelete(withClause) }, + { "SELECT", (parser, withClause) => new SelectStatementParser(parser).Parse(withClause) }, + { "INSERT", (parser, withClause) => new InsertStatementParser(parser).Parse(withClause) }, + { "UPDATE", (parser, withClause) => new UpdateStatementParser(parser).Parse(withClause) }, + { "DELETE", (parser, withClause) => new DeleteStatementParser(parser).Parse(withClause) }, }.ToImmutableDictionary(); - private static readonly IReadOnlyList<(string Keyword, ClauseStart Clause, bool IncludeKeyword)> Segments = - new List<(string, ClauseStart, bool)> - { - ("FROM", ClauseStart.From, false), - ("WHERE", ClauseStart.Where, false), - ("GROUP BY", ClauseStart.GroupBy, false), - ("HAVING", ClauseStart.Having, false), - ("ORDER BY", ClauseStart.OrderBy, false), - ("LIMIT", ClauseStart.Limit, false), - ("OFFSET", ClauseStart.Offset, false), - ("RETURNING", ClauseStart.Returning, false), - ("USING", ClauseStart.Using, false), - ("UNION", ClauseStart.SetOperator, true), - ("EXCEPT", ClauseStart.SetOperator, true), - ("INTERSECT", ClauseStart.SetOperator, true), - }.ToImmutableList(); - private readonly List tokens; private int position; private readonly SqlSyntaxOptions syntaxOptions; - private SqlParser(IEnumerable tokens, SqlSyntaxOptions syntaxOptions) + /// + /// Gets the tokens being parsed. + /// + internal List Tokens => tokens; + + /// + /// Gets or sets the current parsing position within . + /// + internal int Position + { + get => position; + set => position = value; + } + + /// + /// Gets the syntax options currently applied to parsing. + /// + internal SqlSyntaxOptions SyntaxOptions => syntaxOptions; + + /// + /// Gets a value indicating whether the parser has consumed all tokens. + /// + internal bool IsAtEnd => position >= tokens.Count; + + /// + /// Initializes a new instance of the class with pre-tokenized input. + /// + /// The tokens representing the SQL statement. + /// The syntax options guiding token interpretation. + internal SqlParser(IEnumerable tokens, SqlSyntaxOptions syntaxOptions) { this.syntaxOptions = syntaxOptions ?? throw new ArgumentNullException(nameof(syntaxOptions)); this.tokens = tokens.ToList(); @@ -150,234 +162,13 @@ private IReadOnlyList ParseColumnList() /// /// The WITH clause associated with the SELECT, if any. /// The parsed instance. - private SqlSelectStatement ParseSelect(WithClause? withClause) - { - ExpectKeyword("SELECT"); - bool isDistinct = TryConsumeKeyword("DISTINCT"); - var selectTokens = ReadSectionTokens(ClauseStart.From, ClauseStart.Where, ClauseStart.GroupBy, ClauseStart.Having, ClauseStart.OrderBy, ClauseStart.Limit, ClauseStart.Offset, ClauseStart.Returning, ClauseStart.SetOperator, ClauseStart.StatementEnd); - var selectSegment = BuildSegment("Select", selectTokens); - - ClauseStart[] clauses = - [ - ClauseStart.From, - ClauseStart.Where, - ClauseStart.GroupBy, - ClauseStart.Having, - ClauseStart.OrderBy, - ClauseStart.Limit, - ClauseStart.Offset, - ClauseStart.Returning, - ClauseStart.Using, - ClauseStart.SetOperator, - ClauseStart.StatementEnd, - ]; - - Dictionary segments = new Dictionary(); - - foreach (var segment in Segments) - { - var tokensAfter = clauses.SkipWhile(c => c != segment.Clause).Skip(1).ToArray(); - if (TryConsumeSegmentKeyword(segment.Keyword, out var consumedTokens)) - { - var segmentTokens = ReadSectionTokens(tokensAfter); - if (segment.IncludeKeyword) - { - segmentTokens.InsertRange(0, consumedTokens); - } - - segments[segment.Clause] = BuildSegment(segment.Clause.ToString(), segmentTokens); - } - else if (!segments.ContainsKey(segment.Clause)) - { - segments[segment.Clause] = null; - } - } - - return new SqlSelectStatement( - selectSegment, - segments[ClauseStart.From], - segments[ClauseStart.Where], - segments[ClauseStart.GroupBy], - segments[ClauseStart.Having], - segments[ClauseStart.OrderBy], - segments[ClauseStart.Limit], - segments[ClauseStart.Offset], - segments[ClauseStart.SetOperator], - withClause, - isDistinct); - } - - private SqlInsertStatement ParseInsert(WithClause? withClause) - { - ExpectKeyword("INSERT"); - if (TryConsumeKeyword("INTO") == false) - { - ExpectKeyword("INTO"); - } - - var targetTokens = new List(); - int returningIndex = FindClauseIndex("RETURNING"); - while (!IsAtEnd) - { - if (CheckKeyword("VALUES") || CheckKeyword("SELECT") || CheckKeyword("WITH") || CheckKeyword("RETURNING") || CheckKeyword("OUTPUT")) - { - break; - } - - if (Peek().Text == ";") - { - break; - } - - targetTokens.Add(Read()); - } - - var targetSegment = BuildSegment("Target", targetTokens); - SqlSegment? valuesSegment = null; - SqlStatement? sourceQuery = null; - SqlSegment? outputSegment = null; - SqlSegment? returningSegment = null; - - if (TryConsumeKeyword("OUTPUT")) - { - var outputTokens = ReadSectionTokens(ClauseStart.Values, ClauseStart.Select, ClauseStart.Returning, ClauseStart.StatementEnd); - outputSegment = BuildSegment("Output", outputTokens); - } - - if (CheckKeyword("VALUES")) - { - ExpectKeyword("VALUES"); - var valuesTokens = new List(); - while (!IsAtEnd && (returningIndex < 0 || position < returningIndex)) - { - if (Peek().Text == ";") - { - break; - } - - valuesTokens.Add(Read()); - } - - valuesSegment = BuildSegment("Values", valuesTokens); - } - else if (CheckKeyword("SELECT") || CheckKeyword("WITH")) - { - int end = returningIndex >= 0 ? returningIndex : tokens.Count; - var sourceTokens = tokens.GetRange(position, end - position); - var subParser = new SqlParser(sourceTokens, syntaxOptions); - sourceQuery = subParser.ParseStatementWithOptionalCte(); - subParser.ConsumeOptionalTerminator(); - subParser.EnsureEndOfInput(); - position = end; - } - - if (TryConsumeKeyword("RETURNING")) - { - var returningTokens = ReadSectionTokens(ClauseStart.StatementEnd); - returningSegment = BuildSegment("Returning", returningTokens); - } - - return new SqlInsertStatement(targetSegment, valuesSegment, sourceQuery, outputSegment, returningSegment, withClause); - } - - private SqlUpdateStatement ParseUpdate(WithClause? withClause) - { - ExpectKeyword("UPDATE"); - var targetTokens = new List(); - while (!IsAtEnd && !CheckKeyword("SET") && Peek().Text != ";") - { - targetTokens.Add(Read()); - } - - var targetSegment = BuildSegment("Target", targetTokens); - ExpectKeyword("SET"); - var setTokens = ReadSectionTokens(ClauseStart.Output, ClauseStart.From, ClauseStart.Where, ClauseStart.Returning, ClauseStart.StatementEnd); - var setSegment = BuildSegment("Set", setTokens); - - SqlSegment? fromSegment = null; - SqlSegment? whereSegment = null; - SqlSegment? outputSegment = null; - SqlSegment? returningSegment = null; - - if (TryConsumeKeyword("OUTPUT")) - { - var outputTokens = ReadSectionTokens(ClauseStart.From, ClauseStart.Where, ClauseStart.Returning, ClauseStart.StatementEnd); - outputSegment = BuildSegment("Output", outputTokens); - } - - if (TryConsumeKeyword("FROM")) - { - var fromTokens = ReadSectionTokens(ClauseStart.Where, ClauseStart.Returning, ClauseStart.StatementEnd); - fromSegment = BuildSegment("From", fromTokens); - } - - if (TryConsumeKeyword("WHERE")) - { - var whereTokens = ReadSectionTokens(ClauseStart.Returning, ClauseStart.StatementEnd); - whereSegment = BuildSegment("Where", whereTokens); - } - - if (TryConsumeKeyword("RETURNING")) - { - var returningTokens = ReadSectionTokens(ClauseStart.StatementEnd); - returningSegment = BuildSegment("Returning", returningTokens); - } - - return new SqlUpdateStatement(targetSegment, setSegment, fromSegment, whereSegment, outputSegment, returningSegment, withClause); - } - - private SqlDeleteStatement ParseDelete(WithClause? withClause) - { - ExpectKeyword("DELETE"); - SqlSegment? targetSegment = null; - if (!CheckKeyword("FROM")) - { - var targetTokens = new List(); - while (!IsAtEnd && !CheckKeyword("FROM") && Peek().Text != ";") - { - targetTokens.Add(Read()); - } - - targetSegment = BuildSegment("Target", targetTokens); - } - - ExpectKeyword("FROM"); - var fromTokens = ReadSectionTokens(ClauseStart.Output, ClauseStart.Using, ClauseStart.Where, ClauseStart.Returning, ClauseStart.StatementEnd); - var fromSegment = BuildSegment("From", fromTokens); - - SqlSegment? usingSegment = null; - SqlSegment? whereSegment = null; - SqlSegment? outputSegment = null; - SqlSegment? returningSegment = null; - - if (TryConsumeKeyword("OUTPUT")) - { - var outputTokens = ReadSectionTokens(ClauseStart.Using, ClauseStart.Where, ClauseStart.Returning, ClauseStart.StatementEnd); - outputSegment = BuildSegment("Output", outputTokens); - } - - if (TryConsumeKeyword("USING")) - { - var usingTokens = ReadSectionTokens(ClauseStart.Where, ClauseStart.Returning, ClauseStart.StatementEnd); - usingSegment = BuildSegment("Using", usingTokens); - } - - if (TryConsumeKeyword("WHERE")) - { - var whereTokens = ReadSectionTokens(ClauseStart.Returning, ClauseStart.StatementEnd); - whereSegment = BuildSegment("Where", whereTokens); - } - - if (TryConsumeKeyword("RETURNING")) - { - var returningTokens = ReadSectionTokens(ClauseStart.StatementEnd); - returningSegment = BuildSegment("Returning", returningTokens); - } - - return new SqlDeleteStatement(targetSegment, fromSegment, usingSegment, whereSegment, outputSegment, returningSegment, withClause); - } - - private SqlSegment BuildSegment(string name, List tokens) + /// + /// Builds a from the provided tokens. + /// + /// The logical name of the segment. + /// The tokens that compose the segment. + /// A representing the parsed section. + internal SqlSegment BuildSegment(string name, List tokens) { return new SqlSegment(name, BuildSegmentParts(tokens, syntaxOptions), syntaxOptions); } @@ -488,7 +279,12 @@ private List ReadTokensUntilMatchingParenthesis() throw new SqlParseException("Unterminated parenthesis in WITH clause definition."); } - private List ReadSectionTokens(params ClauseStart[] terminators) + /// + /// Reads tokens until one of the specified clause starts is encountered at depth zero. + /// + /// Clause boundaries that end the current section. + /// The tokens collected for the current section. + internal List ReadSectionTokens(params ClauseStart[] terminators) { var tokens = new List(); int depth = 0; @@ -524,6 +320,16 @@ private List ReadSectionTokens(params ClauseStart[] terminators) return tokens; } + /// + /// Determines whether the current token marks the start of one of the specified clauses. + /// + /// Clause starts that should stop parsing. + /// true when the current position matches a clause start. + internal bool IsClauseStart(params ClauseStart[] terminators) + { + return CheckClauseStart(terminators); + } + private bool CheckClauseStart(params ClauseStart[] terminators) { foreach (var terminator in terminators) @@ -534,65 +340,33 @@ private bool CheckClauseStart(params ClauseStart[] terminators) { return true; } + + continue; } - else if (terminator == ClauseStart.Where && CheckKeyword("WHERE")) - { - return true; - } - else if (terminator == ClauseStart.From && CheckKeyword("FROM")) - { - return true; - } - else if (terminator == ClauseStart.GroupBy && CheckKeywordSequence("GROUP", "BY")) - { - return true; - } - else if (terminator == ClauseStart.Having && CheckKeyword("HAVING")) - { - return true; - } - else if (terminator == ClauseStart.OrderBy && CheckKeywordSequence("ORDER", "BY")) - { - return true; - } - else if (terminator == ClauseStart.Limit && CheckKeyword("LIMIT")) - { - return true; - } - else if (terminator == ClauseStart.Offset && CheckKeyword("OFFSET")) - { - return true; - } - else if (terminator == ClauseStart.Output && CheckKeyword("OUTPUT")) - { - return true; - } - else if (terminator == ClauseStart.Returning && CheckKeyword("RETURNING")) - { - return true; - } - else if (terminator == ClauseStart.Values && CheckKeyword("VALUES")) - { - return true; - } - else if (terminator == ClauseStart.Select && (CheckKeyword("SELECT") || CheckKeyword("WITH"))) - { - return true; - } - else if (terminator == ClauseStart.Using && CheckKeyword("USING")) + + if (!ClauseStartKeywordRegistry.TryGetClauseKeywords(terminator, out var keywordSequences)) { - return true; + continue; } - else if (terminator == ClauseStart.SetOperator && (CheckKeyword("UNION") || CheckKeyword("EXCEPT") || CheckKeyword("INTERSECT"))) + + foreach (var keywordSequence in keywordSequences) { - return true; + if (CheckKeywordSequence(keywordSequence)) + { + return true; + } } } return false; } - private int FindClauseIndex(string keyword) + /// + /// Finds the index of the next keyword, ignoring nested parentheses. + /// + /// The keyword to look for. + /// The token index if found; otherwise, -1. + internal int FindClauseIndex(string keyword) { int depth = 0; for (int i = position; i < tokens.Count; i++) @@ -619,7 +393,13 @@ private int FindClauseIndex(string keyword) return -1; } - private SqlToken ExpectKeyword(string keyword) + /// + /// Consumes the specified keyword or throws an exception when the stream does not match. + /// + /// The keyword to consume. + /// The consumed . + /// Thrown when the expected keyword is missing. + internal SqlToken ExpectKeyword(string keyword) { if (IsAtEnd || !CheckKeyword(keyword)) { @@ -629,6 +409,11 @@ private SqlToken ExpectKeyword(string keyword) return Read(); } + /// + /// Consumes an identifier at the current position or throws when the token is not an identifier. + /// + /// The identifier text. + /// Thrown when the token is not an identifier. private string ExpectIdentifier() { if (IsAtEnd) @@ -646,6 +431,12 @@ private string ExpectIdentifier() return token.Text; } + /// + /// Consumes the expected symbol or throws when a different token is found. + /// + /// The symbol to consume. + /// The consumed token. + /// Thrown when the symbol does not match. private SqlToken ExpectSymbol(string symbol) { if (IsAtEnd || Peek().Text != symbol) @@ -662,7 +453,7 @@ private SqlToken ExpectSymbol(string symbol) /// The keyword text to match. /// The tokens consumed when the keyword is matched. /// true when the keyword is consumed; otherwise, false. - private bool TryConsumeSegmentKeyword(string keyword, out List consumedTokens) + internal bool TryConsumeSegmentKeyword(string keyword, out List consumedTokens) { consumedTokens = new List(); if (keyword == ";") @@ -698,7 +489,7 @@ private bool TryConsumeSegmentKeyword(string keyword, out List consume return false; } - private bool TryConsumeKeyword(string keyword) + internal bool TryConsumeKeyword(string keyword) { if (CheckKeyword(keyword)) { @@ -709,6 +500,11 @@ private bool TryConsumeKeyword(string keyword) return false; } + /// + /// Attempts to consume the specified symbol at the current position. + /// + /// The symbol expected. + /// true when the symbol was consumed; otherwise, false. private bool TryConsumeSymbol(string symbol) { if (!IsAtEnd && Peek().Text == symbol) @@ -720,17 +516,41 @@ private bool TryConsumeSymbol(string symbol) return false; } - private bool CheckKeyword(string keyword) + internal bool CheckKeyword(string keyword) { return !IsAtEnd && Peek().Normalized == keyword; } - private bool CheckKeywordSequence(string first, string second) + internal bool CheckKeywordSequence(string first, string second) { - return !IsAtEnd && Peek().Normalized == first && PeekOptional(1)?.Normalized == second; + return CheckKeywordSequence(new[] { first, second }); } - private SqlToken Peek(int offset = 0) + /// + /// Checks whether the upcoming tokens match the provided keyword sequence. + /// + /// The ordered keywords to verify. + /// true when the upcoming tokens match the full sequence; otherwise, false. + internal bool CheckKeywordSequence(IReadOnlyList keywordSequence) + { + if (keywordSequence.Count == 0) + { + return false; + } + + for (int i = 0; i < keywordSequence.Count; i++) + { + var token = PeekOptional(i); + if (token is null || token.Normalized != keywordSequence[i]) + { + return false; + } + } + + return true; + } + + internal SqlToken Peek(int offset = 0) { return tokens[position + offset]; } @@ -746,16 +566,15 @@ private SqlToken Peek(int offset = 0) return tokens[index]; } - private SqlToken Read() + internal SqlToken Read() { return tokens[position++]; } - - private bool IsAtEnd => position >= tokens.Count; } internal enum ClauseStart { + Into, From, Where, GroupBy, @@ -807,10 +626,15 @@ internal sealed class SqlTokenizer "UNION", "ALL", "DISTINCT", + "CASE", "INSERT", "INTO", "VALUES", "RETURNING", + "WHEN", + "THEN", + "ELSE", + "END", "OUTPUT", "UPDATE", "SET", @@ -818,6 +642,7 @@ internal sealed class SqlTokenizer "WITH", "RECURSIVE", "AS", + "IF", "ON", "JOIN", "INNER", diff --git a/Utils.Data/Sql/SqlQueryAnalyzer.cs b/Utils.Data/Sql/SqlQueryAnalyzer.cs index 919a095..51d60ac 100644 --- a/Utils.Data/Sql/SqlQueryAnalyzer.cs +++ b/Utils.Data/Sql/SqlQueryAnalyzer.cs @@ -238,6 +238,57 @@ protected void ReplaceSegment(SqlSegment? previous, SqlSegment? replacement) } } +/// +/// Provides a contract for binding a parsed segment name to a typed SQL statement part. +/// +internal interface IPartReferenceBinding +{ + /// + /// Attempts to bind the provided segment to a typed part when the binding name matches. + /// + /// The name of the segment that was created. + /// The segment instance associated with the name. + /// true when the binding matched the name; otherwise, false. + bool TryBind(string name, SqlSegment segment); +} + +/// +/// Associates a part reader-provided name and factory with a callback that stores the typed part. +/// +/// The part type produced by the binding. +internal sealed class PartReferenceBinding : IPartReferenceBinding + where TPart : SqlStatementPart +{ + private readonly string partName; + private readonly Func partFactory; + private readonly Action onBind; + + /// + /// Initializes a new instance of the class. + /// + /// The name of the segment produced by the associated part reader. + /// Factory used to create the typed part. + /// Callback invoked when the binding is matched. + public PartReferenceBinding(string partName, Func partFactory, Action onBind) + { + this.partName = partName ?? throw new ArgumentNullException(nameof(partName)); + this.partFactory = partFactory ?? throw new ArgumentNullException(nameof(partFactory)); + this.onBind = onBind ?? throw new ArgumentNullException(nameof(onBind)); + } + + /// + public bool TryBind(string name, SqlSegment segment) + { + if (!string.Equals(name, partName, StringComparison.Ordinal)) + { + return false; + } + + onBind(partFactory(segment ?? throw new ArgumentNullException(nameof(segment)))); + return true; + } +} + /// /// Represents a parsed SELECT statement. /// @@ -251,6 +302,16 @@ public sealed class SqlSelectStatement : SqlStatement private SqlSegment? limit; private SqlSegment? offset; private SqlSegment? tail; + private readonly SelectPart selectPart; + private readonly IReadOnlyList partReferenceBindings; + private FromPart? fromPart; + private WherePart? wherePart; + private GroupByPart? groupByPart; + private HavingPart? havingPart; + private OrderByPart? orderByPart; + private LimitPart? limitPart; + private OffsetPart? offsetPart; + private TailPart? tailPart; /// /// Initializes a new instance of the class. @@ -301,6 +362,26 @@ public SqlSelectStatement( this.offset = offset; this.tail = tail; IsDistinct = isDistinct; + selectPart = new SelectPart(Select); + fromPart = from == null ? null : new FromPart(from); + wherePart = where == null ? null : new WherePart(where); + groupByPart = groupBy == null ? null : new GroupByPart(groupBy); + havingPart = having == null ? null : new HavingPart(having); + orderByPart = orderBy == null ? null : new OrderByPart(orderBy); + limitPart = limit == null ? null : new LimitPart(limit); + offsetPart = offset == null ? null : new OffsetPart(offset); + tailPart = tail == null ? null : new TailPart(tail); + partReferenceBindings = new IPartReferenceBinding[] + { + new PartReferenceBinding(FromPartReader.PartName, FromPartReader.PartFactory, part => fromPart ??= part), + new PartReferenceBinding(WherePartReader.PartName, WherePartReader.PartFactory, part => wherePart ??= part), + new PartReferenceBinding(GroupByPartReader.PartName, GroupByPartReader.PartFactory, part => groupByPart ??= part), + new PartReferenceBinding(HavingPartReader.PartName, HavingPartReader.PartFactory, part => havingPart ??= part), + new PartReferenceBinding(OrderByPartReader.PartName, OrderByPartReader.PartFactory, part => orderByPart ??= part), + new PartReferenceBinding(LimitPartReader.PartName, LimitPartReader.PartFactory, part => limitPart ??= part), + new PartReferenceBinding(OffsetPartReader.PartName, OffsetPartReader.PartFactory, part => offsetPart ??= part), + new PartReferenceBinding(SetOperatorPartReader.PartName, SetOperatorPartReader.PartFactory, part => tailPart ??= part), + }; } /// @@ -308,46 +389,91 @@ public SqlSelectStatement( /// public SqlSegment Select { get; } + /// + /// Gets the SELECT clause represented as a typed part. + /// + public SelectPart SelectPart => selectPart; + /// /// Gets the FROM segment describing the data sources. /// public SqlSegment? From => from; + /// + /// Gets the FROM clause represented as a typed part. + /// + public FromPart? FromPart => fromPart; + /// /// Gets the WHERE segment containing the filtering conditions. /// public SqlSegment? Where => where; + /// + /// Gets the WHERE clause represented as a typed part. + /// + public WherePart? WherePart => wherePart; + /// /// Gets the GROUP BY segment. /// public SqlSegment? GroupBy => groupBy; + /// + /// Gets the GROUP BY clause represented as a typed part. + /// + public GroupByPart? GroupByPart => groupByPart; + /// /// Gets the HAVING segment. /// public SqlSegment? Having => having; + /// + /// Gets the HAVING clause represented as a typed part. + /// + public HavingPart? HavingPart => havingPart; + /// /// Gets the ORDER BY segment. /// public SqlSegment? OrderBy => orderBy; + /// + /// Gets the ORDER BY clause represented as a typed part. + /// + public OrderByPart? OrderByPart => orderByPart; + /// /// Gets the LIMIT segment. /// public SqlSegment? Limit => limit; + /// + /// Gets the LIMIT clause represented as a typed part. + /// + public LimitPart? LimitPart => limitPart; + /// /// Gets the OFFSET segment. /// public SqlSegment? Offset => offset; + /// + /// Gets the OFFSET clause represented as a typed part. + /// + public OffsetPart? OffsetPart => offsetPart; + /// /// Gets any trailing segments such as UNION clauses. /// public SqlSegment? Tail => tail; + /// + /// Gets the trailing set operator clause represented as a typed part. + /// + public TailPart? TailPart => tailPart; + /// /// Gets a value indicating whether the SELECT statement includes DISTINCT. /// @@ -507,10 +633,27 @@ private SqlSegment EnsureSegment(ref SqlSegment? segment, string name) { segment = SqlSegment.CreateEmpty(name, SyntaxOptions); AttachSegment(segment); + UpdatePartReference(name, segment); } return segment; } + + /// + /// Ensures that part references remain aligned with their segments. + /// + /// Name of the segment being created. + /// The segment instance that was created. + private void UpdatePartReference(string name, SqlSegment segment) + { + foreach (var binding in partReferenceBindings) + { + if (binding.TryBind(name, segment)) + { + break; + } + } + } } /// @@ -521,6 +664,9 @@ public sealed class SqlInsertStatement : SqlStatement private SqlSegment? values; private SqlSegment? output; private SqlSegment? returning; + private readonly InsertPart insertPart; + private readonly IntoPart intoPart; + private ValuesPart? valuesPart; /// /// Initializes a new instance of the class. @@ -539,6 +685,9 @@ public SqlInsertStatement(SqlSegment target, SqlSegment? values, SqlStatement? s SourceQuery = sourceQuery; this.output = output; this.returning = returning; + insertPart = new InsertPart(Target); + intoPart = new IntoPart(Target); + valuesPart = values == null ? null : new ValuesPart(values); } /// @@ -546,11 +695,26 @@ public SqlInsertStatement(SqlSegment target, SqlSegment? values, SqlStatement? s /// public SqlSegment Target { get; } + /// + /// Gets the INSERT clause represented as a typed part. + /// + public InsertPart InsertPart => insertPart; + + /// + /// Gets the INTO clause represented as a typed part. + /// + public IntoPart IntoPart => intoPart; + /// /// Gets the VALUES segment when the statement inserts literal values. /// public SqlSegment? Values => values; + /// + /// Gets the VALUES clause represented as a typed part when present. + /// + public ValuesPart? ValuesPart => valuesPart; + /// /// Gets the source statement when the insert pulls data from a query. /// @@ -582,6 +746,7 @@ public SqlSegment EnsureValuesSegment() { values = SqlSegment.CreateEmpty("Values", SyntaxOptions); AttachSegment(values); + valuesPart = new ValuesPart(values); } return values; @@ -711,6 +876,10 @@ public sealed class SqlUpdateStatement : SqlStatement private SqlSegment? where; private SqlSegment? output; private SqlSegment? returning; + private readonly UpdatePart updatePart; + private readonly IReadOnlyList partReferenceBindings; + private FromPart? fromPart; + private WherePart? wherePart; /// /// Initializes a new instance of the class. @@ -731,6 +900,14 @@ public SqlUpdateStatement(SqlSegment target, SqlSegment set, SqlSegment? from, S this.where = where; this.output = output; this.returning = returning; + updatePart = new UpdatePart(Target); + fromPart = from == null ? null : new FromPart(from); + wherePart = where == null ? null : new WherePart(where); + partReferenceBindings = new IPartReferenceBinding[] + { + new PartReferenceBinding(FromPartReader.PartName, FromPartReader.PartFactory, part => fromPart ??= part), + new PartReferenceBinding(WherePartReader.PartName, WherePartReader.PartFactory, part => wherePart ??= part), + }; } /// @@ -738,6 +915,11 @@ public SqlUpdateStatement(SqlSegment target, SqlSegment set, SqlSegment? from, S /// public SqlSegment Target { get; } + /// + /// Gets the UPDATE clause represented as a typed part. + /// + public UpdatePart UpdatePart => updatePart; + /// /// Gets the SET segment describing the assignments. /// @@ -748,11 +930,21 @@ public SqlUpdateStatement(SqlSegment target, SqlSegment set, SqlSegment? from, S /// public SqlSegment? From => from; + /// + /// Gets the FROM clause represented as a typed part when present. + /// + public FromPart? FromPart => fromPart; + /// /// Gets the WHERE segment when present. /// public SqlSegment? Where => where; + /// + /// Gets the WHERE clause represented as a typed part when present. + /// + public WherePart? WherePart => wherePart; + /// /// Gets the OUTPUT segment when present. /// @@ -853,10 +1045,27 @@ private SqlSegment EnsureOptionalSegment(ref SqlSegment? segment, string name) { segment = SqlSegment.CreateEmpty(name, SyntaxOptions); AttachSegment(segment); + UpdatePartReference(name, segment); } return segment; } + + /// + /// Updates part references to align with newly created segments. + /// + /// The name of the created segment. + /// The created segment. + private void UpdatePartReference(string name, SqlSegment segment) + { + foreach (var binding in partReferenceBindings) + { + if (binding.TryBind(name, segment)) + { + break; + } + } + } } /// @@ -869,6 +1078,10 @@ public sealed class SqlDeleteStatement : SqlStatement private SqlSegment? where; private SqlSegment? output; private SqlSegment? returning; + private DeletePart? deletePart; + private readonly FromPart fromPart; + private readonly IReadOnlyList partReferenceBindings; + private WherePart? wherePart; /// /// Initializes a new instance of the class. @@ -889,6 +1102,14 @@ public SqlDeleteStatement(SqlSegment? target, SqlSegment from, SqlSegment? @usin this.where = where; this.output = output; this.returning = returning; + deletePart = target == null ? null : new DeletePart(target); + fromPart = new FromPart(From); + wherePart = where == null ? null : new WherePart(where); + partReferenceBindings = new IPartReferenceBinding[] + { + new PartReferenceBinding(DeletePartReader.PartName, DeletePartReader.PartFactory, part => deletePart ??= part), + new PartReferenceBinding(WherePartReader.PartName, WherePartReader.PartFactory, part => wherePart ??= part), + }; } /// @@ -896,11 +1117,21 @@ public SqlDeleteStatement(SqlSegment? target, SqlSegment from, SqlSegment? @usin /// public SqlSegment? Target => target; + /// + /// Gets the DELETE clause represented as a typed part when present. + /// + public DeletePart? DeletePart => deletePart; + /// /// Gets the FROM segment. /// public SqlSegment From { get; } + /// + /// Gets the FROM clause represented as a typed part. + /// + public FromPart FromPart => fromPart; + /// /// Gets the USING segment when present. /// @@ -911,6 +1142,11 @@ public SqlDeleteStatement(SqlSegment? target, SqlSegment from, SqlSegment? @usin /// public SqlSegment? Where => where; + /// + /// Gets the WHERE clause represented as a typed part when present. + /// + public WherePart? WherePart => wherePart; + /// /// Gets the OUTPUT segment when present. /// @@ -1025,10 +1261,27 @@ private SqlSegment EnsureOptionalSegment(ref SqlSegment? segment, string name) { segment = SqlSegment.CreateEmpty(name, SyntaxOptions); AttachSegment(segment); + UpdatePartReference(name, segment); } return segment; } + + /// + /// Updates typed part references when new optional segments are created. + /// + /// The name of the created segment. + /// The created segment. + private void UpdatePartReference(string name, SqlSegment segment) + { + foreach (var binding in partReferenceBindings) + { + if (binding.TryBind(name, segment)) + { + break; + } + } + } } /// diff --git a/Utils.Data/Sql/SqlStatementPartReaders.cs b/Utils.Data/Sql/SqlStatementPartReaders.cs new file mode 100644 index 0000000..6acc80c --- /dev/null +++ b/Utils.Data/Sql/SqlStatementPartReaders.cs @@ -0,0 +1,1000 @@ +using System; +using System.Collections.Generic; + +namespace Utils.Data.Sql; + +#nullable enable + +/// +/// Reads the SELECT clause expressions for a statement. +/// +internal sealed class SelectPartReader +{ + private readonly SqlParser parser; + private readonly ExpressionListReader expressionListReader; + + /// + /// Gets the keyword metadata describing how to detect the start of the SELECT clause. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = ClauseKeywordDefinition.FromKeywords( + ClauseStart.Select, + new[] { "SELECT" }, + new[] { "WITH" }); + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public SelectPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + expressionListReader = new ExpressionListReader(this.parser); + } + + /// + /// Reads the SELECT clause up to the next statement section. + /// + /// The parsed SELECT segment. + public SqlSegment ReadSelectPart(params ClauseStart[] clauseTerminators) + { + var expressions = expressionListReader.ReadExpressions( + "SelectExpr", + true, + clauseTerminators.Length > 0 + ? clauseTerminators + : new[] + { + ClauseStart.From, + ClauseStart.Where, + ClauseStart.GroupBy, + ClauseStart.Having, + ClauseStart.OrderBy, + ClauseStart.Limit, + ClauseStart.Offset, + ClauseStart.Returning, + ClauseStart.Using, + ClauseStart.SetOperator, + ClauseStart.StatementEnd, + }); + + return BuildExpressionListSegment("Select", expressions); + } + + private SqlSegment BuildExpressionListSegment(string name, IReadOnlyList expressions) + { + var tokens = new List(); + for (int i = 0; i < expressions.Count; i++) + { + if (i > 0) + { + tokens.Add(new SqlToken(",", ",", false, false)); + } + + tokens.AddRange(expressions[i].Tokens); + } + + return parser.BuildSegment(name, tokens); + } +} + +/// +/// Reads table sources for a FROM clause. +/// +internal sealed class FromPartReader +{ + private readonly SqlParser parser; + private readonly TableListReader tableListReader; + + /// + /// Gets the keyword metadata describing how to detect the start of the FROM clause. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = + ClauseKeywordDefinition.FromKeywords(ClauseStart.From, new[] { "FROM" }); + + /// + /// Gets the name of the part produced by the reader. + /// + public static string PartName => "From"; + + /// + /// Gets the factory that creates a typed part from the parsed FROM segment. + /// + public static Func PartFactory { get; } = segment => new FromPart(segment); + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public FromPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + tableListReader = new TableListReader(this.parser); + } + + /// + /// Attempts to read a FROM clause when present. + /// + /// The parsed FROM segment when found; otherwise, null. + public SqlSegment? TryReadFromPart(params ClauseStart[] clauseTerminators) + { + if (!parser.TryConsumeKeyword("FROM")) + { + return null; + } + + var tables = tableListReader.ReadTables( + "FromTable", + clauseTerminators.Length > 0 + ? clauseTerminators + : new[] + { + ClauseStart.Where, + ClauseStart.GroupBy, + ClauseStart.Having, + ClauseStart.OrderBy, + ClauseStart.Limit, + ClauseStart.Offset, + ClauseStart.Returning, + ClauseStart.Output, + ClauseStart.Using, + ClauseStart.SetOperator, + ClauseStart.StatementEnd, + }); + + return BuildDelimitedSegment("From", tables); + } + + private SqlSegment BuildDelimitedSegment(string name, IReadOnlyList segments) + { + var parts = new List(); + for (int i = 0; i < segments.Count; i++) + { + if (i > 0) + { + parts.Add(new SqlTokenPart(",")); + } + + foreach (var part in segments[i].Parts) + { + parts.Add(part); + } + } + + return new SqlSegment(name, parts, parser.SyntaxOptions); + } +} + +/// +/// Reads an INTO clause target for SELECT or INSERT statements. +/// +internal sealed class IntoPartReader +{ + private readonly SqlParser parser; + + /// + /// Gets the keyword metadata describing how to detect the start of the INTO clause. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = + ClauseKeywordDefinition.FromKeywords(ClauseStart.Into, new[] { "INTO" }); + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public IntoPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Reads an INTO target when the current token matches the INTO keyword. + /// + /// The parsed INTO/target segment. + public SqlSegment ReadIntoTarget(params ClauseStart[] clauseTerminators) + { + var tokens = parser.ReadSectionTokens( + clauseTerminators.Length > 0 + ? clauseTerminators + : new[] + { + ClauseStart.Output, + ClauseStart.Values, + ClauseStart.Select, + ClauseStart.Returning, + ClauseStart.StatementEnd, + }); + return parser.BuildSegment("Target", tokens); + } +} + +/// +/// Reads a WHERE predicate. +/// +internal sealed class WherePartReader +{ + private readonly SqlParser parser; + private readonly PredicateReader predicateReader; + + /// + /// Gets the keyword metadata describing how to detect the start of the WHERE clause. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = + ClauseKeywordDefinition.FromKeywords(ClauseStart.Where, new[] { "WHERE" }); + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Gets the name of the part produced by the reader. + /// + public static string PartName => "Where"; + + /// + /// Gets the factory that creates a typed part from the parsed WHERE segment. + /// + public static Func PartFactory { get; } = segment => new WherePart(segment); + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public WherePartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + predicateReader = new PredicateReader(this.parser); + } + + /// + /// Attempts to read a WHERE clause when present. + /// + /// The parsed WHERE predicate when found; otherwise, null. + public SqlSegment? TryReadWherePart(params ClauseStart[] clauseTerminators) + { + if (!parser.TryConsumeKeyword("WHERE")) + { + return null; + } + + return predicateReader.ReadPredicate( + "Where", + clauseTerminators.Length > 0 + ? clauseTerminators + : new[] + { + ClauseStart.GroupBy, + ClauseStart.Having, + ClauseStart.OrderBy, + ClauseStart.Limit, + ClauseStart.Offset, + ClauseStart.Returning, + ClauseStart.Using, + ClauseStart.SetOperator, + ClauseStart.StatementEnd, + }); + } +} + +/// +/// Reads GROUP BY expressions. +/// +internal sealed class GroupByPartReader +{ + private readonly SqlParser parser; + private readonly ExpressionListReader expressionListReader; + + /// + /// Gets the keyword metadata describing how to detect the start of the GROUP BY clause. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = + ClauseKeywordDefinition.FromKeywords(ClauseStart.GroupBy, new[] { "GROUP", "BY" }); + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Gets the name of the part produced by the reader. + /// + public static string PartName => "GroupBy"; + + /// + /// Gets the factory that creates a typed part from the parsed GROUP BY segment. + /// + public static Func PartFactory { get; } = segment => new GroupByPart(segment); + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public GroupByPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + expressionListReader = new ExpressionListReader(this.parser); + } + + /// + /// Attempts to read a GROUP BY clause when present. + /// + /// The parsed GROUP BY segment when found; otherwise, null. + public SqlSegment? TryReadGroupByPart(params ClauseStart[] clauseTerminators) + { + if (!parser.TryConsumeSegmentKeyword("GROUP BY", out _)) + { + return null; + } + + var expressions = expressionListReader.ReadExpressions( + "GroupByExpr", + false, + clauseTerminators.Length > 0 + ? clauseTerminators + : new[] + { + ClauseStart.Having, + ClauseStart.OrderBy, + ClauseStart.Limit, + ClauseStart.Offset, + ClauseStart.Returning, + ClauseStart.Using, + ClauseStart.SetOperator, + ClauseStart.StatementEnd, + }); + + return BuildExpressionListSegment("GroupBy", expressions); + } + + private SqlSegment BuildExpressionListSegment(string name, IReadOnlyList expressions) + { + var tokens = new List(); + for (int i = 0; i < expressions.Count; i++) + { + if (i > 0) + { + tokens.Add(new SqlToken(",", ",", false, false)); + } + + tokens.AddRange(expressions[i].Tokens); + } + + return parser.BuildSegment(name, tokens); + } +} + +/// +/// Reads a HAVING predicate. +/// +internal sealed class HavingPartReader +{ + private readonly SqlParser parser; + private readonly PredicateReader predicateReader; + + /// + /// Gets the keyword metadata describing how to detect the start of the HAVING clause. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = + ClauseKeywordDefinition.FromKeywords(ClauseStart.Having, new[] { "HAVING" }); + + /// + /// Gets the name of the part produced by the reader. + /// + public static string PartName => "Having"; + + /// + /// Gets the factory that creates a typed part from the parsed HAVING segment. + /// + public static Func PartFactory { get; } = segment => new HavingPart(segment); + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public HavingPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + predicateReader = new PredicateReader(this.parser); + } + + /// + /// Attempts to read a HAVING clause when present. + /// + /// The parsed HAVING segment when found; otherwise, null. + public SqlSegment? TryReadHavingPart(params ClauseStart[] clauseTerminators) + { + if (!parser.TryConsumeKeyword("HAVING")) + { + return null; + } + + return predicateReader.ReadPredicate( + "Having", + clauseTerminators.Length > 0 + ? clauseTerminators + : new[] + { + ClauseStart.OrderBy, + ClauseStart.Limit, + ClauseStart.Offset, + ClauseStart.Returning, + ClauseStart.Using, + ClauseStart.SetOperator, + ClauseStart.StatementEnd, + }); + } +} + +/// +/// Reads ORDER BY expressions. +/// +internal sealed class OrderByPartReader +{ + private readonly SqlParser parser; + private readonly ExpressionListReader expressionListReader; + + /// + /// Gets the keyword metadata describing how to detect the start of the ORDER BY clause. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = + ClauseKeywordDefinition.FromKeywords(ClauseStart.OrderBy, new[] { "ORDER", "BY" }); + + /// + /// Gets the name of the part produced by the reader. + /// + public static string PartName => "OrderBy"; + + /// + /// Gets the factory that creates a typed part from the parsed ORDER BY segment. + /// + public static Func PartFactory { get; } = segment => new OrderByPart(segment); + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public OrderByPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + expressionListReader = new ExpressionListReader(this.parser); + } + + /// + /// Attempts to read an ORDER BY clause when present. + /// + /// The parsed ORDER BY segment when found; otherwise, null. + public SqlSegment? TryReadOrderByPart(params ClauseStart[] clauseTerminators) + { + if (!parser.TryConsumeSegmentKeyword("ORDER BY", out _)) + { + return null; + } + + var expressions = expressionListReader.ReadExpressions( + "OrderByExpr", + false, + clauseTerminators.Length > 0 + ? clauseTerminators + : new[] + { + ClauseStart.Limit, + ClauseStart.Offset, + ClauseStart.Returning, + ClauseStart.Using, + ClauseStart.SetOperator, + ClauseStart.StatementEnd, + }); + + return BuildExpressionListSegment("OrderBy", expressions); + } + + private SqlSegment BuildExpressionListSegment(string name, IReadOnlyList expressions) + { + var tokens = new List(); + for (int i = 0; i < expressions.Count; i++) + { + if (i > 0) + { + tokens.Add(new SqlToken(",", ",", false, false)); + } + + tokens.AddRange(expressions[i].Tokens); + } + + return parser.BuildSegment(name, tokens); + } +} + +/// +/// Reads a LIMIT clause when present. +/// +internal sealed class LimitPartReader +{ + private readonly SqlParser parser; + + /// + /// Gets the keyword metadata describing how to detect the start of the LIMIT clause. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = + ClauseKeywordDefinition.FromKeywords(ClauseStart.Limit, new[] { "LIMIT" }); + + /// + /// Gets the name of the part produced by the reader. + /// + public static string PartName => "Limit"; + + /// + /// Gets the factory that creates a typed part from the parsed LIMIT segment. + /// + public static Func PartFactory { get; } = segment => new LimitPart(segment); + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public LimitPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Attempts to read a LIMIT clause when present. + /// + /// The parsed LIMIT segment when found; otherwise, null. + public SqlSegment? TryReadLimitPart(params ClauseStart[] clauseTerminators) + { + if (!parser.TryConsumeKeyword("LIMIT")) + { + return null; + } + + var tokens = parser.ReadSectionTokens( + clauseTerminators.Length > 0 + ? clauseTerminators + : new[] + { + ClauseStart.Offset, + ClauseStart.Returning, + ClauseStart.Using, + ClauseStart.SetOperator, + ClauseStart.StatementEnd, + }); + return parser.BuildSegment("Limit", tokens); + } +} + +/// +/// Reads an OFFSET clause when present. +/// +internal sealed class OffsetPartReader +{ + private readonly SqlParser parser; + + /// + /// Gets the keyword metadata describing how to detect the start of the OFFSET clause. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = + ClauseKeywordDefinition.FromKeywords(ClauseStart.Offset, new[] { "OFFSET" }); + + /// + /// Gets the name of the part produced by the reader. + /// + public static string PartName => "Offset"; + + /// + /// Gets the factory that creates a typed part from the parsed OFFSET segment. + /// + public static Func PartFactory { get; } = segment => new OffsetPart(segment); + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public OffsetPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Attempts to read an OFFSET clause when present. + /// + /// The parsed OFFSET segment when found; otherwise, null. + public SqlSegment? TryReadOffsetPart(params ClauseStart[] clauseTerminators) + { + if (!parser.TryConsumeKeyword("OFFSET")) + { + return null; + } + + var tokens = parser.ReadSectionTokens( + clauseTerminators.Length > 0 + ? clauseTerminators + : new[] + { + ClauseStart.Returning, + ClauseStart.Using, + ClauseStart.SetOperator, + ClauseStart.StatementEnd, + }); + return parser.BuildSegment("Offset", tokens); + } +} + +/// +/// Reads VALUES content for INSERT statements. +/// +internal sealed class ValuesPartReader +{ + private readonly SqlParser parser; + + /// + /// Gets the keyword metadata describing how to detect the start of the VALUES clause. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = + ClauseKeywordDefinition.FromKeywords(ClauseStart.Values, new[] { "VALUES" }); + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public ValuesPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Reads the VALUES clause when the VALUES keyword has already been consumed. + /// + /// The parsed VALUES segment. + public SqlSegment ReadValuesPart(params ClauseStart[] clauseTerminators) + { + var tokens = parser.ReadSectionTokens( + clauseTerminators.Length > 0 + ? clauseTerminators + : new[] + { + ClauseStart.Returning, + ClauseStart.StatementEnd, + }); + return parser.BuildSegment("Values", tokens); + } +} + +/// +/// Reads OUTPUT clauses for DML statements. +/// +internal sealed class OutputPartReader +{ + private readonly SqlParser parser; + private readonly ExpressionListReader expressionListReader; + + /// + /// Gets the keyword metadata describing how to detect the start of the OUTPUT clause. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = + ClauseKeywordDefinition.FromKeywords(ClauseStart.Output, new[] { "OUTPUT" }); + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public OutputPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + expressionListReader = new ExpressionListReader(this.parser); + } + + /// + /// Reads an OUTPUT clause after the OUTPUT keyword has been consumed. + /// + /// The logical name to give the resulting segment. + /// Clause boundaries that end the OUTPUT clause. + /// The parsed OUTPUT segment. + public SqlSegment ReadOutputPart(string segmentName, params ClauseStart[] clauseTerminators) + { + var expressions = expressionListReader.ReadExpressions( + segmentName + "Expr", + true, + clauseTerminators.Length > 0 + ? clauseTerminators + : new[] + { + ClauseStart.Returning, + ClauseStart.StatementEnd, + }); + return BuildExpressionListSegment(segmentName, expressions); + } + + private SqlSegment BuildExpressionListSegment(string name, IReadOnlyList expressions) + { + var tokens = new List(); + for (int i = 0; i < expressions.Count; i++) + { + if (i > 0) + { + tokens.Add(new SqlToken(",", ",", false, false)); + } + + tokens.AddRange(expressions[i].Tokens); + } + + return parser.BuildSegment(name, tokens); + } +} + +/// +/// Reads RETURNING clauses. +/// +internal sealed class ReturningPartReader +{ + private readonly SqlParser parser; + + /// + /// Gets the keyword metadata describing how to detect the start of the RETURNING clause. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = + ClauseKeywordDefinition.FromKeywords(ClauseStart.Returning, new[] { "RETURNING" }); + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public ReturningPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Reads a RETURNING clause after the keyword has been consumed. + /// + /// The parsed RETURNING segment. + public SqlSegment ReadReturningPart(params ClauseStart[] clauseTerminators) + { + var tokens = parser.ReadSectionTokens( + clauseTerminators.Length > 0 + ? clauseTerminators + : new[] + { + ClauseStart.StatementEnd, + }); + return parser.BuildSegment("Returning", tokens); + } +} + +/// +/// Reads trailing set operator clauses such as UNION. +/// +internal sealed class SetOperatorPartReader +{ + private static readonly string[] SetOperators = + { + "UNION", + "EXCEPT", + "INTERSECT", + }; + + /// + /// Gets the keyword metadata describing how to detect the start of set operator clauses. + /// + public static ClauseKeywordDefinition KeywordDefinition { get; } = ClauseKeywordDefinition.FromKeywords( + ClauseStart.SetOperator, + new[] { "UNION" }, + new[] { "EXCEPT" }, + new[] { "INTERSECT" }); + + private readonly SqlParser parser; + + /// + /// Gets the clause-start keyword that activates the reader. + /// + public ClauseStart ClauseKeyword => KeywordDefinition.ClauseKeyword; + + /// + /// Gets the name of the part produced by the reader. + /// + public static string PartName => "Tail"; + + /// + /// Gets the factory that creates a typed part from the parsed tail segment. + /// + public static Func PartFactory { get; } = segment => new TailPart(segment); + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public SetOperatorPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Attempts to read a trailing set operator clause when present. + /// + /// The parsed set operator segment when found; otherwise, null. + public SqlSegment? TryReadTailPart() + { + foreach (string keyword in SetOperators) + { + if (parser.TryConsumeSegmentKeyword(keyword, out var consumedTokens)) + { + var tokens = new List(consumedTokens); + tokens.AddRange(parser.ReadSectionTokens(ClauseStart.StatementEnd)); + return parser.BuildSegment("Tail", tokens); + } + } + + return null; + } +} + +/// +/// Reads the optional DELETE target preceding FROM. +/// +internal sealed class DeletePartReader +{ + private readonly SqlParser parser; + + /// + /// Gets the name of the part produced by the reader. + /// + public static string PartName => "Target"; + + /// + /// Gets the factory that creates a typed part from the parsed DELETE target segment. + /// + public static Func PartFactory { get; } = segment => new DeletePart(segment); + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public DeletePartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Attempts to read a DELETE target when the FROM keyword has not yet been encountered. + /// + /// The parsed DELETE target when present; otherwise, null. + public SqlSegment? TryReadDeleteTarget() + { + if (parser.CheckKeyword("FROM")) + { + return null; + } + + var tokens = new List(); + while (!parser.IsAtEnd && !parser.CheckKeyword("FROM") && parser.Peek().Text != ";") + { + tokens.Add(parser.Read()); + } + + return tokens.Count == 0 ? null : parser.BuildSegment("Target", tokens); + } +} + +/// +/// Reads the UPDATE target prior to the SET keyword. +/// +internal sealed class UpdatePartReader +{ + private readonly SqlParser parser; + + /// + /// Gets the name of the part produced by the reader. + /// + public static string PartName => "Target"; + + /// + /// Gets the factory that creates a typed part from the parsed UPDATE target segment. + /// + public static Func PartFactory { get; } = segment => new UpdatePart(segment); + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public UpdatePartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Reads the UPDATE target until the SET keyword is encountered. + /// + /// The parsed UPDATE target segment. + public SqlSegment ReadUpdateTarget() + { + var tokens = new List(); + while (!parser.IsAtEnd && !parser.CheckKeyword("SET") && parser.Peek().Text != ";") + { + tokens.Add(parser.Read()); + } + + return parser.BuildSegment("Target", tokens); + } +} + +/// +/// Reads the SET clause content for UPDATE statements. +/// +internal sealed class SetPartReader +{ + private readonly SqlParser parser; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public SetPartReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Reads the SET clause after the SET keyword has been consumed. + /// + /// The parsed SET segment. + public SqlSegment ReadSetPart() + { + var tokens = parser.ReadSectionTokens( + ClauseStart.Output, + ClauseStart.From, + ClauseStart.Where, + ClauseStart.Returning, + ClauseStart.StatementEnd); + return parser.BuildSegment("Set", tokens); + } +} diff --git a/Utils.Data/Sql/SqlStatementParts.cs b/Utils.Data/Sql/SqlStatementParts.cs new file mode 100644 index 0000000..a2b5dba --- /dev/null +++ b/Utils.Data/Sql/SqlStatementParts.cs @@ -0,0 +1,252 @@ +using System; + +namespace Utils.Data.Sql; + +#nullable enable + +/// +/// Represents a typed part of a SQL statement built from a parsed segment. +/// +public abstract class SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// Name of the part for identification. + /// The underlying segment represented by the part. + /// Thrown when or is null. + protected SqlStatementPart(string name, SqlSegment segment) + { + Name = name ?? throw new ArgumentNullException(nameof(name)); + Segment = segment ?? throw new ArgumentNullException(nameof(segment)); + } + + /// + /// Gets the display name of the part. + /// + public string Name { get; } + + /// + /// Gets the that stores the parsed tokens. + /// + public SqlSegment Segment { get; } + + /// + /// Builds the SQL text represented by the part. + /// + /// The SQL string rendered from the underlying segment. + public string ToSql() + { + return Segment.ToSql(); + } +} + +/// +/// Represents the SELECT clause of a SQL statement. +/// +public sealed class SelectPart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the SELECT clause. + public SelectPart(SqlSegment segment) + : base("Select", segment) + { + } +} + +/// +/// Represents the FROM clause of a SQL statement. +/// +public sealed class FromPart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the FROM clause. + public FromPart(SqlSegment segment) + : base("From", segment) + { + } +} + +/// +/// Represents the INTO clause of a SQL statement. +/// +public sealed class IntoPart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the INTO clause. + public IntoPart(SqlSegment segment) + : base("Into", segment) + { + } +} + +/// +/// Represents the WHERE clause of a SQL statement. +/// +public sealed class WherePart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the WHERE clause. + public WherePart(SqlSegment segment) + : base("Where", segment) + { + } +} + +/// +/// Represents the GROUP BY clause of a SQL statement. +/// +public sealed class GroupByPart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the GROUP BY clause. + public GroupByPart(SqlSegment segment) + : base("GroupBy", segment) + { + } +} + +/// +/// Represents the HAVING clause of a SQL statement. +/// +public sealed class HavingPart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the HAVING clause. + public HavingPart(SqlSegment segment) + : base("Having", segment) + { + } +} + +/// +/// Represents the ORDER BY clause of a SQL statement. +/// +public sealed class OrderByPart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the ORDER BY clause. + public OrderByPart(SqlSegment segment) + : base("OrderBy", segment) + { + } +} + +/// +/// Represents the LIMIT clause of a SQL statement. +/// +public sealed class LimitPart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the LIMIT clause. + public LimitPart(SqlSegment segment) + : base("Limit", segment) + { + } +} + +/// +/// Represents the OFFSET clause of a SQL statement. +/// +public sealed class OffsetPart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the OFFSET clause. + public OffsetPart(SqlSegment segment) + : base("Offset", segment) + { + } +} + +/// +/// Represents trailing set operator content such as UNION clauses. +/// +public sealed class TailPart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the trailing set operator clause. + public TailPart(SqlSegment segment) + : base("Tail", segment) + { + } +} + +/// +/// Represents the VALUES clause of a SQL statement. +/// +public sealed class ValuesPart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the VALUES clause. + public ValuesPart(SqlSegment segment) + : base("Values", segment) + { + } +} + +/// +/// Represents a DELETE clause referencing the target to remove rows from. +/// +public sealed class DeletePart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the DELETE clause. + public DeletePart(SqlSegment segment) + : base("Delete", segment) + { + } +} + +/// +/// Represents an UPDATE clause referencing the target to modify rows in. +/// +public sealed class UpdatePart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the UPDATE clause. + public UpdatePart(SqlSegment segment) + : base("Update", segment) + { + } +} + +/// +/// Represents an INSERT clause identifying the target of the operation. +/// +public sealed class InsertPart : SqlStatementPart +{ + /// + /// Initializes a new instance of the class. + /// + /// The segment describing the INSERT clause. + public InsertPart(SqlSegment segment) + : base("Insert", segment) + { + } +} diff --git a/Utils.Data/Sql/TableListReader.cs b/Utils.Data/Sql/TableListReader.cs new file mode 100644 index 0000000..3d525cc --- /dev/null +++ b/Utils.Data/Sql/TableListReader.cs @@ -0,0 +1,170 @@ +using System; +using System.Collections.Generic; + +namespace Utils.Data.Sql; + +#nullable enable + +/// +/// Reads comma- or join-separated table sources from a shared context. +/// +internal sealed class TableListReader +{ + private readonly SqlParser parser; + + /// + /// Initializes a new instance of the class. + /// + /// The parser supplying token access. + public TableListReader(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Reads a sequence of table sources until a clause boundary is reached. + /// + /// Prefix used to name table segments. + /// Clause boundaries that stop the list. + /// The parsed table segments. + public IReadOnlyList ReadTables(string segmentNamePrefix, params ClauseStart[] clauseTerminators) + { + if (string.IsNullOrWhiteSpace(segmentNamePrefix)) + { + throw new ArgumentException("Segment name prefix cannot be null or whitespace.", nameof(segmentNamePrefix)); + } + + var results = new List(); + int index = 1; + while (true) + { + results.Add(ReadTable($"{segmentNamePrefix}{index}", clauseTerminators)); + index++; + + if (parser.IsAtEnd || parser.IsClauseStart(clauseTerminators)) + { + break; + } + + var next = parser.Peek(); + if (next.Text == ",") + { + parser.Read(); + continue; + } + + if (next.Text == ")") + { + break; + } + + throw new SqlParseException($"Unexpected token '{next.Text}' while reading table list."); + } + + return results; + } + + /// + /// Reads a single table source, including any associated JOIN or APPLY constructs. + /// + /// The segment name assigned to the parsed table. + /// Clause boundaries that stop the table. + /// The parsed table segment. + /// Thrown when no table tokens are found or JOIN clauses are incomplete. + private SqlSegment ReadTable(string segmentName, params ClauseStart[] clauseTerminators) + { + var tokens = new List(); + int depth = 0; + int joinCount = 0; + int onCount = 0; + while (!parser.IsAtEnd) + { + var current = parser.Peek(); + bool joinSatisfied = onCount >= joinCount; + + if (depth == 0 && joinSatisfied) + { + if (current.Text == "," || current.Text == ")") + { + break; + } + + if (clauseTerminators.Length > 0 && parser.IsClauseStart(clauseTerminators)) + { + break; + } + } + + tokens.Add(parser.Read()); + UpdateDepth(current, ref depth); + + if (depth == 0) + { + if (IsJoinRequiringOn(tokens, current)) + { + joinCount++; + } + else if (IsOnKeyword(current)) + { + onCount++; + } + } + } + + if (tokens.Count == 0) + { + throw new SqlParseException("Expected table but none was found."); + } + + if (onCount < joinCount) + { + throw new SqlParseException("Missing ON clause for one or more JOIN operations."); + } + + return parser.BuildSegment(segmentName, tokens); + } + + /// + /// Determines whether the provided token represents a JOIN that requires an ON clause. + /// + /// Tokens collected so far for the current table. + /// The token to evaluate. + /// true when the token is a JOIN keyword not preceded by CROSS. + private static bool IsJoinRequiringOn(IReadOnlyList tokens, SqlToken token) + { + if (!token.IsKeyword || !string.Equals(token.Normalized, "JOIN", StringComparison.OrdinalIgnoreCase)) + { + return false; + } + + var previousToken = tokens.Count >= 2 ? tokens[^2] : null; + return previousToken == null || !string.Equals(previousToken.Normalized, "CROSS", StringComparison.OrdinalIgnoreCase); + } + + /// + /// Determines whether the token represents an ON keyword at the current depth. + /// + /// The token to evaluate. + /// true when the token is ON. + private static bool IsOnKeyword(SqlToken token) + { + return token.IsKeyword && string.Equals(token.Normalized, "ON", StringComparison.OrdinalIgnoreCase); + } + + /// + /// Updates the tracked parenthesis depth for the provided token. + /// + /// The token to evaluate. + /// The tracked depth to update. + private static void UpdateDepth(SqlToken token, ref int depth) + { + if (token.Text == "(") + { + depth++; + } + else if (token.Text == ")" && depth > 0) + { + depth--; + } + } +} diff --git a/Utils.Data/Sql/UpdateStatementParser.cs b/Utils.Data/Sql/UpdateStatementParser.cs new file mode 100644 index 0000000..772cb57 --- /dev/null +++ b/Utils.Data/Sql/UpdateStatementParser.cs @@ -0,0 +1,70 @@ +using System; + +namespace Utils.Data.Sql; + +#nullable enable + +/// +/// Parses UPDATE statements using the shared context. +/// +internal sealed class UpdateStatementParser +{ + private readonly SqlParser parser; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying parser for token management. + public UpdateStatementParser(SqlParser parser) + { + this.parser = parser ?? throw new ArgumentNullException(nameof(parser)); + } + + /// + /// Parses an UPDATE statement. + /// + /// The optional WITH clause attached to the statement. + /// The parsed . + public SqlUpdateStatement Parse(WithClause? withClause) + { + parser.ExpectKeyword("UPDATE"); + var updateTargetReader = new UpdatePartReader(parser); + var outputReader = new OutputPartReader(parser); + var fromReader = new FromPartReader(parser); + var whereReader = new WherePartReader(parser); + var returningReader = new ReturningPartReader(parser); + + var targetSegment = updateTargetReader.ReadUpdateTarget(); + parser.ExpectKeyword("SET"); + var setSegment = new SetPartReader(parser).ReadSetPart(); + + SqlSegment? outputSegment = null; + SqlSegment? fromSegment = null; + SqlSegment? whereSegment = null; + SqlSegment? returningSegment = null; + + if (parser.TryConsumeKeyword("OUTPUT")) + { + outputSegment = outputReader.ReadOutputPart( + "Output", + fromReader.ClauseKeyword, + whereReader.ClauseKeyword, + returningReader.ClauseKeyword, + ClauseStart.StatementEnd); + } + + fromSegment = fromReader.TryReadFromPart( + whereReader.ClauseKeyword, + returningReader.ClauseKeyword, + ClauseStart.StatementEnd); + + whereSegment = whereReader.TryReadWherePart(returningReader.ClauseKeyword, ClauseStart.StatementEnd); + + if (parser.TryConsumeKeyword("RETURNING")) + { + returningSegment = returningReader.ReadReturningPart(ClauseStart.StatementEnd); + } + + return new SqlUpdateStatement(targetSegment, setSegment, fromSegment, whereSegment, outputSegment, returningSegment, withClause); + } +} diff --git a/UtilsTest/Data/ExpressionReaderTests.cs b/UtilsTest/Data/ExpressionReaderTests.cs new file mode 100644 index 0000000..780f473 --- /dev/null +++ b/UtilsTest/Data/ExpressionReaderTests.cs @@ -0,0 +1,79 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Utils.Data.Sql; + +namespace UtilsTests.Data; + +/// +/// Tests for and . +/// +[TestClass] +public sealed class ExpressionReaderTests +{ + [TestMethod] + public void ExpressionReaderReadsExpressionWithExplicitAlias() + { + var parser = SqlParser.Create("amount * 1.2 AS gross, discount"); + var reader = new ExpressionReader(parser); + + var result = reader.ReadExpression("SelectExpr", true, ClauseStart.StatementEnd); + + Assert.AreEqual("amount * 1.2", result.Expression.ToSql()); + Assert.AreEqual("gross", result.Alias); + Assert.AreEqual(",", parser.Peek().Text); + } + + [TestMethod] + public void ExpressionListReaderStopsAtClauseAndHandlesImplicitAlias() + { + var parser = SqlParser.Create("orders.total, SUM(quantity) qty FROM sales"); + var listReader = new ExpressionListReader(parser); + + var expressions = listReader.ReadExpressions("SelectExpr", true, ClauseStart.From); + + Assert.AreEqual(2, expressions.Count); + Assert.AreEqual("orders.total", expressions[0].Expression.ToSql()); + Assert.IsNull(expressions[0].Alias); + Assert.AreEqual("SUM(quantity)", expressions[1].Expression.ToSql()); + Assert.AreEqual("qty", expressions[1].Alias); + Assert.AreEqual("FROM", parser.Peek().Normalized); + } + + [TestMethod] + public void ExpressionReaderHandlesIfFunctionWithAlias() + { + var parser = SqlParser.Create("IF(total > 0, total, 0) AS total_value FROM orders"); + var reader = new ExpressionReader(parser); + + var result = reader.ReadExpression("SelectExpr", true, ClauseStart.From); + + Assert.AreEqual("IF(total > 0, total, 0)", result.Expression.ToSql()); + Assert.AreEqual("total_value", result.Alias); + Assert.AreEqual("FROM", parser.Peek().Normalized); + } + + [TestMethod] + public void ExpressionReaderHandlesSearchedCaseExpression() + { + var parser = SqlParser.Create("CASE WHEN qty > 0 THEN price * qty WHEN qty = 0 THEN 0 ELSE NULL END revenue, tax"); + var reader = new ExpressionReader(parser); + + var result = reader.ReadExpression("SelectExpr", true, ClauseStart.StatementEnd); + + Assert.AreEqual("CASE WHEN qty > 0 THEN price * qty WHEN qty = 0 THEN 0 ELSE NULL END", result.Expression.ToSql()); + Assert.AreEqual("revenue", result.Alias); + Assert.AreEqual(",", parser.Peek().Text); + } + + [TestMethod] + public void ExpressionReaderHandlesSimpleCaseExpression() + { + var parser = SqlParser.Create("CASE status WHEN 'NEW' THEN 1 WHEN 'OLD' THEN 2 ELSE 0 END AS status_code FROM orders"); + var reader = new ExpressionReader(parser); + + var result = reader.ReadExpression("SelectExpr", true, ClauseStart.From); + + Assert.AreEqual("CASE status WHEN 'NEW' THEN 1 WHEN 'OLD' THEN 2 ELSE 0 END", result.Expression.ToSql()); + Assert.AreEqual("status_code", result.Alias); + Assert.AreEqual("FROM", parser.Peek().Normalized); + } +} diff --git a/UtilsTest/Data/FieldAndPredicateReaderTests.cs b/UtilsTest/Data/FieldAndPredicateReaderTests.cs new file mode 100644 index 0000000..390235d --- /dev/null +++ b/UtilsTest/Data/FieldAndPredicateReaderTests.cs @@ -0,0 +1,74 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Utils.Data.Sql; + +namespace UtilsTests.Data; + +/// +/// Tests for and . +/// +[TestClass] +public sealed class FieldAndPredicateReaderTests +{ + /// + /// Ensures the field list reader collects identifiers until a clause boundary is reached. + /// + [TestMethod] + public void FieldListReaderStopsAtClauseBoundary() + { + var parser = SqlParser.Create("id, name, created_at FROM accounts"); + var reader = new FieldListReader(parser); + + var fields = reader.ReadFields("Field", ClauseStart.From); + + Assert.AreEqual(3, fields.Count); + Assert.AreEqual("id", fields[0].ToSql()); + Assert.AreEqual("name", fields[1].ToSql()); + Assert.AreEqual("created_at", fields[2].ToSql()); + Assert.AreEqual("FROM", parser.Peek().Normalized); + } + + /// + /// Ensures predicate reading stops before the next clause and preserves logical operators. + /// + [TestMethod] + public void PredicateReaderReadsComplexPredicate() + { + var parser = SqlParser.Create("price > 100 AND (status = 'A' OR status = 'B') GROUP BY region"); + var reader = new PredicateReader(parser); + + var predicate = reader.ReadPredicate("Where", ClauseStart.GroupBy); + + Assert.AreEqual("price > 100 AND(status = 'A' OR status = 'B')", predicate.ToSql()); + Assert.AreEqual("GROUP", parser.Peek().Normalized); + } + + /// + /// Ensures predicates support IN value lists without consuming the following clause. + /// + [TestMethod] + public void PredicateReaderHandlesInValueList() + { + var parser = SqlParser.Create("country IN ('FR', 'US', 'DE') ORDER BY country"); + var reader = new PredicateReader(parser); + + var predicate = reader.ReadPredicate("Where", ClauseStart.OrderBy); + + Assert.AreEqual("country IN ('FR', 'US', 'DE')", predicate.ToSql()); + Assert.AreEqual("ORDER", parser.Peek().Normalized); + } + + /// + /// Ensures predicates support Oracle-style row value lists and IN subqueries. + /// + [TestMethod] + public void PredicateReaderHandlesRowValueListAndSubquery() + { + var parser = SqlParser.Create("(col1, col2) IN ((1, 2), (3, 4)) OR (col1, col2) IN (SELECT a, b FROM dual)"); + var reader = new PredicateReader(parser); + + var predicate = reader.ReadPredicate("Where", ClauseStart.StatementEnd); + + Assert.AreEqual("(col1, col2) IN ((1, 2), (3, 4)) OR(col1, col2) IN (SELECT a, b FROM dual)", predicate.ToSql()); + Assert.IsTrue(parser.IsAtEnd); + } +} diff --git a/UtilsTest/Data/SqlQueryAnalyzerTests.cs b/UtilsTest/Data/SqlQueryAnalyzerTests.cs index f99f3d7..e795aaa 100644 --- a/UtilsTest/Data/SqlQueryAnalyzerTests.cs +++ b/UtilsTest/Data/SqlQueryAnalyzerTests.cs @@ -82,6 +82,24 @@ public void ParseInsertWithOutputClause() Assert.AreEqual("INSERT INTO audit_log(user_id) OUTPUT inserted.id VALUES (@userId)", query.ToSql()); } + [TestMethod] + public void ParseInsertWithCteAsSource() + { + const string sql = @"WITH source_data AS (SELECT 1 AS id) +INSERT INTO destination(id) +SELECT id FROM source_data;"; + + SqlQuery query = SqlQueryAnalyzer.Parse(sql); + + var insert = (SqlInsertStatement)query.RootStatement; + Assert.IsNotNull(insert.WithClause); + Assert.AreEqual(1, insert.WithClause!.Definitions.Count); + Assert.AreEqual("source_data", insert.WithClause.Definitions[0].Name); + Assert.IsNull(insert.Values); + Assert.IsNotNull(insert.SourceQuery); + Assert.AreEqual("WITH source_data AS (SELECT 1 AS id) INSERT INTO destination(id) SELECT id FROM source_data", query.ToSql()); + } + [TestMethod] public void ParseUpdateWithSubquery() { diff --git a/UtilsTest/Data/SqlStatementPartReaderTests.cs b/UtilsTest/Data/SqlStatementPartReaderTests.cs new file mode 100644 index 0000000..0ec8970 --- /dev/null +++ b/UtilsTest/Data/SqlStatementPartReaderTests.cs @@ -0,0 +1,114 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Linq; +using Utils.Data.Sql; + +namespace UtilsTests.Data; + +/// +/// Tests for SQL statement part readers such as and . +/// +[TestClass] +public sealed class SqlStatementPartReaderTests +{ + /// + /// Ensures the select part reader preserves aliases and stops before the FROM clause. + /// + [TestMethod] + public void SelectPartReaderStopsBeforeFromClause() + { + var parser = SqlParser.Create("amount AS total, tax t FROM sales"); + var reader = new SelectPartReader(parser); + var fromReader = new FromPartReader(parser); + + var selectSegment = reader.ReadSelectPart(fromReader.ClauseKeyword); + + Assert.AreEqual("amount AS total, tax t", selectSegment.ToSql()); + Assert.AreEqual("FROM", parser.Peek().Normalized); + } + + /// + /// Ensures the from part reader gathers comma-separated tables and respects where boundaries. + /// + [TestMethod] + public void FromPartReaderReadsTables() + { + var parser = SqlParser.Create("FROM accounts a, orders o WHERE o.account_id = a.id"); + var reader = new FromPartReader(parser); + + var whereReader = new WherePartReader(parser); + var fromSegment = reader.TryReadFromPart(whereReader.ClauseKeyword); + + Assert.IsNotNull(fromSegment); + Assert.AreEqual("accounts a, orders o", fromSegment!.ToSql()); + Assert.AreEqual("WHERE", parser.Peek().Normalized); + } + + /// + /// Ensures the where part reader reads predicates up to set operators or statement boundaries. + /// + [TestMethod] + public void WherePartReaderStopsAtClause() + { + var parser = SqlParser.Create("WHERE price > 100 UNION SELECT 1"); + var reader = new WherePartReader(parser); + + var whereSegment = reader.TryReadWherePart(); + + Assert.IsNotNull(whereSegment); + Assert.AreEqual("price > 100", whereSegment!.ToSql()); + Assert.AreEqual("UNION", parser.Peek().Normalized); + } + + /// + /// Ensures clause keywords are exposed for coordinating terminators between part readers. + /// + [TestMethod] + public void ClauseKeywordsAreExposed() + { + var parser = SqlParser.Create("LIMIT 5"); + var limitReader = new LimitPartReader(parser); + var offsetReader = new OffsetPartReader(parser); + + Assert.AreEqual(ClauseStart.Limit, limitReader.ClauseKeyword); + Assert.AreEqual(ClauseStart.Offset, offsetReader.ClauseKeyword); + } + + /// + /// Ensures clause keyword metadata is available through the registry for clause start detection. + /// + [TestMethod] + public void ClauseKeywordRegistryExposesPartReaderKeywords() + { + var orderByKeywords = ClauseStartKeywordRegistry.KnownClauseKeywords[ClauseStart.OrderBy]; + + Assert.AreEqual(1, orderByKeywords.Count); + CollectionAssert.AreEqual(new[] { "ORDER", "BY" }, orderByKeywords.Single().ToArray()); + } + + /// + /// Ensures part references are created using the metadata exposed by part readers instead of hardcoded switches. + /// + [TestMethod] + public void PartReferencesUseReaderMetadata() + { + var select = new SqlSelectStatement( + SqlSegment.CreateEmpty("Select", SqlSyntaxOptions.Default), + null, + null, + null, + null, + null, + null, + null, + null, + null, + false); + + var whereSegment = select.EnsureWhereSegment(); + + Assert.IsNotNull(whereSegment); + Assert.IsNotNull(select.WherePart); + Assert.AreSame(whereSegment, select.WherePart!.Segment); + Assert.AreEqual(WherePartReader.PartName, select.WherePart.Name); + } +} diff --git a/UtilsTest/Data/SqlStatementPartTests.cs b/UtilsTest/Data/SqlStatementPartTests.cs new file mode 100644 index 0000000..26ad77b --- /dev/null +++ b/UtilsTest/Data/SqlStatementPartTests.cs @@ -0,0 +1,91 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Utils.Data.Sql; + +namespace UtilsTests.Data; + +/// +/// Tests for typed SQL statement parts exposed by parsed statements. +/// +[TestClass] +public sealed class SqlStatementPartTests +{ + /// + /// Ensures SELECT statements expose typed parts aligned with their segments. + /// + [TestMethod] + public void SelectStatementExposesTypedParts() + { + var parser = SqlParser.Create("SELECT id, name FROM accounts WHERE active = 1"); + + var statement = (SqlSelectStatement)parser.ParseStatementWithOptionalCte(); + parser.ConsumeOptionalTerminator(); + parser.EnsureEndOfInput(); + + Assert.AreSame(statement.Select, statement.SelectPart.Segment); + Assert.IsNotNull(statement.FromPart); + Assert.AreSame(statement.From, statement.FromPart!.Segment); + Assert.IsNotNull(statement.WherePart); + Assert.AreEqual(statement.Where?.ToSql(), statement.WherePart!.ToSql()); + } + + /// + /// Ensures INSERT statements expose typed parts for the target and values clauses. + /// + [TestMethod] + public void InsertStatementExposesTypedParts() + { + var parser = SqlParser.Create("INSERT INTO accounts(id) VALUES (1)"); + + var statement = (SqlInsertStatement)parser.ParseStatementWithOptionalCte(); + parser.ConsumeOptionalTerminator(); + parser.EnsureEndOfInput(); + + Assert.AreSame(statement.Target, statement.InsertPart.Segment); + Assert.AreSame(statement.Target, statement.IntoPart.Segment); + Assert.IsNotNull(statement.ValuesPart); + Assert.AreSame(statement.Values, statement.ValuesPart!.Segment); + } + + /// + /// Ensures UPDATE statements create typed parts when optional clauses are added. + /// + [TestMethod] + public void UpdateStatementCreatesTypedPartsWhenEnsured() + { + var parser = SqlParser.Create("UPDATE accounts SET name = 'x'"); + + var statement = (SqlUpdateStatement)parser.ParseStatementWithOptionalCte(); + parser.ConsumeOptionalTerminator(); + parser.EnsureEndOfInput(); + + Assert.AreSame(statement.Target, statement.UpdatePart.Segment); + Assert.IsNull(statement.FromPart); + + var fromSegment = statement.EnsureFromSegment(); + + Assert.IsNotNull(statement.FromPart); + Assert.AreSame(fromSegment, statement.FromPart!.Segment); + } + + /// + /// Ensures DELETE statements expose typed parts and create them when optional sections are added. + /// + [TestMethod] + public void DeleteStatementCreatesTypedParts() + { + var parser = SqlParser.Create("DELETE FROM accounts"); + + var statement = (SqlDeleteStatement)parser.ParseStatementWithOptionalCte(); + parser.ConsumeOptionalTerminator(); + parser.EnsureEndOfInput(); + + Assert.IsNull(statement.DeletePart); + Assert.IsNotNull(statement.FromPart); + Assert.AreSame(statement.From, statement.FromPart.Segment); + + var targetSegment = statement.EnsureTargetSegment(); + + Assert.IsNotNull(statement.DeletePart); + Assert.AreSame(targetSegment, statement.DeletePart!.Segment); + } +} diff --git a/UtilsTest/Data/TableListReaderTests.cs b/UtilsTest/Data/TableListReaderTests.cs new file mode 100644 index 0000000..5031cba --- /dev/null +++ b/UtilsTest/Data/TableListReaderTests.cs @@ -0,0 +1,72 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Utils.Data.Sql; + +namespace UtilsTests.Data; + +/// +/// Tests for . +/// +[TestClass] +public sealed class TableListReaderTests +{ + /// + /// Ensures comma-separated tables stop at the next clause boundary. + /// + [TestMethod] + public void TableListReaderReadsTablesUntilClause() + { + var parser = SqlParser.Create("users u, orders o WHERE status = 'A'"); + var reader = new TableListReader(parser); + + var tables = reader.ReadTables("Table", ClauseStart.Where); + + Assert.AreEqual(2, tables.Count); + Assert.AreEqual("users u", tables[0].ToSql()); + Assert.AreEqual("orders o", tables[1].ToSql()); + Assert.AreEqual("WHERE", parser.Peek().Normalized); + } + + /// + /// Ensures join chains remain intact until each JOIN has an ON clause. + /// + [TestMethod] + public void TableListReaderKeepsJoinedTablesTogether() + { + var parser = SqlParser.Create("customers c INNER JOIN orders o ON c.id = o.customer_id LEFT OUTER JOIN items i ON o.item_id = i.id GROUP BY c.id"); + var reader = new TableListReader(parser); + + var tables = reader.ReadTables("Table", ClauseStart.GroupBy); + + Assert.AreEqual(1, tables.Count); + Assert.AreEqual("customers c INNER JOIN orders o ON c.id = o.customer_id LEFT OUTER JOIN items i ON o.item_id = i.id", tables[0].ToSql()); + Assert.AreEqual("GROUP", parser.Peek().Normalized); + } + + /// + /// Ensures cross-apply subqueries are included in the table source without requiring ON clauses. + /// + [TestMethod] + public void TableListReaderHandlesCrossApplyWithSubquery() + { + var parser = SqlParser.Create("accounts a CROSS APPLY (SELECT * FROM transactions t WHERE t.account_id = a.id) tx WHERE a.active = 1"); + var reader = new TableListReader(parser); + + var tables = reader.ReadTables("Table", ClauseStart.Where); + + Assert.AreEqual(1, tables.Count); + Assert.AreEqual("accounts a CROSS APPLY(SELECT * FROM transactions t WHERE t.account_id = a.id) tx", tables[0].ToSql()); + Assert.AreEqual("WHERE", parser.Peek().Normalized); + } + + /// + /// Ensures missing ON clauses for JOIN operations trigger a parsing exception. + /// + [TestMethod] + public void TableListReaderThrowsWhenJoinIsIncomplete() + { + var parser = SqlParser.Create("users u INNER JOIN orders o WHERE 1 = 1"); + var reader = new TableListReader(parser); + + Assert.ThrowsException(() => reader.ReadTables("Table", ClauseStart.Where)); + } +} From a110287452da6459d6243adee06090225d85bfe2 Mon Sep 17 00:00:00 2001 From: warny Date: Sun, 4 Jan 2026 18:28:51 +0100 Subject: [PATCH 2/2] Run CI on all branches --- .github/workflows/dotnetcore.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/dotnetcore.yml b/.github/workflows/dotnetcore.yml index fcb61a6..dad43a2 100644 --- a/.github/workflows/dotnetcore.yml +++ b/.github/workflows/dotnetcore.yml @@ -3,8 +3,10 @@ name: Utils on: push: branches: - - master - - release + - '**' + pull_request: + branches: + - '**' jobs: build: