Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 56 additions & 83 deletions internal/diff/identifier_quote_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,88 +7,6 @@ import (
"github.com/pgschema/pgschema/internal/util"
)

func TestNeedsQuoting(t *testing.T) {
tests := []struct {
name string
identifier string
want bool
}{
{"empty string", "", false},
{"simple lowercase", "tablename", false},
{"with underscore", "table_name", false},
{"reserved word user", "user", true},
{"reserved word USER", "USER", true},
{"reserved word Order", "Order", true},
{"camelCase", "userId", true},
{"PascalCase", "CreatedAt", true},
{"starts with number", "1table", true},
{"contains special char", "table-name", true},
{"all lowercase", "createdat", false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := util.NeedsQuoting(tt.identifier); got != tt.want {
t.Errorf("util.NeedsQuoting(%q) = %v, want %v", tt.identifier, got, tt.want)
}
})
}
}

func TestQuoteIdentifier(t *testing.T) {
tests := []struct {
name string
identifier string
want string
}{
{"simple lowercase", "tablename", "tablename"},
{"reserved word", "user", `"user"`},
{"camelCase", "userId", `"userId"`},
{"already quoted", `"userId"`, `"userId"`}, // Should not double-quote
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Handle already quoted identifiers
if len(tt.identifier) > 2 && tt.identifier[0] == '"' && tt.identifier[len(tt.identifier)-1] == '"' {
// If already quoted, should return as-is
if got := tt.identifier; got != tt.want {
t.Errorf("already quoted identifier %q should remain %q, got %q", tt.identifier, tt.want, got)
}
} else {
if got := util.QuoteIdentifier(tt.identifier); got != tt.want {
t.Errorf("util.QuoteIdentifier(%q) = %q, want %q", tt.identifier, got, tt.want)
}
}
})
}
}

func TestQualifyEntityNameWithQuotes(t *testing.T) {
tests := []struct {
name string
entitySchema string
entityName string
targetSchema string
want string
}{
{"same schema lowercase", "public", "users", "public", "users"},
{"same schema camelCase", "public", "userId", "public", `"userId"`},
{"different schema", "auth", "users", "public", "auth.users"},
{"different schema camelCase", "auth", "userId", "public", `auth."userId"`},
{"reserved word", "public", "user", "public", `"user"`},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := util.QualifyEntityNameWithQuotes(tt.entitySchema, tt.entityName, tt.targetSchema); got != tt.want {
t.Errorf("util.QualifyEntityNameWithQuotes(%q, %q, %q) = %q, want %q",
tt.entitySchema, tt.entityName, tt.targetSchema, got, tt.want)
}
})
}
}

func TestGenerateConstraintSQL_WithQuoting(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -159,6 +77,61 @@ func TestGenerateConstraintSQL_WithQuoting(t *testing.T) {
}
}

func TestCheckConstraintQuoting(t *testing.T) {
tests := []struct {
name string
constraint *ir.Constraint
want string
}{
{
name: "CHECK with camelCase column",
constraint: &ir.Constraint{
Name: "positive_followers",
Type: ir.ConstraintTypeCheck,
CheckClause: `CHECK ("followerCount" >= 0)`,
},
want: `CHECK ("followerCount" >= 0)`,
},
{
name: "CHECK with multiple camelCase columns and AND",
constraint: &ir.Constraint{
Name: "valid_counts",
Type: ir.ConstraintTypeCheck,
CheckClause: `CHECK ("likeCount" >= 0 AND "commentCount" >= 0)`,
},
want: `CHECK ("likeCount" >= 0 AND "commentCount" >= 0)`,
},
{
name: "CHECK with BETWEEN",
constraint: &ir.Constraint{
Name: "stock_range",
Type: ir.ConstraintTypeCheck,
CheckClause: `CHECK ("stockLevel" BETWEEN 0 AND 1000)`,
},
want: `CHECK ("stockLevel" BETWEEN 0 AND 1000)`,
},
{
name: "CHECK with IN clause",
constraint: &ir.Constraint{
Name: "valid_status",
Type: ir.ConstraintTypeCheck,
CheckClause: `CHECK ("orderStatus" IN ('pending', 'shipped', 'delivered'))`,
},
want: `CHECK ("orderStatus" IN ('pending', 'shipped', 'delivered'))`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// For CHECK constraints, generateConstraintSQL returns the CheckClause as-is
got := generateConstraintSQL(tt.constraint, "public")
if got != tt.want {
t.Errorf("generateConstraintSQL() for CHECK = %q, want %q", got, tt.want)
}
})
}
}

