From 0dcd3e2c57333fd9b041ac9df8be60339c85e1cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20R=C4=85czka?= <04_barista_pads@icloud.com> Date: Sat, 6 Sep 2025 19:04:34 +0200 Subject: [PATCH 1/3] fix: preserve identifier quoting in CHECK constraints and fix BETWEEN parsing This PR fixes two related issues with CHECK constraint parsing: 1. CHECK constraints with camelCase columns were losing their quotes - Added minimal quoteIdentifierIfNeeded() to preserve camelCase quoting - Modified extractColumnName() to quote identifiers when needed 2. BETWEEN expressions were incorrectly parsed as BETWEEN (X, Y) instead of BETWEEN X AND Y - Fixed parseAExpr() to handle BETWEEN with proper AND syntax - Removed unnecessary parentheses around simple comparisons These changes ensure that CHECK constraints work correctly with camelCase column names and complex expressions, which is critical for compatibility with modern ORMs like Better Auth and Prisma. Note: Aware of minor code duplication in quoting logic; planning to extract to shared util package in follow-up PR to keep this fix focused and minimal. Fixes identifier quoting issues reported with Better Auth integration. --- internal/diff/identifier_quote_test.go | 55 +++++++++++++ internal/diff/table.go | 33 ++++++-- internal/ir/parser.go | 106 ++++++++++++++++++++++--- 3 files changed, 177 insertions(+), 17 deletions(-) diff --git a/internal/diff/identifier_quote_test.go b/internal/diff/identifier_quote_test.go index cba70249..b2c23b70 100644 --- a/internal/diff/identifier_quote_test.go +++ b/internal/diff/identifier_quote_test.go @@ -157,3 +157,58 @@ func TestGenerateConstraintSQL_WithQuoting(t *testing.T) { }) } } + +func TestCheckConstraintQuoting(t *testing.T) { + tests := []struct { + name string + constraint *ir.Constraint + want string + }{ + { + name: "CHECK with camelCase column", + constraint: &ir.Constraint{ + Name: "positive_followers", + Type: ir.ConstraintTypeCheck, + CheckClause: `CHECK ("followerCount" >= 0)`, + }, + want: `CHECK ("followerCount" >= 0)`, + }, + { + name: "CHECK with multiple camelCase columns and AND", + constraint: &ir.Constraint{ + Name: "valid_counts", + Type: ir.ConstraintTypeCheck, + CheckClause: `CHECK ("likeCount" >= 0 AND "commentCount" >= 0)`, + }, + want: `CHECK ("likeCount" >= 0 AND "commentCount" >= 0)`, + }, + { + name: "CHECK with BETWEEN", + constraint: &ir.Constraint{ + Name: "stock_range", + Type: ir.ConstraintTypeCheck, + CheckClause: `CHECK ("stockLevel" BETWEEN 0 AND 1000)`, + }, + want: `CHECK ("stockLevel" BETWEEN 0 AND 1000)`, + }, + { + name: "CHECK with IN clause", + constraint: &ir.Constraint{ + Name: "valid_status", + Type: ir.ConstraintTypeCheck, + CheckClause: `CHECK ("orderStatus" IN ('pending', 'shipped', 'delivered'))`, + }, + want: `CHECK ("orderStatus" IN ('pending', 'shipped', 'delivered'))`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // For CHECK constraints, generateConstraintSQL returns the CheckClause as-is + got := generateConstraintSQL(tt.constraint, "public") + if got != tt.want { + t.Errorf("generateConstraintSQL() for CHECK = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/diff/table.go b/internal/diff/table.go index 09f6b04e..00deb252 100644 --- a/internal/diff/table.go +++ b/internal/diff/table.go @@ -656,11 +656,12 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector } collector.collect(context, sql) - case ir.ConstraintTypeCheck: - // CheckClause already contains "CHECK (...)" from the constraint definition + case ir.ConstraintTypeCheck: + // Ensure CHECK clause has outer parentheses around the full expression tableName := getTableNameWithSchema(td.Table.Schema, td.Table.Name, targetSchema) + clause := ensureCheckClauseParens(constraint.CheckClause) canonicalSQL := fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s %s;", - tableName, constraint.Name, constraint.CheckClause) + tableName, constraint.Name, clause) context := &diffContext{ Type: DiffTypeTableConstraint, @@ -775,10 +776,10 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector addSQL = fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s UNIQUE (%s);", tableName, constraint.Name, strings.Join(columnNames, ", ")) - case ir.ConstraintTypeCheck: - // Add CHECK constraint + case ir.ConstraintTypeCheck: + // Add CHECK constraint with ensured outer parentheses addSQL = fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s %s;", - tableName, constraint.Name, constraint.CheckClause) + tableName, constraint.Name, ensureCheckClauseParens(constraint.CheckClause)) case ir.ConstraintTypeForeignKey: // Sort columns by position @@ -1171,6 +1172,26 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector } } +// ensureCheckClauseParens guarantees that a CHECK clause string contains +// exactly one pair of outer parentheses around the full boolean expression. +// It expects input in the form: "CHECK " or "CHECK()" or "CHECK ()". +func ensureCheckClauseParens(s string) string { + t := strings.TrimSpace(s) + // Normalize leading "CHECK" token + if len(t) >= 5 && strings.EqualFold(t[:5], "check") { + t = t[5:] + } + t = strings.TrimSpace(t) + // Remove optional single leading parenthesis from the word CHECK (e.g., CHECK()) + // We treat whatever remains as the expression. + expr := t + // If expression already has a single outer pair of parens, keep as is + if len(expr) >= 2 && expr[0] == '(' && expr[len(expr)-1] == ')' { + return "CHECK " + expr + } + return "CHECK (" + expr + ")" +} + // writeColumnDefinitionToBuilder builds column definitions with SERIAL detection and proper formatting // This is moved from ir/table.go to consolidate SQL generation in the diff module func writeColumnDefinitionToBuilder(builder *strings.Builder, table *ir.Table, column *ir.Column, targetSchema string) { diff --git a/internal/ir/parser.go b/internal/ir/parser.go index f68af7fb..d9561dcf 100644 --- a/internal/ir/parser.go +++ b/internal/ir/parser.go @@ -6,6 +6,7 @@ import ( "regexp" "strconv" "strings" + "unicode" pg_query "github.com/pganalyze/pg_query_go/v6" ) @@ -193,6 +194,9 @@ func (p *Parser) extractColumnName(node *pg_query.Node) string { // Convert trigger pseudo-relations and domain VALUE to uppercase if part == "new" || part == "old" || part == "value" { part = strings.ToUpper(part) + } else { + // Quote identifier if needed + part = p.quoteIdentifierIfNeeded(part) } parts = append(parts, part) } @@ -206,6 +210,47 @@ func (p *Parser) extractColumnName(node *pg_query.Node) string { return "" } +// quoteIdentifierIfNeeded adds quotes to an identifier if it needs them +func (p *Parser) quoteIdentifierIfNeeded(identifier string) string { + if identifier == "" { + return identifier + } + + // Check if it contains uppercase letters (PostgreSQL folds unquoted to lowercase) + hasUpper := false + for _, r := range identifier { + if unicode.IsUpper(r) { + hasUpper = true + break + } + } + + if hasUpper { + return `"` + identifier + `"` + } + + // Check if it's a reserved word + reservedWords := map[string]bool{ + "user": true, "order": true, "group": true, "select": true, + "from": true, "where": true, "table": true, "check": true, + } + if reservedWords[strings.ToLower(identifier)] { + return `"` + identifier + `"` + } + + // Check if it starts with non-letter or contains special characters + for i, r := range identifier { + if i == 0 && !unicode.IsLetter(r) && r != '_' { + return `"` + identifier + `"` + } + if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' { + return `"` + identifier + `"` + } + } + + return identifier +} + // Helper function to extract string value from Node func (p *Parser) extractStringValue(node *pg_query.Node) string { if node == nil { @@ -533,7 +578,9 @@ func (p *Parser) parseInlineCheckConstraint(constraint *pg_query.Constraint, col // Handle check constraint expression if constraint.RawExpr != nil { - checkConstraint.CheckClause = "CHECK (" + p.extractExpressionText(constraint.RawExpr) + ")" + raw := p.extractExpressionText(constraint.RawExpr) + expr := p.wrapInParens(raw) + checkConstraint.CheckClause = "CHECK " + expr } return checkConstraint @@ -787,7 +834,9 @@ func (p *Parser) parseConstraint(constraint *pg_query.Constraint, schemaName, ta // Handle check constraint expression if constraintType == ConstraintTypeCheck && constraint.RawExpr != nil { - c.CheckClause = "CHECK (" + p.extractExpressionText(constraint.RawExpr) + ")" + raw := p.extractExpressionText(constraint.RawExpr) + expr := p.wrapInParens(raw) + c.CheckClause = "CHECK " + expr } // Set validation state based on what was specified in the SQL @@ -796,6 +845,31 @@ func (p *Parser) parseConstraint(constraint *pg_query.Constraint, schemaName, ta return c } +// wrapInParens ensures the expression has exactly one pair of outer parentheses +func (p *Parser) wrapInParens(s string) string { + s = strings.TrimSpace(s) + if len(s) >= 2 && s[0] == '(' { + depth := 0 + for i := 0; i < len(s); i++ { + switch s[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + if i == len(s)-1 { + // The outermost paren pair wraps the full expression + return s + } + // Leading '(' closes before the end -> not fully wrapped + break + } + } + } + } + return "(" + s + ")" +} + // generateConstraintName generates a default constraint name func (p *Parser) generateConstraintName(constraintType ConstraintType, tableName string, keys []*pg_query.Node) string { var suffix string @@ -903,15 +977,25 @@ func (p *Parser) parseAExpr(expr *pg_query.A_Expr) string { return fmt.Sprintf("%s IN %s", left, right) } - // Simplified implementation for basic expressions - if len(expr.Name) > 0 { - if str := expr.Name[0].GetString_(); str != nil { - op := str.Sval - left := p.extractExpressionText(expr.Lexpr) - right := p.extractExpressionText(expr.Rexpr) - return fmt.Sprintf("(%s %s %s)", left, op, right) - } - } + // Simplified implementation for basic expressions + if len(expr.Name) > 0 { + if str := expr.Name[0].GetString_(); str != nil { + op := str.Sval + left := p.extractExpressionText(expr.Lexpr) + // Special-case BETWEEN: right side comes as a 2-item list + if strings.EqualFold(op, "between") { + if listNode, ok := expr.Rexpr.Node.(*pg_query.Node_List); ok { + if len(listNode.List.Items) == 2 { + low := p.extractExpressionText(listNode.List.Items[0]) + high := p.extractExpressionText(listNode.List.Items[1]) + return fmt.Sprintf("%s BETWEEN %s AND %s", left, low, high) + } + } + } + right := p.extractExpressionText(expr.Rexpr) + return fmt.Sprintf("%s %s %s", left, op, right) + } + } return "" } From caf368ba717bf96c8b22f153760e3ef420b4a90b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20R=C4=85czka?= <116732045+screenfluent@users.noreply.github.com> Date: Sat, 6 Sep 2025 19:09:23 +0200 Subject: [PATCH 2/3] Update internal/ir/parser.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/ir/parser.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/ir/parser.go b/internal/ir/parser.go index d9561dcf..8aac8734 100644 --- a/internal/ir/parser.go +++ b/internal/ir/parser.go @@ -993,7 +993,7 @@ func (p *Parser) parseAExpr(expr *pg_query.A_Expr) string { } } right := p.extractExpressionText(expr.Rexpr) - return fmt.Sprintf("%s %s %s", left, op, right) + return fmt.Sprintf("(%s) %s (%s)", left, op, right) } } return "" From 3f739f4ae09e6cd493ecd49985237e103aeff872 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20R=C4=85czka?= <04_barista_pads@icloud.com> Date: Wed, 10 Sep 2025 17:31:56 +0200 Subject: [PATCH 3/3] fix: resolve merge conflicts and improve CHECK constraint handling - Merge upstream main and resolve conflicts in identifier_quote_test.go - Improve hardcoded logic in ensureCheckClauseParens using depth-based validation - Remove duplicate quoteIdentifierIfNeeded function, use util.QuoteIdentifier instead - Fix CHECK constraint parsing for domains and expressions - Add parentheses around comparison operators to match PostgreSQL's internal format - Fix unused import issue All tests now pass (100% success rate). Addresses maintainer's feedback about hardcoded logic and resolves merge conflicts. --- internal/diff/table.go | 30 +++++++++++++++----- internal/ir/parser.go | 64 ++++++++++++------------------------------ 2 files changed, 41 insertions(+), 53 deletions(-) diff --git a/internal/diff/table.go b/internal/diff/table.go index 95405183..da2ea70e 100644 --- a/internal/diff/table.go +++ b/internal/diff/table.go @@ -1182,14 +1182,30 @@ func ensureCheckClauseParens(s string) string { if len(t) >= 5 && strings.EqualFold(t[:5], "check") { t = t[5:] } - t = strings.TrimSpace(t) - // Remove optional single leading parenthesis from the word CHECK (e.g., CHECK()) - // We treat whatever remains as the expression. - expr := t - // If expression already has a single outer pair of parens, keep as is - if len(expr) >= 2 && expr[0] == '(' && expr[len(expr)-1] == ')' { - return "CHECK " + expr + expr := strings.TrimSpace(t) + + // Check if expression is already properly wrapped in parentheses + // by counting parenthesis depth to ensure the outer pair wraps the full expression + if len(expr) >= 2 && expr[0] == '(' { + depth := 0 + for i := 0; i < len(expr); i++ { + switch expr[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + if i == len(expr)-1 { + // The outermost paren pair wraps the full expression + return "CHECK " + expr + } + // Leading '(' closes before the end -> not fully wrapped + break + } + } + } } + return "CHECK (" + expr + ")" } diff --git a/internal/ir/parser.go b/internal/ir/parser.go index a085b4f0..abe6e41b 100644 --- a/internal/ir/parser.go +++ b/internal/ir/parser.go @@ -6,9 +6,9 @@ import ( "regexp" "strconv" "strings" - "unicode" pg_query "github.com/pganalyze/pg_query_go/v6" + "github.com/pgschema/pgschema/internal/util" ) // ParsingPhase represents the current phase of SQL parsing @@ -196,7 +196,7 @@ func (p *Parser) extractColumnName(node *pg_query.Node) string { part = strings.ToUpper(part) } else { // Quote identifier if needed - part = p.quoteIdentifierIfNeeded(part) + part = util.QuoteIdentifier(part) } parts = append(parts, part) } @@ -210,47 +210,6 @@ func (p *Parser) extractColumnName(node *pg_query.Node) string { return "" } -// quoteIdentifierIfNeeded adds quotes to an identifier if it needs them -func (p *Parser) quoteIdentifierIfNeeded(identifier string) string { - if identifier == "" { - return identifier - } - - // Check if it contains uppercase letters (PostgreSQL folds unquoted to lowercase) - hasUpper := false - for _, r := range identifier { - if unicode.IsUpper(r) { - hasUpper = true - break - } - } - - if hasUpper { - return `"` + identifier + `"` - } - - // Check if it's a reserved word - reservedWords := map[string]bool{ - "user": true, "order": true, "group": true, "select": true, - "from": true, "where": true, "table": true, "check": true, - } - if reservedWords[strings.ToLower(identifier)] { - return `"` + identifier + `"` - } - - // Check if it starts with non-letter or contains special characters - for i, r := range identifier { - if i == 0 && !unicode.IsLetter(r) && r != '_' { - return `"` + identifier + `"` - } - if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' { - return `"` + identifier + `"` - } - } - - return identifier -} - // Helper function to extract string value from Node func (p *Parser) extractStringValue(node *pg_query.Node) string { if node == nil { @@ -998,7 +957,13 @@ func (p *Parser) parseAExpr(expr *pg_query.A_Expr) string { } } right := p.extractExpressionText(expr.Rexpr) - return fmt.Sprintf("(%s) %s (%s)", left, op, right) + // Add parentheses for comparison operators (matching PostgreSQL's internal format) + switch op { + case ">=", "<=", ">", "<", "=", "<>", "!=", "~", "~*", "!~", "!~*": + return fmt.Sprintf("(%s %s %s)", left, op, right) + default: + return fmt.Sprintf("%s %s %s", left, op, right) + } } } return "" @@ -1022,7 +987,14 @@ func (p *Parser) parseBoolExpr(expr *pg_query.BoolExpr) string { parts = append(parts, p.extractExpressionText(arg)) } - return "(" + strings.Join(parts, " "+op+" ") + ")" + // Only wrap in parentheses if it's a NOT expression or if there are multiple parts + if op == "NOT" { + return op + " " + strings.Join(parts, " ") + } + if len(parts) > 1 { + return "(" + strings.Join(parts, " "+op+" ") + ")" + } + return strings.Join(parts, " "+op+" ") } // parseList parses list expressions (e.g., for IN clauses) @@ -2418,7 +2390,7 @@ func (p *Parser) parseCreateDomain(domainStmt *pg_query.CreateDomainStmt) error constraintDef := "" if constraint.RawExpr != nil { exprText := p.extractExpressionText(constraint.RawExpr) - constraintDef = fmt.Sprintf("CHECK %s", exprText) + constraintDef = fmt.Sprintf("CHECK %s", p.wrapInParens(exprText)) } if constraintDef != "" {