Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
908 changes: 0 additions & 908 deletions ir/formatter.go

This file was deleted.

328 changes: 21 additions & 307 deletions ir/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,17 @@ import (
"sort"
"strings"
"unicode"

pg_query "github.com/pganalyze/pg_query_go/v6"
)

// normalizeIR normalizes the IR representation from the inspector
//
// Historical note: This normalization was originally needed to reconcile differences
// between parsed SQL (from parser.go) and database-inspected schema (from inspector.go).
// Since the parser was removed in favor of the embedded-postgres approach (both desired
// and current states now come from database inspection), much of this normalization is
// no longer necessary and can be simplified in a future refactor.
// normalizeIR normalizes the IR representation from the inspector.
//
// Current normalization still handles:
// - PostgreSQL version differences (PG 14 vs 17 format variations)
// - Type name mappings (internal PostgreSQL types → standard SQL types)
// - View definition formatting across different versions
// Since both desired state (from embedded postgres) and current state (from target database)
// now come from the same PostgreSQL version via database inspection, most normalizations
// are no longer needed. The remaining normalizations handle:
//
// TODO: Simplify this file to remove parser-specific normalizations
// - Type name mappings (internal PostgreSQL types → standard SQL types, e.g., int4 → integer)
// - PostgreSQL internal representations (e.g., "~~ " → "LIKE", "= ANY (ARRAY[...])" → "IN (...)")
// - Minor formatting differences in default values, policies, triggers, etc.
func normalizeIR(ir *IR) {
if ir == nil {
return
Expand Down Expand Up @@ -222,44 +215,18 @@ func normalizePolicyExpression(expr string) string {
return expr
}

// normalizeView normalizes view definition
// normalizeView normalizes view definition.
//
// Since both desired state (from embedded postgres) and current state (from target database)
// now come from the same PostgreSQL version via pg_get_viewdef(), they produce identical
// output and no normalization is needed.
func normalizeView(view *View) {
if view == nil {
return
}

view.Definition = normalizeViewDefinition(view.Definition, view.Schema)
}

// normalizeViewDefinition normalizes view SQL definition for consistent comparison
// across different PostgreSQL versions.
//
// PostgreSQL versions produce different pg_get_viewdef() output:
// - PostgreSQL 15: Includes table qualifiers → "dept_emp.emp_no, max(dept_emp.from_date)"
// - PostgreSQL 16+: Omits unnecessary qualifiers → "emp_no, max(from_date)"
//
// This function removes unnecessary table qualifiers from column references when unambiguous
// to ensure consistent comparison between Inspector (database) and Parser (SQL files).
func normalizeViewDefinition(definition string, viewSchema string) string {
if definition == "" {
return definition
}

// Parse the view definition to get AST and remove unnecessary table qualifiers
normalized, err := removeUnnecessaryTableQualifiers(definition)
if err != nil {
// If parsing fails, use the original definition
normalized = definition
}

// Apply all AST-based normalizations in one pass to avoid re-parsing
// This includes:
// 1. Converting PostgreSQL's "= ANY (ARRAY[...])" to "IN (...)"
// 2. Normalizing ORDER BY clauses to use aliases
// 3. Applying proper schema qualification rules for table references
normalized = normalizeViewWithAST(normalized, viewSchema)

return normalized
// No normalization needed - both IR forms come from database inspection
// at the same PostgreSQL version, so pg_get_viewdef() output is identical
}

// normalizeFunction normalizes function signature and definition
Expand Down Expand Up @@ -829,8 +796,11 @@ func normalizeConstraint(constraint *Constraint) {
}
}

// normalizeCheckClause converts PostgreSQL's normalized CHECK expressions to parser format
// Uses pg_query to parse and deparse for consistent normalization
// normalizeCheckClause normalizes CHECK constraint expressions.
//
// Since both desired state (from embedded postgres) and current state (from target database)
// now come from the same PostgreSQL version via pg_get_constraintdef(), they produce identical
// output. We only need basic cleanup for PostgreSQL internal representations.
func normalizeCheckClause(checkClause string) string {
// Strip " NOT VALID" suffix if present (mimicking pg_dump behavior)
// PostgreSQL's pg_get_constraintdef may include NOT VALID at the end,
Expand All @@ -855,71 +825,13 @@ func normalizeCheckClause(checkClause string) string {
}
}

// Apply legacy normalizations for PostgreSQL-specific patterns
// Apply basic normalizations for PostgreSQL internal representations
// (e.g., "~~ " to "LIKE", "= ANY (ARRAY[...])" to "IN (...)")
normalizedClause := applyLegacyCheckNormalizations(clause)

// Try to normalize using pg_query parse/deparse for consistent formatting
pgNormalizedClause := normalizeExpressionWithPgQuery(normalizedClause)
if pgNormalizedClause != "" {
return fmt.Sprintf("CHECK (%s)", pgNormalizedClause)
}

// Fallback to legacy normalization result if pg_query fails
return fmt.Sprintf("CHECK (%s)", normalizedClause)
}