func TestAddColumnIdentifierQuoting(t *testing.T) {
tests := []struct {
name string
Expand All @@ -184,4 +157,4 @@ func TestAddColumnIdentifierQuoting(t *testing.T) {
}
})
}
}
}
49 changes: 43 additions & 6 deletions internal/diff/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,11 +657,12 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector
}
collector.collect(context, sql)

case ir.ConstraintTypeCheck:
// CheckClause already contains "CHECK (...)" from the constraint definition
case ir.ConstraintTypeCheck:
// Ensure CHECK clause has outer parentheses around the full expression
tableName := getTableNameWithSchema(td.Table.Schema, td.Table.Name, targetSchema)
clause := ensureCheckClauseParens(constraint.CheckClause)
canonicalSQL := fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s %s;",
tableName, constraint.Name, constraint.CheckClause)
tableName, constraint.Name, clause)

context := &diffContext{
Type: DiffTypeTableConstraint,
Expand Down Expand Up @@ -776,10 +777,10 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector
addSQL = fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s UNIQUE (%s);",
tableName, constraint.Name, strings.Join(columnNames, ", "))

case ir.ConstraintTypeCheck:
// Add CHECK constraint
case ir.ConstraintTypeCheck:
// Add CHECK constraint with ensured outer parentheses
addSQL = fmt.Sprintf("ALTER TABLE %s\nADD CONSTRAINT %s %s;",
tableName, constraint.Name, constraint.CheckClause)
tableName, constraint.Name, ensureCheckClauseParens(constraint.CheckClause))

case ir.ConstraintTypeForeignKey:
// Sort columns by position
Expand Down Expand Up @@ -1172,6 +1173,42 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector
}
}

// ensureCheckClauseParens guarantees that a CHECK clause string contains
// exactly one pair of outer parentheses around the full boolean expression.
// It expects input in the form: "CHECK <expr>" or "CHECK(<expr>)" or "CHECK (<expr>)".
func ensureCheckClauseParens(s string) string {
t := strings.TrimSpace(s)
// Normalize leading "CHECK" token
if len(t) >= 5 && strings.EqualFold(t[:5], "check") {
t = t[5:]
}
expr := strings.TrimSpace(t)

// Check if expression is already properly wrapped in parentheses
// by counting parenthesis depth to ensure the outer pair wraps the full expression
if len(expr) >= 2 && expr[0] == '(' {
depth := 0
for i := 0; i < len(expr); i++ {
switch expr[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 {
if i == len(expr)-1 {
// The outermost paren pair wraps the full expression
return "CHECK " + expr
}
// Leading '(' closes before the end -> not fully wrapped
break
}
}
}
}

return "CHECK (" + expr + ")"
}

// writeColumnDefinitionToBuilder builds column definitions with SERIAL detection and proper formatting
// This is moved from ir/table.go to consolidate SQL generation in the diff module
func writeColumnDefinitionToBuilder(builder *strings.Builder, table *ir.Table, column *ir.Column, targetSchema string) {
Expand Down
82 changes: 69 additions & 13 deletions internal/ir/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

pg_query "github.com/pganalyze/pg_query_go/v6"
"github.com/pgschema/pgschema/internal/util"
)

