Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions internal/diff/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package diff

import (
"fmt"
"regexp"
"sort"
"strings"

Expand Down Expand Up @@ -1104,9 +1103,7 @@ func buildColumnClauses(column *ir.Column, isPartOfAnyPK bool, tableSchema strin
defaultValue = strings.ReplaceAll(defaultValue, schemaPrefix, "")
}

// Strip type qualifiers from default values
defaultValue = stripTypeQualifiers(defaultValue)

// Type casts are now preserved (from pg_query.Deparse) for canonical representation
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
}

Expand Down Expand Up @@ -1213,18 +1210,6 @@ func formatColumnDataTypeForCreate(column *ir.Column) string {
return dataType
}

// stripTypeQualifiers removes PostgreSQL type qualifiers from default values
func stripTypeQualifiers(defaultValue string) string {
// Use regex to match any type qualifier pattern (::typename)
// This handles both built-in types and user-defined types like enums
re := regexp.MustCompile(`(.*)::[a-zA-Z_][a-zA-Z0-9_\s]*(\[\])?$`)
matches := re.FindStringSubmatch(defaultValue)
if len(matches) > 1 {
return matches[1]
}
return defaultValue
}

// indexesStructurallyEqual compares two indexes for structural equality
// excluding comments and other metadata that don't require index recreation
func indexesStructurallyEqual(oldIndex, newIndex *ir.Index) bool {
Expand Down
139 changes: 112 additions & 27 deletions ir/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,15 @@ func (f *postgreSQLFormatter) formatBoolExpr(boolExpr *pg_query.BoolExpr) {

// formatTypeCast formats a type cast expression
func (f *postgreSQLFormatter) formatTypeCast(typeCast *pg_query.TypeCast) {
// Check if this is a redundant cast that should be stripped for cleaner output
if f.isRedundantTypeCast(typeCast) {
// Just format the argument without the cast
if typeCast.Arg != nil {
f.formatExpression(typeCast.Arg)
}
return
}

// Special handling for INTERVAL type casts
if typeCast.TypeName != nil && len(typeCast.TypeName.Names) > 0 {
// Get the type name (last element in the names array)
Expand Down Expand Up @@ -498,6 +507,101 @@ func (f *postgreSQLFormatter) formatTypeCast(typeCast *pg_query.TypeCast) {
}
}

// isRedundantTypeCast checks if a type cast is redundant and can be safely removed
// for cleaner view output. This includes:
// - Casts on string literals (e.g., 'value'::text, 'value'::varchar)
// - Casts on NULL (e.g., NULL::numeric)
// - Casts with pg_catalog schema qualifiers on basic types
// - Nested casts (e.g., 'value'::varchar::text)
func (f *postgreSQLFormatter) isRedundantTypeCast(typeCast *pg_query.TypeCast) bool {
if typeCast.Arg == nil || typeCast.TypeName == nil {
return false
}

// Helper to check if we can find a constant at the bottom of nested casts
var findBaseConstant func(*pg_query.Node) *pg_query.A_Const
findBaseConstant = func(node *pg_query.Node) *pg_query.A_Const {
if aConst := node.GetAConst(); aConst != nil {
return aConst
}
// Check if this is a nested type cast
if nestedCast := node.GetTypeCast(); nestedCast != nil && nestedCast.Arg != nil {
return findBaseConstant(nestedCast.Arg)
}
return nil
}

// Check if the argument is a constant value (possibly nested in casts)
if aConst := findBaseConstant(typeCast.Arg); aConst != nil {
// Get the type name to check if this is a text-like cast
typeName := ""
if len(typeCast.TypeName.Names) > 0 {
if len(typeCast.TypeName.Names) == 2 {
// Handle pg_catalog.typename
if schema := typeCast.TypeName.Names[0].GetString_(); schema != nil && schema.Sval == "pg_catalog" {
if typ := typeCast.TypeName.Names[1].GetString_(); typ != nil {
typeName = typ.Sval
}
}
} else if len(typeCast.TypeName.Names) == 1 {
if typ := typeCast.TypeName.Names[0].GetString_(); typ != nil {
typeName = typ.Sval
}
}
}

// String literal casts to text-like types are redundant (e.g., 'text'::text, 'value'::varchar)
// But date/timestamp casts are NOT redundant (e.g., '2020-01-01'::date)
if aConst.GetSval() != nil {
// Only strip casts to text-like types
textLikeTypes := []string{"text", "varchar", "character varying", "char", "character", "bpchar"}
for _, t := range textLikeTypes {
if typeName == t {
return true
}
}
// Keep all other casts (date, timestamp, numeric, etc.)
return false
}

// NULL casts are redundant (e.g., NULL::numeric → NULL)
if aConst.Isnull {
return true
}
}

// Check if this is a redundant column cast (e.g., column::text where column is already text)
// For view formatting, we strip casts on column references to basic types
if typeCast.Arg.GetColumnRef() != nil && typeCast.TypeName != nil {
// Get the type name
if len(typeCast.TypeName.Names) > 0 {
typeName := ""
// Check if it's a pg_catalog qualified type
if len(typeCast.TypeName.Names) == 2 {
if schema := typeCast.TypeName.Names[0].GetString_(); schema != nil && schema.Sval == "pg_catalog" {
if typ := typeCast.TypeName.Names[1].GetString_(); typ != nil {
typeName = typ.Sval
}
}
} else if len(typeCast.TypeName.Names) == 1 {
if typ := typeCast.TypeName.Names[0].GetString_(); typ != nil {
typeName = typ.Sval
}
}

// Common text-like types that are often redundantly cast
textLikeTypes := []string{"text", "varchar", "character varying", "char", "character", "bpchar"}
for _, t := range textLikeTypes {
if typeName == t {
return true
}
}
}
}

return false
}

// formatTypeName formats a type name
func (f *postgreSQLFormatter) formatTypeName(typeName *pg_query.TypeName) {
for i, nameNode := range typeName.Names {
Expand Down Expand Up @@ -558,14 +662,15 @@ func (f *postgreSQLFormatter) formatCaseExpr(caseExpr *pg_query.CaseExpr) {
f.buffer.WriteString(" WHEN ")
f.formatExpression(when.Expr)
f.buffer.WriteString(" THEN ")
f.formatExpressionStripCast(when.Result)
// Format result expression - redundant type casts will be stripped by formatTypeCast
f.formatExpression(when.Result)
}
}

// Format ELSE clause, stripping unnecessary type casts from constants/NULL
// Format ELSE clause - redundant type casts will be stripped by formatTypeCast
if caseExpr.Defresult != nil {
f.buffer.WriteString(" ELSE ")
f.formatExpressionStripCast(caseExpr.Defresult)
f.formatExpression(caseExpr.Defresult)
}

f.buffer.WriteString(" END")
Expand Down Expand Up @@ -647,26 +752,6 @@ func (f *postgreSQLFormatter) formatNullTest(nullTest *pg_query.NullTest) {
}
}

// formatExpressionStripCast formats an expression, stripping unnecessary type casts from constants and NULL
func (f *postgreSQLFormatter) formatExpressionStripCast(expr *pg_query.Node) {
// If this is a TypeCast of a constant or NULL, format just the value without the cast
if typeCast := expr.GetTypeCast(); typeCast != nil {
if typeCast.Arg != nil {
if aConst := typeCast.Arg.GetAConst(); aConst != nil {
// This is a typed constant, format just the constant value
f.formatAConst(aConst)
return
}
// For non-constant args, recursively strip casts
f.formatExpressionStripCast(typeCast.Arg)
return
}
}

// Otherwise, format normally
f.formatExpression(expr)
}

// formatAArrayExpr formats array expressions (ARRAY[...])
func (f *postgreSQLFormatter) formatAArrayExpr(arrayExpr *pg_query.A_ArrayExpr) {
f.buffer.WriteString("ARRAY[")
Expand All @@ -682,17 +767,17 @@ func (f *postgreSQLFormatter) formatAArrayExpr(arrayExpr *pg_query.A_ArrayExpr)
// formatArrayAsIN is a helper to format "column IN (values)" syntax
// Used by both formatAExpr and formatScalarArrayOpExpr to convert "= ANY(ARRAY[...])" to "IN (...)"
func (f *postgreSQLFormatter) formatArrayAsIN(leftExpr *pg_query.Node, arrayElements []*pg_query.Node) {
// Format left side (the column/expression)
f.formatExpressionStripCast(leftExpr)
// Format left side (the column/expression) - preserves type casts for canonical representation
f.formatExpression(leftExpr)

f.buffer.WriteString(" IN (")

// Format array elements as comma-separated list, stripping unnecessary type casts
// Format array elements as comma-separated list - preserves type casts for canonical representation
for i, elem := range arrayElements {
if i > 0 {
f.buffer.WriteString(", ")
}
f.formatExpressionStripCast(elem)
f.formatExpression(elem)
}

f.buffer.WriteString(")")
Expand Down
39 changes: 16 additions & 23 deletions ir/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,28 +112,21 @@ func normalizeDefaultValue(value string) string {
}

// Handle type casting - remove explicit type casts that are semantically equivalent
// Pattern: ''::text -> ''
// Pattern: '{}'::jsonb -> '{}'
// Use regex to properly handle type casts within complex expressions
// Pattern: 'literal'::type -> 'literal' (removes redundant casts from string literals)
if strings.Contains(value, "::") {
// Find the cast and remove it for simple literal values
if strings.HasPrefix(value, "'") {
if idx := strings.Index(value, "'::"); idx != -1 {
// Find the closing quote
if closeIdx := strings.Index(value[1:], "'"); closeIdx != -1 {
literal := value[:closeIdx+2] // Include the closing quote
if literal == "''" || literal == "'{}'" {
value = literal
}
}
}
}
// Pattern: 'G'::schema.type_name -> 'G'
// Pattern: 'G'::type_name -> 'G'
if strings.Contains(value, "'::") {
if idx := strings.Index(value, "'::"); idx != -1 {
value = value[:idx+1]
}
}
// Use regex to match and remove type casts from string literals
// This handles: 'text'::text, 'utc'::text, '{}'::jsonb, '{}'::text[], etc.
// Also handles multi-word types like 'value'::character varying
// Pattern explanation:
// '([^']*)' - matches a quoted string literal (capturing the content)
// ::[a-zA-Z_][\w\s.]* - matches ::typename
// [a-zA-Z_] - type name must start with letter or underscore
// [\w\s.]* - followed by word chars, spaces, or dots (for "character varying" or "pg_catalog.text")
// (?:\[\])? - optionally followed by [] for array types (non-capturing group)
// (?:\b|(?=\[)|$) - followed by word boundary, opening bracket, or end of string
re := regexp.MustCompile(`'([^']*)'::(?:[a-zA-Z_][\w\s.]*)(?:\[\])?`)
value = re.ReplaceAllString(value, "'$1'")
}

return value
Expand Down Expand Up @@ -891,7 +884,7 @@ func normalizeCheckClause(checkClause string) string {
func normalizeExpressionWithPgQuery(expr string) string {
// Create a dummy SELECT statement with the expression to parse it
dummySQL := fmt.Sprintf("SELECT %s", expr)

parseResult, err := pg_query.Parse(dummySQL)
if err != nil {
// If parsing fails, return empty string to trigger fallback
Expand Down Expand Up @@ -935,7 +928,7 @@ func removeRedundantNumericCasts(expr string) string {
re := regexp.MustCompile(pattern)
result = re.ReplaceAllString(result, "$1")
}

return result
}

Expand Down
Loading