// normalizeExpressionWithPgQuery normalizes an expression using PostgreSQL's parser
func normalizeExpressionWithPgQuery(expr string) string {
// Create a dummy SELECT statement with the expression to parse it
dummySQL := fmt.Sprintf("SELECT %s", expr)

parseResult, err := pg_query.Parse(dummySQL)
if err != nil {
// If parsing fails, return empty string to trigger fallback
return ""
}

// Deparse to get normalized form
deparsed, err := pg_query.Deparse(parseResult)
if err != nil {
return ""
}

// Extract the expression from "SELECT expr" format
if after, found := strings.CutPrefix(deparsed, "SELECT "); found {
normalized := strings.TrimSpace(after)
// Remove redundant numeric type casts from literals
normalized = removeRedundantNumericCasts(normalized)
return normalized
}

return ""
}

// removeRedundantNumericCasts removes type casts from numeric literals
// e.g., "0::numeric" -> "0", "123::integer" -> "123"
func removeRedundantNumericCasts(expr string) string {
// Pattern: number::numeric_type -> number
// This handles: 0::numeric, 123::integer, 45.67::numeric, etc.
patterns := []string{
`(\d+(?:\.\d+)?)::numeric\b`,
`(\d+)::integer\b`,
`(\d+)::bigint\b`,
`(\d+)::smallint\b`,
`(\d+(?:\.\d+)?)::decimal\b`,
`(\d+(?:\.\d+)?)::real\b`,
`(\d+(?:\.\d+)?)::double\s+precision\b`,
}

result := expr
for _, pattern := range patterns {
re := regexp.MustCompile(pattern)
result = re.ReplaceAllString(result, "$1")
}

return result
}