// ParsingPhase represents the current phase of SQL parsing
Expand Down Expand Up @@ -193,6 +194,9 @@ func (p *Parser) extractColumnName(node *pg_query.Node) string {
// Convert trigger pseudo-relations and domain VALUE to uppercase
if part == "new" || part == "old" || part == "value" {
part = strings.ToUpper(part)
} else {
// Quote identifier if needed
part = util.QuoteIdentifier(part)
}
parts = append(parts, part)
}
Expand Down Expand Up @@ -533,7 +537,9 @@ func (p *Parser) parseInlineCheckConstraint(constraint *pg_query.Constraint, col

// Handle check constraint expression
if constraint.RawExpr != nil {
checkConstraint.CheckClause = "CHECK (" + p.extractExpressionText(constraint.RawExpr) + ")"
raw := p.extractExpressionText(constraint.RawExpr)
expr := p.wrapInParens(raw)
checkConstraint.CheckClause = "CHECK " + expr
}

return checkConstraint
Expand Down Expand Up @@ -792,7 +798,9 @@ func (p *Parser) parseConstraint(constraint *pg_query.Constraint, schemaName, ta

// Handle check constraint expression
if constraintType == ConstraintTypeCheck && constraint.RawExpr != nil {
c.CheckClause = "CHECK (" + p.extractExpressionText(constraint.RawExpr) + ")"
raw := p.extractExpressionText(constraint.RawExpr)
expr := p.wrapInParens(raw)
c.CheckClause = "CHECK " + expr
}

// Set validation state based on what was specified in the SQL
Expand All @@ -801,6 +809,31 @@ func (p *Parser) parseConstraint(constraint *pg_query.Constraint, schemaName, ta
return c
}

// wrapInParens ensures the expression has exactly one pair of outer parentheses
func (p *Parser) wrapInParens(s string) string {
s = strings.TrimSpace(s)
if len(s) >= 2 && s[0] == '(' {
depth := 0
for i := 0; i < len(s); i++ {
switch s[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 {
if i == len(s)-1 {
// The outermost paren pair wraps the full expression
return s
}
// Leading '(' closes before the end -> not fully wrapped
break
}
}
}
}
return "(" + s + ")"
}

// generateConstraintName generates a default constraint name
func (p *Parser) generateConstraintName(constraintType ConstraintType, tableName string, keys []*pg_query.Node) string {
var suffix string
Expand Down Expand Up @@ -908,15 +941,31 @@ func (p *Parser) parseAExpr(expr *pg_query.A_Expr) string {
return fmt.Sprintf("%s IN %s", left, right)
}

// Simplified implementation for basic expressions
if len(expr.Name) > 0 {
if str := expr.Name[0].GetString_(); str != nil {
op := str.Sval
left := p.extractExpressionText(expr.Lexpr)
right := p.extractExpressionText(expr.Rexpr)
return fmt.Sprintf("(%s %s %s)", left, op, right)
}
}
// Simplified implementation for basic expressions
if len(expr.Name) > 0 {
if str := expr.Name[0].GetString_(); str != nil {
op := str.Sval
left := p.extractExpressionText(expr.Lexpr)
// Special-case BETWEEN: right side comes as a 2-item list
if strings.EqualFold(op, "between") {
if listNode, ok := expr.Rexpr.Node.(*pg_query.Node_List); ok {
if len(listNode.List.Items) == 2 {
low := p.extractExpressionText(listNode.List.Items[0])
high := p.extractExpressionText(listNode.List.Items[1])
return fmt.Sprintf("%s BETWEEN %s AND %s", left, low, high)
}
}
}
right := p.extractExpressionText(expr.Rexpr)
// Add parentheses for comparison operators (matching PostgreSQL's internal format)
switch op {
case ">=", "<=", ">", "<", "=", "<>", "!=", "~", "~*", "!~", "!~*":
return fmt.Sprintf("(%s %s %s)", left, op, right)
default:
return fmt.Sprintf("%s %s %s", left, op, right)
}
}
}
return ""
}

Expand All @@ -938,7 +987,14 @@ func (p *Parser) parseBoolExpr(expr *pg_query.BoolExpr) string {
parts = append(parts, p.extractExpressionText(arg))
}

return "(" + strings.Join(parts, " "+op+" ") + ")"
// Only wrap in parentheses if it's a NOT expression or if there are multiple parts
if op == "NOT" {
return op + " " + strings.Join(parts, " ")
}
if len(parts) > 1 {
return "(" + strings.Join(parts, " "+op+" ") + ")"
}
return strings.Join(parts, " "+op+" ")
}

// parseList parses list expressions (e.g., for IN clauses)
Expand Down Expand Up @@ -2334,7 +2390,7 @@ func (p *Parser) parseCreateDomain(domainStmt *pg_query.CreateDomainStmt) error
constraintDef := ""
if constraint.RawExpr != nil {
exprText := p.extractExpressionText(constraint.RawExpr)
constraintDef = fmt.Sprintf("CHECK %s", exprText)
constraintDef = fmt.Sprintf("CHECK %s", p.wrapInParens(exprText))
}

if constraintDef != "" {
Expand Down