diff --git a/internal/diff/identifier_quote_test.go b/internal/diff/identifier_quote_test.go index c003b2e4..04315138 100644 --- a/internal/diff/identifier_quote_test.go +++ b/internal/diff/identifier_quote_test.go @@ -7,88 +7,6 @@ import ( "github.com/pgschema/pgschema/internal/util" ) -func TestNeedsQuoting(t *testing.T) { - tests := []struct { - name string - identifier string - want bool - }{ - {"empty string", "", false}, - {"simple lowercase", "tablename", false}, - {"with underscore", "table_name", false}, - {"reserved word user", "user", true}, - {"reserved word USER", "USER", true}, - {"reserved word Order", "Order", true}, - {"camelCase", "userId", true}, - {"PascalCase", "CreatedAt", true}, - {"starts with number", "1table", true}, - {"contains special char", "table-name", true}, - {"all lowercase", "createdat", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := util.NeedsQuoting(tt.identifier); got != tt.want { - t.Errorf("util.NeedsQuoting(%q) = %v, want %v", tt.identifier, got, tt.want) - } - }) - } -} - -func TestQuoteIdentifier(t *testing.T) { - tests := []struct { - name string - identifier string - want string - }{ - {"simple lowercase", "tablename", "tablename"}, - {"reserved word", "user", `"user"`}, - {"camelCase", "userId", `"userId"`}, - {"already quoted", `"userId"`, `"userId"`}, // Should not double-quote - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Handle already quoted identifiers - if len(tt.identifier) > 2 && tt.identifier[0] == '"' && tt.identifier[len(tt.identifier)-1] == '"' { - // If already quoted, should return as-is - if got := tt.identifier; got != tt.want { - t.Errorf("already quoted identifier %q should remain %q, got %q", tt.identifier, tt.want, got) - } - } else { - if got := util.QuoteIdentifier(tt.identifier); got != tt.want { - t.Errorf("util.QuoteIdentifier(%q) = %q, want %q", tt.identifier, got, tt.want) - } - } - }) - } -} - -func TestQualifyEntityNameWithQuotes(t *testing.T) { - tests := []struct { - name string - entitySchema string - entityName string - targetSchema string - want string - }{ - {"same schema lowercase", "public", "users", "public", "users"}, - {"same schema camelCase", "public", "userId", "public", `"userId"`}, - {"different schema", "auth", "users", "public", "auth.users"}, - {"different schema camelCase", "auth", "userId", "public", `auth."userId"`}, - {"reserved word", "public", "user", "public", `"user"`}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := util.QualifyEntityNameWithQuotes(tt.entitySchema, tt.entityName, tt.targetSchema); got != tt.want { - t.Errorf("util.QualifyEntityNameWithQuotes(%q, %q, %q) = %q, want %q", - tt.entitySchema, tt.entityName, tt.targetSchema, got, tt.want) - } - }) - } -} - func TestGenerateConstraintSQL_WithQuoting(t *testing.T) { tests := []struct { name string @@ -159,6 +77,61 @@ func TestGenerateConstraintSQL_WithQuoting(t *testing.T) { } } +func TestCheckConstraintQuoting(t *testing.T) { + tests := []struct { + name string + constraint *ir.Constraint + want string + }{ + { + name: "CHECK with camelCase column", + constraint: &ir.Constraint{ + Name: "positive_followers", + Type: ir.ConstraintTypeCheck, + CheckClause: `CHECK ("followerCount" >= 0)`, + }, + want: `CHECK ("followerCount" >= 0)`, + }, + { + name: "CHECK with multiple camelCase columns and AND", + constraint: &ir.Constraint{ + Name: "valid_counts", + Type: ir.ConstraintTypeCheck, + CheckClause: `CHECK ("likeCount" >= 0 AND "commentCount" >= 0)`, + }, + want: `CHECK ("likeCount" >= 0 AND "commentCount" >= 0)`, + }, + { + name: "CHECK with BETWEEN", + constraint: &ir.Constraint{ + Name: "stock_range", + Type: ir.ConstraintTypeCheck, + CheckClause: `CHECK ("stockLevel" BETWEEN 0 AND 1000)`, + }, + want: `CHECK ("stockLevel" BETWEEN 0 AND 1000)`, + }, + { + name: "CHECK with IN clause", + constraint: &ir.Constraint{ + Name: "valid_status", + Type: ir.ConstraintTypeCheck, + CheckClause: `CHECK ("orderStatus" IN ('pending', 'shipped', 'delivered'))`, + }, + want: `CHECK ("orderStatus" IN ('pending', 'shipped', 'delivered'))`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // For CHECK constraints, generateConstraintSQL returns the CheckClause as-is + got := generateConstraintSQL(tt.constraint, "public") + if got != tt.want { + t.Errorf("generateConstraintSQL() for CHECK = %q, want %q", got, tt.want) + } + }) + } +} + func TestAddColumnIdentifierQuoting(t *testing.T) { tests := []struct { name string @@ -184,4 +157,4 @@ func TestAddColumnIdentifierQuoting(t *testing.T) { } }) } -} +} \ No newline at end of file diff --git a/internal/diff/table.go b/internal/diff/table.go index cf7b2b76..da2ea70e 100644 --- a/internal/diff/table.go +++ b/internal/diff/table.go @@ -657,11 +657,12 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector } collector.collect(context, sql) - case ir.ConstraintTypeCheck: - // CheckClause already contains "CHECK (...)" from the constraint definition + case ir.ConstraintTypeCheck: + // Ensure CHECK clause has outer parentheses around the full expression tableName := getTableNameWithSchema(td.Table.Schema, td.Table.Name, targetSchema) + clause := ensureCheckClauseParens(constraint.CheckClause) canonicalSQL := fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s %s;", - tableName, constraint.Name, constraint.CheckClause) + tableName, constraint.Name, clause) context := &diffContext{ Type: DiffTypeTableConstraint, @@ -776,10 +777,10 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector addSQL = fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s UNIQUE (%s);", tableName, constraint.Name, strings.Join(columnNames, ", ")) - case ir.ConstraintTypeCheck: - // Add CHECK constraint + case ir.ConstraintTypeCheck: + // Add CHECK constraint with ensured outer parentheses addSQL = fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s %s;", - tableName, constraint.Name, constraint.CheckClause) + tableName, constraint.Name, ensureCheckClauseParens(constraint.CheckClause)) case ir.ConstraintTypeForeignKey: // Sort columns by position @@ -1172,6 +1173,42 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector } } +// ensureCheckClauseParens guarantees that a CHECK clause string contains +// exactly one pair of outer parentheses around the full boolean expression. +// It expects input in the form: "CHECK " or "CHECK()" or "CHECK ()". +func ensureCheckClauseParens(s string) string { + t := strings.TrimSpace(s) + // Normalize leading "CHECK" token + if len(t) >= 5 && strings.EqualFold(t[:5], "check") { + t = t[5:] + } + expr := strings.TrimSpace(t) + + // Check if expression is already properly wrapped in parentheses + // by counting parenthesis depth to ensure the outer pair wraps the full expression + if len(expr) >= 2 && expr[0] == '(' { + depth := 0 + for i := 0; i < len(expr); i++ { + switch expr[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + if i == len(expr)-1 { + // The outermost paren pair wraps the full expression + return "CHECK " + expr + } + // Leading '(' closes before the end -> not fully wrapped + break + } + } + } + } + + return "CHECK (" + expr + ")" +} + // writeColumnDefinitionToBuilder builds column definitions with SERIAL detection and proper formatting // This is moved from ir/table.go to consolidate SQL generation in the diff module func writeColumnDefinitionToBuilder(builder *strings.Builder, table *ir.Table, column *ir.Column, targetSchema string) { diff --git a/internal/ir/parser.go b/internal/ir/parser.go index 5cc14803..abe6e41b 100644 --- a/internal/ir/parser.go +++ b/internal/ir/parser.go @@ -8,6 +8,7 @@ import ( "strings" pg_query "github.com/pganalyze/pg_query_go/v6" + "github.com/pgschema/pgschema/internal/util" ) // ParsingPhase represents the current phase of SQL parsing @@ -193,6 +194,9 @@ func (p *Parser) extractColumnName(node *pg_query.Node) string { // Convert trigger pseudo-relations and domain VALUE to uppercase if part == "new" || part == "old" || part == "value" { part = strings.ToUpper(part) + } else { + // Quote identifier if needed + part = util.QuoteIdentifier(part) } parts = append(parts, part) } @@ -533,7 +537,9 @@ func (p *Parser) parseInlineCheckConstraint(constraint *pg_query.Constraint, col // Handle check constraint expression if constraint.RawExpr != nil { - checkConstraint.CheckClause = "CHECK (" + p.extractExpressionText(constraint.RawExpr) + ")" + raw := p.extractExpressionText(constraint.RawExpr) + expr := p.wrapInParens(raw) + checkConstraint.CheckClause = "CHECK " + expr } return checkConstraint @@ -792,7 +798,9 @@ func (p *Parser) parseConstraint(constraint *pg_query.Constraint, schemaName, ta // Handle check constraint expression if constraintType == ConstraintTypeCheck && constraint.RawExpr != nil { - c.CheckClause = "CHECK (" + p.extractExpressionText(constraint.RawExpr) + ")" + raw := p.extractExpressionText(constraint.RawExpr) + expr := p.wrapInParens(raw) + c.CheckClause = "CHECK " + expr } // Set validation state based on what was specified in the SQL @@ -801,6 +809,31 @@ func (p *Parser) parseConstraint(constraint *pg_query.Constraint, schemaName, ta return c } +// wrapInParens ensures the expression has exactly one pair of outer parentheses +func (p *Parser) wrapInParens(s string) string { + s = strings.TrimSpace(s) + if len(s) >= 2 && s[0] == '(' { + depth := 0 + for i := 0; i < len(s); i++ { + switch s[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + if i == len(s)-1 { + // The outermost paren pair wraps the full expression + return s + } + // Leading '(' closes before the end -> not fully wrapped + break + } + } + } + } + return "(" + s + ")" +} + // generateConstraintName generates a default constraint name func (p *Parser) generateConstraintName(constraintType ConstraintType, tableName string, keys []*pg_query.Node) string { var suffix string @@ -908,15 +941,31 @@ func (p *Parser) parseAExpr(expr *pg_query.A_Expr) string { return fmt.Sprintf("%s IN %s", left, right) } - // Simplified implementation for basic expressions - if len(expr.Name) > 0 { - if str := expr.Name[0].GetString_(); str != nil { - op := str.Sval - left := p.extractExpressionText(expr.Lexpr) - right := p.extractExpressionText(expr.Rexpr) - return fmt.Sprintf("(%s %s %s)", left, op, right) - } - } + // Simplified implementation for basic expressions + if len(expr.Name) > 0 { + if str := expr.Name[0].GetString_(); str != nil { + op := str.Sval + left := p.extractExpressionText(expr.Lexpr) + // Special-case BETWEEN: right side comes as a 2-item list + if strings.EqualFold(op, "between") { + if listNode, ok := expr.Rexpr.Node.(*pg_query.Node_List); ok { + if len(listNode.List.Items) == 2 { + low := p.extractExpressionText(listNode.List.Items[0]) + high := p.extractExpressionText(listNode.List.Items[1]) + return fmt.Sprintf("%s BETWEEN %s AND %s", left, low, high) + } + } + } + right := p.extractExpressionText(expr.Rexpr) + // Add parentheses for comparison operators (matching PostgreSQL's internal format) + switch op { + case ">=", "<=", ">", "<", "=", "<>", "!=", "~", "~*", "!~", "!~*": + return fmt.Sprintf("(%s %s %s)", left, op, right) + default: + return fmt.Sprintf("%s %s %s", left, op, right) + } + } + } return "" } @@ -938,7 +987,14 @@ func (p *Parser) parseBoolExpr(expr *pg_query.BoolExpr) string { parts = append(parts, p.extractExpressionText(arg)) } - return "(" + strings.Join(parts, " "+op+" ") + ")" + // Only wrap in parentheses if it's a NOT expression or if there are multiple parts + if op == "NOT" { + return op + " " + strings.Join(parts, " ") + } + if len(parts) > 1 { + return "(" + strings.Join(parts, " "+op+" ") + ")" + } + return strings.Join(parts, " "+op+" ") } // parseList parses list expressions (e.g., for IN clauses) @@ -2334,7 +2390,7 @@ func (p *Parser) parseCreateDomain(domainStmt *pg_query.CreateDomainStmt) error constraintDef := "" if constraint.RawExpr != nil { exprText := p.extractExpressionText(constraint.RawExpr) - constraintDef = fmt.Sprintf("CHECK %s", exprText) + constraintDef = fmt.Sprintf("CHECK %s", p.wrapInParens(exprText)) } if constraintDef != "" {