// applyLegacyCheckNormalizations applies the existing normalization patterns
func applyLegacyCheckNormalizations(clause string) string {
// Convert PostgreSQL's "= ANY (ARRAY[...])" format to "IN (...)" format
Expand All @@ -944,66 +856,6 @@ func applyLegacyCheckNormalizations(clause string) string {
return clause
}

// removeUnnecessaryTableQualifiers removes table qualifiers from column references
// when they are unambiguous (i.e., when there's only one table in the FROM clause)
func removeUnnecessaryTableQualifiers(definition string) (string, error) {
// Parse the SQL definition to validate and extract table information
parseResult, err := pg_query.Parse(definition)
if err != nil {
return definition, err
}

if len(parseResult.Stmts) == 0 {
return definition, fmt.Errorf("no statements found")
}

// Get the first statement (should be a SELECT)
stmt := parseResult.Stmts[0]
selectStmt := stmt.Stmt.GetSelectStmt()
if selectStmt == nil {
return definition, fmt.Errorf("not a SELECT statement")
}

// Extract table names from FROM clause
tables := extractTablesFromFromClause(selectStmt.FromClause)

// If there's more than one table, keep qualifiers as they might be necessary
if len(tables) != 1 {
return definition, fmt.Errorf("multiple tables found, keeping original")
}

tableName := tables[0]

// Use regex-based replacement to preserve formatting while removing qualifiers
// This approach maintains the original PostgreSQL pretty-printing format
qualifierRegex := regexp.MustCompile(`\b` + regexp.QuoteMeta(tableName) + `\.([a-zA-Z_][a-zA-Z0-9_]*)\b`)
normalized := qualifierRegex.ReplaceAllString(definition, "$1")

return normalized, nil
}

// extractTablesFromFromClause extracts table names or aliases from the FROM clause
func extractTablesFromFromClause(fromClause []*pg_query.Node) []string {
var tables []string

for _, fromItem := range fromClause {
if rangeVar := fromItem.GetRangeVar(); rangeVar != nil {
if rangeVar.Relname != "" {
// Use alias if present, otherwise use the table name
if rangeVar.Alias != nil && rangeVar.Alias.Aliasname != "" {
tables = append(tables, rangeVar.Alias.Aliasname)
} else {
tables = append(tables, rangeVar.Relname)
}
}
}
// TODO: Handle other FROM clause types like JOINs, subqueries, etc.
// For now, we only handle simple table references
}

return tables
}

// convertAnyArrayToIn converts PostgreSQL's "column = ANY (ARRAY[...])" format
// to the more readable "column IN (...)" format
func convertAnyArrayToIn(expr string) string {
Expand Down Expand Up @@ -1042,141 +894,3 @@ func convertAnyArrayToIn(expr string) string {
return fmt.Sprintf("%s IN (%s)", columnName, strings.Join(cleanValues, ", "))
}

// normalizeViewWithAST applies all AST-based normalizations in a single pass
// This includes converting "= ANY (ARRAY[...])" to "IN (...)" and normalizing ORDER BY
func normalizeViewWithAST(definition string, viewSchema string) string {
if definition == "" {
return definition
}

// Parse the view definition
parseResult, err := pg_query.Parse(definition)
if err != nil {
return definition
}

if len(parseResult.Stmts) == 0 {
return definition
}

stmt := parseResult.Stmts[0]
selectStmt := stmt.Stmt.GetSelectStmt()
if selectStmt == nil {
return definition
}

// Step 1: Normalize ORDER BY clauses (modify AST if needed)
if len(selectStmt.SortClause) > 0 {
// Build reverse alias map (expression -> alias) from target list
exprToAliasMap := buildExpressionToAliasMap(selectStmt.TargetList)

// Transform ORDER BY clauses: replace complex expressions with aliases when possible
for _, sortItem := range selectStmt.SortClause {
if sortBy := sortItem.GetSortBy(); sortBy != nil {
normalizeOrderByExpressionToAlias(sortBy, exprToAliasMap)
}
}
}

// Step 2: Check if we need to use custom formatter for normalization
// Use custom formatter only if the view definition contains "= ANY" (needs conversion to IN)
// For other cases, preserve the original definition to avoid breaking complex expressions
if strings.Contains(definition, "= ANY") {
// Use custom formatter to normalize the query
// The formatter will handle:
// - Converting "= ANY (ARRAY[...])" to "IN (...)"
// - Proper formatting of all expressions
// - Applying proper schema qualification rules
formatter := newPostgreSQLFormatter(viewSchema)
formatted := formatter.formatQueryNode(stmt.Stmt)
if formatted != "" {
return formatted
}
}

return definition
}

// buildExpressionToAliasMap creates a map from expression fingerprints to their aliases
// This helps convert ORDER BY expressions back to column aliases
func buildExpressionToAliasMap(targetList []*pg_query.Node) map[string]string {
exprToAlias := make(map[string]string)

for _, target := range targetList {
if resTarget := target.GetResTarget(); resTarget != nil && resTarget.Name != "" && resTarget.Val != nil {
// Create a fingerprint of the expression by deparsing it
if fingerprint := getExpressionFingerprint(resTarget.Val); fingerprint != "" {
exprToAlias[fingerprint] = resTarget.Name
}
}
}

return exprToAlias
}

// normalizeOrderByExpressionToAlias converts ORDER BY expressions back to aliases when possible
// Returns true if the expression was modified
func normalizeOrderByExpressionToAlias(sortBy *pg_query.SortBy, exprToAliasMap map[string]string) bool {
if sortBy.Node == nil {
return false
}

// Get the fingerprint of the current ORDER BY expression
fingerprint := getExpressionFingerprint(sortBy.Node)
if fingerprint == "" {
return false
}

// Check if this expression matches one of our aliased expressions
if alias, exists := exprToAliasMap[fingerprint]; exists {
// Replace the complex expression with a simple ColumnRef to the alias
sortBy.Node = &pg_query.Node{
Node: &pg_query.Node_ColumnRef{
ColumnRef: &pg_query.ColumnRef{
Fields: []*pg_query.Node{{
Node: &pg_query.Node_String_{
String_: &pg_query.String{Sval: alias},
},
}},
},
},
}
return true
}

return false
}

// getExpressionFingerprint creates a normalized fingerprint of an expression
// This is used to match expressions between SELECT list and ORDER BY
func getExpressionFingerprint(expr *pg_query.Node) string {
if expr == nil {
return ""
}

// Create a temporary SELECT statement with just this expression to deparse it
tempSelect := &pg_query.SelectStmt{
TargetList: []*pg_query.Node{{
Node: &pg_query.Node_ResTarget{
ResTarget: &pg_query.ResTarget{Val: expr},
},
}},
}
tempResult := &pg_query.ParseResult{
Stmts: []*pg_query.RawStmt{{
Stmt: &pg_query.Node{
Node: &pg_query.Node_SelectStmt{SelectStmt: tempSelect},
},
}},
}

if deparsed, err := pg_query.Deparse(tempResult); err == nil {
// Extract just the expression part from "SELECT expression"
if expr, found := strings.CutPrefix(deparsed, "SELECT "); found {
// Normalize the fingerprint by removing extra whitespace and lowercasing
return strings.ToLower(strings.ReplaceAll(strings.TrimSpace(expr), " ", ""))
}
}

return ""
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ CREATE TABLE IF NOT EXISTS departments (
created_at timestamp DEFAULT now(),
CONSTRAINT departments_pkey PRIMARY KEY (id),
CONSTRAINT departments_company_id_fkey FOREIGN KEY (company_id) REFERENCES companies (id),
CONSTRAINT departments_budget_check CHECK (budget > 0)
CONSTRAINT departments_budget_check CHECK (budget > 0::numeric)
);

CREATE INDEX IF NOT EXISTS idx_departments_name ON departments (name);
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"path": "public.companies"
},
{
"sql": "CREATE TABLE IF NOT EXISTS departments (\n id integer,\n name text NOT NULL,\n company_id integer NOT NULL,\n budget numeric(10,2),\n created_at timestamp DEFAULT now(),\n CONSTRAINT departments_pkey PRIMARY KEY (id),\n CONSTRAINT departments_company_id_fkey FOREIGN KEY (company_id) REFERENCES companies (id),\n CONSTRAINT departments_budget_check CHECK (budget > 0)\n);",
"sql": "CREATE TABLE IF NOT EXISTS departments (\n id integer,\n name text NOT NULL,\n company_id integer NOT NULL,\n budget numeric(10,2),\n created_at timestamp DEFAULT now(),\n CONSTRAINT departments_pkey PRIMARY KEY (id),\n CONSTRAINT departments_company_id_fkey FOREIGN KEY (company_id) REFERENCES companies (id),\n CONSTRAINT departments_budget_check CHECK (budget > 0::numeric)\n);",
"type": "table",
"operation": "create",
"path": "public.departments"
Expand Down
Loading