From c8c364810b15d78faf3145710eadcc94ee03035f Mon Sep 17 00:00:00 2001 From: tianzhou Date: Thu, 23 Oct 2025 17:14:50 +0800 Subject: [PATCH] chore: simplify view comparison and remove pg_query_go --- go.mod | 2 - go.sum | 9 - internal/diff/view.go | 757 +------------------------- internal/diff/view_comparison_test.go | 569 ------------------- 4 files changed, 7 insertions(+), 1330 deletions(-) delete mode 100644 internal/diff/view_comparison_test.go diff --git a/go.mod b/go.mod index 11fdd51c..21da726b 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,6 @@ require ( github.com/google/go-cmp v0.7.0 github.com/jackc/pgx/v5 v5.7.5 github.com/joho/godotenv v1.5.1 - github.com/pganalyze/pg_query_go/v6 v6.1.0 github.com/pgschema/pgschema/ir v0.0.0 github.com/spf13/cobra v1.9.1 ) @@ -27,7 +26,6 @@ require ( golang.org/x/crypto v0.37.0 // indirect golang.org/x/sync v0.17.0 // indirect golang.org/x/text v0.24.0 // indirect - google.golang.org/protobuf v1.36.5 // indirect ) replace github.com/pgschema/pgschema/ir => ./ir diff --git a/go.sum b/go.sum index b8bf8b5e..0f9521b6 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fergusstrange/embedded-postgres v1.29.0 h1:Uv8hdhoiaNMuH0w8UuGXDHr60VoAQPFdgx7Qf3bzXJM= github.com/fergusstrange/embedded-postgres v1.29.0/go.mod h1:t/MLs0h9ukYM6FSt99R7InCHs1nW0ordoVCcnzmpTYw= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -24,8 +22,6 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/pganalyze/pg_query_go/v6 v6.1.0 h1:jG5ZLhcVgL1FAw4C/0VNQaVmX1SUJx71wBGdtTtBvls= -github.com/pganalyze/pg_query_go/v6 v6.1.0/go.mod h1:nvTHIuoud6e1SfrUaFwHqT0i4b5Nr+1rPWVds3B5+50= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -48,11 +44,6 @@ golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= -google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/diff/view.go b/internal/diff/view.go index 05d457b4..14c9fee1 100644 --- a/internal/diff/view.go +++ b/internal/diff/view.go @@ -4,7 +4,6 @@ import ( "fmt" "strings" - pg_query "github.com/pganalyze/pg_query_go/v6" "github.com/pgschema/pgschema/ir" ) @@ -136,8 +135,9 @@ func generateModifyViewsSQL(diffs []*viewDiff, targetSchema string, collector *d continue // Skip the normal processing for this view } - // Check if only the comment changed and definition is semantically identical - definitionsEqual := diff.Old.Definition == diff.New.Definition || compareViewDefinitionsSemantically(diff.Old.Definition, diff.New.Definition) + // Check if only the comment changed and definition is identical + // Both IRs come from pg_get_viewdef() at the same PostgreSQL version, so string comparison is sufficient + definitionsEqual := diff.Old.Definition == diff.New.Definition commentOnlyChange := diff.CommentChanged && definitionsEqual && diff.Old.Materialized == diff.New.Materialized // Check if only indexes changed (for materialized views) @@ -305,7 +305,8 @@ func generateViewSQL(view *ir.View, targetSchema string) string { return fmt.Sprintf("%s %s AS\n%s;", createClause, viewName, view.Definition) } -// viewsEqual compares two views for equality using semantic comparison +// viewsEqual compares two views for equality +// Both IRs come from pg_get_viewdef() at the same PostgreSQL version, so string comparison is sufficient func viewsEqual(old, new *ir.View) bool { if old.Schema != new.Schema { return false @@ -319,13 +320,8 @@ func viewsEqual(old, new *ir.View) bool { return false } - // Quick path: if string definitions are identical, they're equal - if old.Definition == new.Definition { - return true - } - - // Use semantic comparison using AST analysis (assumes valid SQL) - return compareViewDefinitionsSemantically(old.Definition, new.Definition) + // Both definitions come from pg_get_viewdef(), so they are already normalized + return old.Definition == new.Definition } // viewDependsOnView checks if viewA depends on viewB @@ -334,742 +330,3 @@ func viewDependsOnView(viewA *ir.View, viewBName string) bool { // This can be enhanced with proper dependency parsing later return strings.Contains(strings.ToLower(viewA.Definition), strings.ToLower(viewBName)) } - -// compareViewDefinitionsSemantically compares two SQL view definitions semantically -// using AST comparison rather than string comparison to handle formatting differences -// Assumes valid SQL syntax is always passed -func compareViewDefinitionsSemantically(def1, def2 string) bool { - if def1 == def2 { - return true // Quick path for identical strings - } - - // Parse both definitions into ASTs (assuming valid SQL) - result1, err1 := pg_query.Parse(def1) - result2, err2 := pg_query.Parse(def2) - - if err1 != nil || err2 != nil { - return false - } - - // Both should have exactly one statement (the SELECT for the view) - if len(result1.Stmts) != 1 || len(result2.Stmts) != 1 { - return false - } - - // Compare the SELECT statements semantically - equal, _ := compareSelectStatements(result1.Stmts[0], result2.Stmts[0]) - return equal -} - -// compareSelectStatements compares two SELECT statement ASTs for semantic equivalence -func compareSelectStatements(stmt1, stmt2 *pg_query.RawStmt) (bool, error) { - // Extract SelectStmt from RawStmt - selectStmt1 := stmt1.Stmt.GetSelectStmt() - selectStmt2 := stmt2.Stmt.GetSelectStmt() - - if selectStmt1 == nil || selectStmt2 == nil { - return false, fmt.Errorf("expected SELECT statements") - } - - // Compare key components of SELECT statements - if !compareTargetLists(selectStmt1.TargetList, selectStmt2.TargetList) { - return false, nil - } - - if !compareFromClauses(selectStmt1.FromClause, selectStmt2.FromClause) { - return false, nil - } - - if !compareWhereClauses(selectStmt1.WhereClause, selectStmt2.WhereClause) { - return false, nil - } - - // Compare GROUP BY clause - if !compareGroupByClauses(selectStmt1.GroupClause, selectStmt2.GroupClause) { - return false, nil - } - - // Compare HAVING clause - if !compareHavingClauses(selectStmt1.HavingClause, selectStmt2.HavingClause) { - return false, nil - } - - // Compare ORDER BY clause - if !compareSortClauses(selectStmt1.SortClause, selectStmt2.SortClause) { - return false, nil - } - - return true, nil -} - -// compareTargetLists compares SELECT target lists (column expressions) -func compareTargetLists(list1, list2 []*pg_query.Node) bool { - if len(list1) != len(list2) { - return false - } - - for i, target1 := range list1 { - target2 := list2[i] - if !compareResTargets(target1.GetResTarget(), target2.GetResTarget()) { - return false - } - } - - return true -} - -// compareResTargets compares individual SELECT targets (columns/expressions) -func compareResTargets(target1, target2 *pg_query.ResTarget) bool { - if target1 == nil || target2 == nil { - return target1 == target2 - } - - // Compare target names (aliases) - if target1.Name != target2.Name { - return false - } - - // Compare target expressions - return compareExpressions(target1.Val, target2.Val) -} - -// compareFromClauses compares FROM clauses including JOINs -func compareFromClauses(from1, from2 []*pg_query.Node) bool { - if len(from1) != len(from2) { - return false - } - - for i, node1 := range from1 { - node2 := from2[i] - if !compareFromClauseNode(node1, node2) { - return false - } - } - - return true -} - -// compareFromClauseNode compares individual FROM clause nodes (tables, JOINs, etc.) -func compareFromClauseNode(node1, node2 *pg_query.Node) bool { - // Handle JoinExpr (the main case we're fixing) - if join1 := node1.GetJoinExpr(); join1 != nil { - join2 := node2.GetJoinExpr() - if join2 == nil { - return false - } - return compareJoinExprs(join1, join2) - } - - // Handle RangeVar (simple table references) - if rangeVar1 := node1.GetRangeVar(); rangeVar1 != nil { - rangeVar2 := node2.GetRangeVar() - if rangeVar2 == nil { - return false - } - return compareRangeVars(rangeVar1, rangeVar2) - } - - // TODO: Add other FROM clause node types as needed - - return false -} - -// compareJoinExprs compares JOIN expressions - this is the key function for our issue -func compareJoinExprs(join1, join2 *pg_query.JoinExpr) bool { - if join1 == nil || join2 == nil { - return join1 == join2 - } - - // Compare join type - if join1.Jointype != join2.Jointype { - return false - } - - // Compare left and right operands - if !compareFromClauseNode(join1.Larg, join2.Larg) { - return false - } - - if !compareFromClauseNode(join1.Rarg, join2.Rarg) { - return false - } - - // Compare join conditions - this is where the parentheses differences occur - return compareExpressions(join1.Quals, join2.Quals) -} - -// compareRangeVars compares table references -func compareRangeVars(rv1, rv2 *pg_query.RangeVar) bool { - if rv1 == nil || rv2 == nil { - return rv1 == rv2 - } - - // Normalize schema names - empty string should be treated as "public" - schema1 := rv1.Schemaname - schema2 := rv2.Schemaname - if schema1 == "" { - schema1 = "public" - } - if schema2 == "" { - schema2 = "public" - } - - // Compare normalized schema and table names - return schema1 == schema2 && - rv1.Relname == rv2.Relname && - rv1.Alias.GetAliasname() == rv2.Alias.GetAliasname() -} - -// compareExpressions compares SQL expressions semantically -func compareExpressions(expr1, expr2 *pg_query.Node) bool { - if expr1 == nil || expr2 == nil { - return expr1 == expr2 - } - - // Handle TypeCast expressions using normalized comparison - if expr1.GetTypeCast() != nil || expr2.GetTypeCast() != nil { - return compareExpressionsWithTypeCast(expr1, expr2) - } - - // Handle BoolExpr (AND, OR, NOT) - if boolExpr1 := expr1.GetBoolExpr(); boolExpr1 != nil { - boolExpr2 := expr2.GetBoolExpr() - if boolExpr2 == nil { - return false - } - return compareBoolExprs(boolExpr1, boolExpr2) - } - - // Handle A_Expr (comparison operators like =, <, >, etc.) - if aExpr1 := expr1.GetAExpr(); aExpr1 != nil { - aExpr2 := expr2.GetAExpr() - if aExpr2 == nil { - return false - } - return compareAExprs(aExpr1, aExpr2) - } - - // Handle ColumnRef (column references) - if colRef1 := expr1.GetColumnRef(); colRef1 != nil { - colRef2 := expr2.GetColumnRef() - if colRef2 == nil { - return false - } - return compareColumnRefs(colRef1, colRef2) - } - - // Handle A_Const (constants) - if const1 := expr1.GetAConst(); const1 != nil { - const2 := expr2.GetAConst() - if const2 == nil { - return false - } - return compareAConsts(const1, const2) - } - - // Handle FuncCall (function calls) - if funcCall1 := expr1.GetFuncCall(); funcCall1 != nil { - funcCall2 := expr2.GetFuncCall() - if funcCall2 == nil { - return false - } - return compareFuncCalls(funcCall1, funcCall2) - } - - // Handle CaseExpr (CASE expressions) - if caseExpr1 := expr1.GetCaseExpr(); caseExpr1 != nil { - caseExpr2 := expr2.GetCaseExpr() - if caseExpr2 == nil { - return false - } - return compareCaseExprs(caseExpr1, caseExpr2) - } - - // Handle CoalesceExpr (COALESCE expressions) - if coalesceExpr1 := expr1.GetCoalesceExpr(); coalesceExpr1 != nil { - coalesceExpr2 := expr2.GetCoalesceExpr() - if coalesceExpr2 == nil { - return false - } - return compareCoalesceExprs(coalesceExpr1, coalesceExpr2) - } - - // Handle NullTest (IS NULL, IS NOT NULL) - if nullTest1 := expr1.GetNullTest(); nullTest1 != nil { - nullTest2 := expr2.GetNullTest() - if nullTest2 == nil { - return false - } - return compareNullTests(nullTest1, nullTest2) - } - - // TODO: Add other expression types as needed - - return false -} - -// compareBoolExprs compares boolean expressions (AND, OR, NOT) -func compareBoolExprs(bool1, bool2 *pg_query.BoolExpr) bool { - if bool1 == nil || bool2 == nil { - return bool1 == bool2 - } - - // Must have same boolean operation type - if bool1.Boolop != bool2.Boolop { - return false - } - - // Must have same number of arguments - if len(bool1.Args) != len(bool2.Args) { - return false - } - - // Compare each argument - for i, arg1 := range bool1.Args { - arg2 := bool2.Args[i] - if !compareExpressions(arg1, arg2) { - return false - } - } - - return true -} - -// compareAExprs compares A_Expr nodes (comparison operators) -func compareAExprs(expr1, expr2 *pg_query.A_Expr) bool { - if expr1 == nil || expr2 == nil { - return expr1 == expr2 - } - - // Compare operator names - if !compareOperatorNames(expr1.Name, expr2.Name) { - return false - } - - // Compare left and right operands - return compareExpressions(expr1.Lexpr, expr2.Lexpr) && - compareExpressions(expr1.Rexpr, expr2.Rexpr) -} - -// compareOperatorNames compares operator names -func compareOperatorNames(names1, names2 []*pg_query.Node) bool { - if len(names1) != len(names2) { - return false - } - - for i, name1 := range names1 { - name2 := names2[i] - str1 := name1.GetString_() - str2 := name2.GetString_() - if str1 == nil || str2 == nil || str1.Sval != str2.Sval { - return false - } - } - - return true -} - -// compareColumnRefs compares column references -func compareColumnRefs(col1, col2 *pg_query.ColumnRef) bool { - if col1 == nil || col2 == nil { - return col1 == col2 - } - - // Quick path: if they have same structure, compare directly - if len(col1.Fields) == len(col2.Fields) { - allMatch := true - for i, field1 := range col1.Fields { - field2 := col2.Fields[i] - str1 := field1.GetString_() - str2 := field2.GetString_() - if str1 == nil || str2 == nil || str1.Sval != str2.Sval { - allMatch = false - break - } - } - if allMatch { - return true - } - } - - // Handle alias expansion: compare "alias.column" vs "column" - // Extract the final column name from each reference - colName1 := getColumnName(col1) - colName2 := getColumnName(col2) - - // If the column names match, consider them equivalent - // This handles cases like "e.id" vs "id" - return colName1 == colName2 -} - -// getColumnName extracts the final column name from a ColumnRef -func getColumnName(colRef *pg_query.ColumnRef) string { - if colRef == nil || len(colRef.Fields) == 0 { - return "" - } - - // Get the last field (the actual column name) - lastField := colRef.Fields[len(colRef.Fields)-1] - if str := lastField.GetString_(); str != nil { - return str.Sval - } - - return "" -} - -// compareNullTests compares NULL test expressions (IS NULL, IS NOT NULL) -func compareNullTests(null1, null2 *pg_query.NullTest) bool { - if null1 == nil || null2 == nil { - return null1 == null2 - } - - // Must have the same null test type (IS NULL vs IS NOT NULL) - if null1.Nulltesttype != null2.Nulltesttype { - return false - } - - // Compare the argument expressions - return compareExpressions(null1.Arg, null2.Arg) -} - -// compareAConsts compares constant values -func compareAConsts(const1, const2 *pg_query.A_Const) bool { - if const1 == nil || const2 == nil { - return const1 == const2 - } - - // Compare the actual values, not the string representation (which includes location info) - switch val1 := const1.Val.(type) { - case *pg_query.A_Const_Sval: - if val2, ok := const2.Val.(*pg_query.A_Const_Sval); ok { - return val1.Sval.Sval == val2.Sval.Sval - } - case *pg_query.A_Const_Ival: - if val2, ok := const2.Val.(*pg_query.A_Const_Ival); ok { - return val1.Ival.Ival == val2.Ival.Ival - } - case *pg_query.A_Const_Fval: - if val2, ok := const2.Val.(*pg_query.A_Const_Fval); ok { - return val1.Fval.Fval == val2.Fval.Fval - } - case *pg_query.A_Const_Boolval: - if val2, ok := const2.Val.(*pg_query.A_Const_Boolval); ok { - return val1.Boolval.Boolval == val2.Boolval.Boolval - } - case *pg_query.A_Const_Bsval: - if val2, ok := const2.Val.(*pg_query.A_Const_Bsval); ok { - return val1.Bsval.Bsval == val2.Bsval.Bsval - } - } - - // Fallback to string comparison if types don't match or are unknown - return const1.String() == const2.String() -} - -// compareWhereClauses compares WHERE clauses -func compareWhereClauses(where1, where2 *pg_query.Node) bool { - return compareExpressions(where1, where2) -} - -// compareGroupByClauses compares GROUP BY clauses -func compareGroupByClauses(group1, group2 []*pg_query.Node) bool { - if len(group1) != len(group2) { - return false - } - - for i, expr1 := range group1 { - expr2 := group2[i] - if !compareExpressions(expr1, expr2) { - return false - } - } - - return true -} - -// compareHavingClauses compares HAVING clauses -func compareHavingClauses(having1, having2 *pg_query.Node) bool { - return compareExpressions(having1, having2) -} - -// compareSortClauses compares ORDER BY clauses -func compareSortClauses(sort1, sort2 []*pg_query.Node) bool { - if len(sort1) != len(sort2) { - return false - } - - for i, node1 := range sort1 { - node2 := sort2[i] - if !compareSortBy(node1.GetSortBy(), node2.GetSortBy()) { - return false - } - } - - return true -} - -// compareSortBy compares individual sort specifications -func compareSortBy(sort1, sort2 *pg_query.SortBy) bool { - if sort1 == nil || sort2 == nil { - return sort1 == sort2 - } - - // Compare sort expression - if !compareExpressions(sort1.Node, sort2.Node) { - return false - } - - // Compare sort direction - if sort1.SortbyDir != sort2.SortbyDir { - return false - } - - // Compare null ordering - if sort1.SortbyNulls != sort2.SortbyNulls { - return false - } - - return true -} - -// compareFuncCalls compares function call expressions -func compareFuncCalls(func1, func2 *pg_query.FuncCall) bool { - if func1 == nil || func2 == nil { - return func1 == func2 - } - - // Compare function names - if !compareFuncNames(func1.Funcname, func2.Funcname) { - return false - } - - // Compare arguments - if len(func1.Args) != len(func2.Args) { - return false - } - - for i, arg1 := range func1.Args { - arg2 := func2.Args[i] - if !compareExpressions(arg1, arg2) { - return false - } - } - - // Ignore other function properties like location, agg_star for now - // We can add them later if needed - - return true -} - -// compareFuncNames compares function name lists -func compareFuncNames(names1, names2 []*pg_query.Node) bool { - if len(names1) != len(names2) { - return false - } - - for i, name1 := range names1 { - name2 := names2[i] - str1 := name1.GetString_() - str2 := name2.GetString_() - if str1 == nil || str2 == nil || str1.Sval != str2.Sval { - return false - } - } - - return true -} - -// compareCaseExprs compares CASE expressions -func compareCaseExprs(case1, case2 *pg_query.CaseExpr) bool { - if case1 == nil || case2 == nil { - return case1 == case2 - } - - // Compare the case expression argument (the expression after CASE, if any) - if !compareExpressions(case1.Arg, case2.Arg) { - return false - } - - // Compare WHEN clauses - if len(case1.Args) != len(case2.Args) { - return false - } - - for i, when1 := range case1.Args { - when2 := case2.Args[i] - if !compareCaseWhenClauses(when1.GetCaseWhen(), when2.GetCaseWhen()) { - return false - } - } - - // Compare ELSE clause (default result) - return compareExpressions(case1.Defresult, case2.Defresult) -} - -// compareCaseWhenClauses compares individual WHEN clauses in CASE expressions -func compareCaseWhenClauses(when1, when2 *pg_query.CaseWhen) bool { - if when1 == nil || when2 == nil { - return when1 == when2 - } - - // Compare the WHEN condition - if !compareExpressions(when1.Expr, when2.Expr) { - return false - } - - // Compare the THEN result - return compareExpressions(when1.Result, when2.Result) -} - -// compareCoalesceExprs compares COALESCE expressions -func compareCoalesceExprs(coalesce1, coalesce2 *pg_query.CoalesceExpr) bool { - if coalesce1 == nil || coalesce2 == nil { - return coalesce1 == coalesce2 - } - - // Compare number of arguments - if len(coalesce1.Args) != len(coalesce2.Args) { - return false - } - - // Compare each argument - for i, arg1 := range coalesce1.Args { - if !compareExpressions(arg1, coalesce2.Args[i]) { - return false - } - } - - return true -} - -// compareExpressionsWithTypeCast compares expressions where at least one has a type cast -// This handles PostgreSQL's automatic type casting behavior in a normalized way -func compareExpressionsWithTypeCast(expr1, expr2 *pg_query.Node) bool { - typeCast1 := expr1.GetTypeCast() - typeCast2 := expr2.GetTypeCast() - - // Case 1: Both expressions are TypeCasts - if typeCast1 != nil && typeCast2 != nil { - return compareTypeCasts(typeCast1, typeCast2) - } - - // Case 2: Only one expression is a TypeCast - if typeCast1 != nil { - // expr1 is TypeCast, expr2 is not - argCompare := compareExpressions(typeCast1.Arg, expr2) - if argCompare { - return isImplicitCast(typeCast1) - } - return false - } - - if typeCast2 != nil { - // expr2 is TypeCast, expr1 is not - argCompare := compareExpressions(expr1, typeCast2.Arg) - if argCompare { - return isImplicitCast(typeCast2) - } - return false - } - - // This should never happen as we check for TypeCast existence before calling this function - return false -} - -// compareTypeCasts compares two TypeCast expressions -func compareTypeCasts(cast1, cast2 *pg_query.TypeCast) bool { - if cast1 == nil || cast2 == nil { - return cast1 == cast2 - } - - // Compare the arguments being cast - if !compareExpressions(cast1.Arg, cast2.Arg) { - return false - } - - // Compare the target types - consider compatible types as equivalent - return areCompatibleTypes(cast1.TypeName, cast2.TypeName) -} - -// isImplicitCast checks if a type cast is likely an implicit cast added by PostgreSQL -func isImplicitCast(typeCast *pg_query.TypeCast) bool { - if typeCast.TypeName == nil || len(typeCast.TypeName.Names) == 0 { - return false - } - - // Get the target type name - var typeName string - if str := typeCast.TypeName.Names[len(typeCast.TypeName.Names)-1].GetString_(); str != nil { - typeName = str.Sval - } - - // PostgreSQL commonly adds these implicit casts - implicitCastTypes := map[string]bool{ - "text": true, - "varchar": true, - "character": true, - "character varying": true, - "char": true, - "int4": true, - "int8": true, - "integer": true, - "bigint": true, - "numeric": true, - "bool": true, - "boolean": true, - "regconfig": true, - "regclass": true, - "oid": true, - } - - return implicitCastTypes[typeName] -} - -// areCompatibleTypes checks if two type names are compatible for comparison -func areCompatibleTypes(type1, type2 *pg_query.TypeName) bool { - if type1 == nil || type2 == nil { - return type1 == type2 - } - - // Extract type names - typeName1 := getTypeName(type1) - typeName2 := getTypeName(type2) - - // Exact match - if typeName1 == typeName2 { - return true - } - - // Check for compatible text types - textTypes := map[string]bool{ - "text": true, "varchar": true, "char": true, "character varying": true, - } - if textTypes[typeName1] && textTypes[typeName2] { - return true - } - - // Check for compatible integer types - intTypes := map[string]bool{ - "int4": true, "integer": true, "int": true, - "int8": true, "bigint": true, - } - if intTypes[typeName1] && intTypes[typeName2] { - return true - } - - return false -} - -// getTypeName extracts the type name from a TypeName node -func getTypeName(typeName *pg_query.TypeName) string { - if typeName == nil || len(typeName.Names) == 0 { - return "" - } - - // Get the last name in the list (the actual type name) - if str := typeName.Names[len(typeName.Names)-1].GetString_(); str != nil { - return str.Sval - } - - return "" -} diff --git a/internal/diff/view_comparison_test.go b/internal/diff/view_comparison_test.go deleted file mode 100644 index bd59cb83..00000000 --- a/internal/diff/view_comparison_test.go +++ /dev/null @@ -1,569 +0,0 @@ -package diff - -import ( - "testing" - - "github.com/pgschema/pgschema/ir" -) - -func TestViewSemanticComparison(t *testing.T) { - tests := []struct { - name string - definition1 string - definition2 string - expectEqual bool - }{ - { - name: "identical views", - definition1: ` SELECT - emp_no, - max(from_date) AS from_date, - max(to_date) AS to_date - FROM dept_emp - GROUP BY emp_no`, - definition2: ` SELECT - emp_no, - max(from_date) AS from_date, - max(to_date) AS to_date - FROM dept_emp - GROUP BY emp_no`, - expectEqual: true, - }, - { - name: "formatting differences - semicolon and line breaks", - definition1: ` SELECT emp_no, - max(from_date) AS from_date, - max(to_date) AS to_date - FROM dept_emp - GROUP BY emp_no;`, - definition2: ` SELECT - emp_no, - max(from_date) AS from_date, - max(to_date) AS to_date - FROM dept_emp - GROUP BY emp_no`, - expectEqual: true, - }, - { - name: "complex view with joins - formatting differences", - definition1: ` SELECT l.emp_no, - d.dept_no, - l.from_date, - l.to_date - FROM dept_emp d - JOIN dept_emp_latest_date l ON d.emp_no = l.emp_no AND d.from_date = l.from_date AND l.to_date = d.to_date;`, - definition2: ` SELECT - l.emp_no, - d.dept_no, - l.from_date, - l.to_date - FROM dept_emp d - JOIN dept_emp_latest_date l ON d.emp_no = l.emp_no AND d.from_date = l.from_date AND l.to_date = d.to_date`, - expectEqual: true, - }, - { - name: "different column order", - definition1: ` SELECT emp_no, name FROM users`, - definition2: ` SELECT name, emp_no FROM users`, - expectEqual: false, - }, - { - name: "different function", - definition1: ` SELECT emp_no, max(from_date) FROM dept_emp GROUP BY emp_no`, - definition2: ` SELECT emp_no, min(from_date) FROM dept_emp GROUP BY emp_no`, - expectEqual: false, - }, - { - name: "whitespace and indentation differences", - definition1: `SELECT emp_no,max(from_date)AS from_date FROM dept_emp GROUP BY emp_no`, - definition2: ` SELECT - emp_no, - max(from_date) AS from_date - FROM dept_emp - GROUP BY emp_no`, - expectEqual: true, - }, - { - name: "case sensitivity in SQL keywords should be ignored", - definition1: ` select emp_no, max(from_date) as from_date from dept_emp group by emp_no`, - definition2: ` SELECT emp_no, MAX(from_date) AS from_date FROM dept_emp GROUP BY emp_no`, - expectEqual: true, - }, - { - name: "different table names", - definition1: ` SELECT emp_no FROM employees`, - definition2: ` SELECT emp_no FROM users`, - expectEqual: false, - }, - { - name: "different column names", - definition1: ` SELECT emp_no FROM employees`, - definition2: ` SELECT user_id FROM employees`, - expectEqual: false, - }, - { - name: "different WHERE clauses", - definition1: ` SELECT emp_no FROM employees WHERE active = true`, - definition2: ` SELECT emp_no FROM employees WHERE active = false`, - expectEqual: false, - }, - { - name: "missing WHERE clause", - definition1: ` SELECT emp_no FROM employees WHERE active = true`, - definition2: ` SELECT emp_no FROM employees`, - expectEqual: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Test the views as IR objects - view1 := &ir.View{ - Schema: "public", - Name: "test_view", - Definition: tt.definition1, - } - view2 := &ir.View{ - Schema: "public", - Name: "test_view", - Definition: tt.definition2, - } - - result := viewsEqual(view1, view2) - if result != tt.expectEqual { - t.Errorf("viewsEqual() = %v, expected %v", result, tt.expectEqual) - t.Logf("Definition 1:\n%s", tt.definition1) - t.Logf("Definition 2:\n%s", tt.definition2) - } - - // Also test the semantic comparison function directly - semanticResult := compareViewDefinitionsSemantically(tt.definition1, tt.definition2) - if semanticResult != tt.expectEqual { - t.Errorf("compareViewDefinitionsSemantically() = %v, expected %v", semanticResult, tt.expectEqual) - } - }) - } -} - -func TestFunctionCallComparison(t *testing.T) { - // Specific test for the issue we fixed - function calls with different location metadata - tests := []struct { - name string - sql1 string - sql2 string - expectEqual bool - }{ - { - name: "max function with different formatting", - sql1: " SELECT emp_no, max(from_date) AS from_date FROM dept_emp GROUP BY emp_no", - sql2: " SELECT\n emp_no,\n max(from_date) AS from_date\n FROM dept_emp\n GROUP BY emp_no", - expectEqual: true, - }, - { - name: "count function with different formatting", - sql1: " SELECT count(*) FROM users", - sql2: " SELECT\n count(*)\n FROM users", - expectEqual: true, - }, - { - name: "multiple function calls with formatting differences", - sql1: " SELECT count(*), sum(salary), avg(age) FROM employees", - sql2: " SELECT\n count(*),\n sum(salary),\n avg(age)\n FROM employees", - expectEqual: true, - }, - { - name: "nested function calls", - sql1: " SELECT upper(concat(first_name, ' ', last_name)) FROM users", - sql2: " SELECT\n upper(concat(first_name, ' ', last_name))\n FROM users", - expectEqual: true, - }, - { - name: "function with multiple arguments", - sql1: " SELECT substring(name, 1, 10) FROM users", - sql2: " SELECT\n substring(name, 1, 10)\n FROM users", - expectEqual: true, - }, - { - name: "different function names", - sql1: " SELECT max(salary) FROM employees", - sql2: " SELECT min(salary) FROM employees", - expectEqual: false, - }, - { - name: "different function arguments", - sql1: " SELECT substring(name, 1, 10) FROM users", - sql2: " SELECT substring(name, 1, 5) FROM users", - expectEqual: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := compareViewDefinitionsSemantically(tt.sql1, tt.sql2) - if result != tt.expectEqual { - t.Errorf("compareViewDefinitionsSemantically() = %v, expected %v", result, tt.expectEqual) - t.Logf("SQL 1:\n%s", tt.sql1) - t.Logf("SQL 2:\n%s", tt.sql2) - } - }) - } -} - -func TestJoinComparison(t *testing.T) { - tests := []struct { - name string - sql1 string - sql2 string - expectEqual bool - }{ - { - name: "inner join with formatting differences", - sql1: " SELECT u.name, p.title FROM users u JOIN profiles p ON u.id = p.user_id", - sql2: " SELECT\n u.name,\n p.title\n FROM users u\n JOIN profiles p ON u.id = p.user_id", - expectEqual: true, - }, - { - name: "left join with parentheses differences", - sql1: " SELECT u.name FROM users u LEFT JOIN profiles p ON (u.id = p.user_id)", - sql2: " SELECT u.name FROM users u LEFT JOIN profiles p ON u.id = p.user_id", - expectEqual: true, - }, - { - name: "multiple joins with formatting", - sql1: " SELECT u.name, p.title, r.name FROM users u JOIN profiles p ON u.id = p.user_id JOIN roles r ON u.role_id = r.id", - sql2: " SELECT\n u.name,\n p.title,\n r.name\n FROM users u\n JOIN profiles p ON u.id = p.user_id\n JOIN roles r ON u.role_id = r.id", - expectEqual: true, - }, - { - name: "different join types", - sql1: " SELECT u.name FROM users u JOIN profiles p ON u.id = p.user_id", - sql2: " SELECT u.name FROM users u LEFT JOIN profiles p ON u.id = p.user_id", - expectEqual: false, - }, - { - name: "different join conditions", - sql1: " SELECT u.name FROM users u JOIN profiles p ON u.id = p.user_id", - sql2: " SELECT u.name FROM users u JOIN profiles p ON u.email = p.email", - expectEqual: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := compareViewDefinitionsSemantically(tt.sql1, tt.sql2) - if result != tt.expectEqual { - t.Errorf("compareViewDefinitionsSemantically() = %v, expected %v", result, tt.expectEqual) - t.Logf("SQL 1:\n%s", tt.sql1) - t.Logf("SQL 2:\n%s", tt.sql2) - } - }) - } -} - -func TestSubqueryComparison(t *testing.T) { - // Note: Subquery comparison is not yet fully implemented in our semantic comparison logic. - // This test documents the current limitations and expected behavior. - tests := []struct { - name string - sql1 string - sql2 string - expectEqual bool - note string - }{ - { - name: "subquery with formatting differences", - sql1: " SELECT emp_no FROM (SELECT emp_no, salary FROM employees WHERE active = true) sub WHERE salary > 50000", - sql2: " SELECT\n emp_no\n FROM (\n SELECT\n emp_no,\n salary\n FROM employees\n WHERE active = true\n ) sub\n WHERE salary > 50000", - expectEqual: false, // Currently not supported - would require RangeSubselect handling - note: "Subquery formatting differences not yet supported", - }, - { - name: "different subquery conditions", - sql1: " SELECT name FROM users WHERE id IN (SELECT user_id FROM orders WHERE total > 100)", - sql2: " SELECT name FROM users WHERE id IN (SELECT user_id FROM orders WHERE total > 200)", - expectEqual: false, // This should correctly detect differences - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := compareViewDefinitionsSemantically(tt.sql1, tt.sql2) - if result != tt.expectEqual { - if tt.note != "" { - t.Logf("Expected limitation: %s", tt.note) - t.Logf("This is a known limitation in the current implementation") - } else { - t.Errorf("compareViewDefinitionsSemantically() = %v, expected %v", result, tt.expectEqual) - } - t.Logf("SQL 1:\n%s", tt.sql1) - t.Logf("SQL 2:\n%s", tt.sql2) - } - }) - } -} - -func TestGroupByAndHavingComparison(t *testing.T) { - tests := []struct { - name string - sql1 string - sql2 string - expectEqual bool - }{ - { - name: "group by with formatting differences", - sql1: " SELECT dept_id, count(*) FROM employees GROUP BY dept_id", - sql2: " SELECT\n dept_id,\n count(*)\n FROM employees\n GROUP BY dept_id", - expectEqual: true, - }, - { - name: "multiple group by columns", - sql1: " SELECT dept_id, status, count(*) FROM employees GROUP BY dept_id, status", - sql2: " SELECT\n dept_id,\n status,\n count(*)\n FROM employees\n GROUP BY dept_id, status", - expectEqual: true, - }, - { - name: "different group by columns", - sql1: " SELECT dept_id, count(*) FROM employees GROUP BY dept_id", - sql2: " SELECT dept_id, count(*) FROM employees GROUP BY status", - expectEqual: false, - }, - { - name: "missing group by", - sql1: " SELECT dept_id, count(*) FROM employees GROUP BY dept_id", - sql2: " SELECT dept_id, count(*) FROM employees", - expectEqual: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := compareViewDefinitionsSemantically(tt.sql1, tt.sql2) - if result != tt.expectEqual { - t.Errorf("compareViewDefinitionsSemantically() = %v, expected %v", result, tt.expectEqual) - t.Logf("SQL 1:\n%s", tt.sql1) - t.Logf("SQL 2:\n%s", tt.sql2) - } - }) - } -} - -func TestConstantComparison(t *testing.T) { - tests := []struct { - name string - sql1 string - sql2 string - expectEqual bool - }{ - { - name: "string constants with formatting", - sql1: " SELECT name FROM users WHERE status = 'active'", - sql2: " SELECT\n name\n FROM users\n WHERE status = 'active'", - expectEqual: true, - }, - { - name: "numeric constants", - sql1: " SELECT name FROM users WHERE age > 18", - sql2: " SELECT\n name\n FROM users\n WHERE age > 18", - expectEqual: true, - }, - { - name: "boolean constants", - sql1: " SELECT name FROM users WHERE active = true", - sql2: " SELECT\n name\n FROM users\n WHERE active = true", - expectEqual: true, - }, - { - name: "different string constants", - sql1: " SELECT name FROM users WHERE status = 'active'", - sql2: " SELECT name FROM users WHERE status = 'inactive'", - expectEqual: false, - }, - { - name: "different numeric constants", - sql1: " SELECT name FROM users WHERE age > 18", - sql2: " SELECT name FROM users WHERE age > 21", - expectEqual: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := compareViewDefinitionsSemantically(tt.sql1, tt.sql2) - if result != tt.expectEqual { - t.Errorf("compareViewDefinitionsSemantically() = %v, expected %v", result, tt.expectEqual) - t.Logf("SQL 1:\n%s", tt.sql1) - t.Logf("SQL 2:\n%s", tt.sql2) - } - }) - } -} - -func TestColumnAliasComparison(t *testing.T) { - tests := []struct { - name string - sql1 string - sql2 string - expectEqual bool - }{ - { - name: "column alias with AS keyword", - sql1: " SELECT emp_no AS employee_id FROM employees", - sql2: " SELECT\n emp_no AS employee_id\n FROM employees", - expectEqual: true, - }, - { - name: "column alias without AS keyword", - sql1: " SELECT emp_no employee_id FROM employees", - sql2: " SELECT\n emp_no employee_id\n FROM employees", - expectEqual: true, - }, - { - name: "different column aliases", - sql1: " SELECT emp_no AS employee_id FROM employees", - sql2: " SELECT emp_no AS emp_id FROM employees", - expectEqual: false, - }, - { - name: "missing alias", - sql1: " SELECT emp_no AS employee_id FROM employees", - sql2: " SELECT emp_no FROM employees", - expectEqual: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := compareViewDefinitionsSemantically(tt.sql1, tt.sql2) - if result != tt.expectEqual { - t.Errorf("compareViewDefinitionsSemantically() = %v, expected %v", result, tt.expectEqual) - t.Logf("SQL 1:\n%s", tt.sql1) - t.Logf("SQL 2:\n%s", tt.sql2) - } - }) - } -} - -func TestComplexRealWorldViews(t *testing.T) { - tests := []struct { - name string - sql1 string - sql2 string - expectEqual bool - }{ - { - name: "pg_dump style vs manual formatting - dept_emp_latest_date", - sql1: ` SELECT emp_no, - max(from_date) AS from_date, - max(to_date) AS to_date - FROM dept_emp - GROUP BY emp_no;`, - sql2: ` SELECT - emp_no, - max(from_date) AS from_date, - max(to_date) AS to_date - FROM dept_emp - GROUP BY emp_no`, - expectEqual: true, - }, - { - name: "pg_dump style vs manual formatting - current_dept_emp with complex joins", - sql1: ` SELECT l.emp_no, - d.dept_no, - l.from_date, - l.to_date - FROM dept_emp d - JOIN dept_emp_latest_date l ON d.emp_no = l.emp_no AND d.from_date = l.from_date AND l.to_date = d.to_date;`, - sql2: ` SELECT - l.emp_no, - d.dept_no, - l.from_date, - l.to_date - FROM dept_emp d - JOIN dept_emp_latest_date l ON d.emp_no = l.emp_no AND d.from_date = l.from_date AND l.to_date = d.to_date`, - expectEqual: true, - }, - { - name: "view with window functions and formatting differences", - sql1: ` SELECT emp_no, salary, rank() OVER (PARTITION BY dept_id ORDER BY salary DESC) AS salary_rank FROM employees`, - sql2: ` SELECT - emp_no, - salary, - rank() OVER (PARTITION BY dept_id ORDER BY salary DESC) AS salary_rank - FROM employees`, - expectEqual: true, - }, - { - name: "view with CASE expressions", - sql1: ` SELECT emp_no, CASE WHEN salary > 50000 THEN 'high' ELSE 'low' END AS salary_level FROM employees`, - sql2: ` SELECT - emp_no, - CASE - WHEN salary > 50000 THEN 'high' - ELSE 'low' - END AS salary_level - FROM employees`, - expectEqual: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := compareViewDefinitionsSemantically(tt.sql1, tt.sql2) - if result != tt.expectEqual { - t.Errorf("compareViewDefinitionsSemantically() = %v, expected %v", result, tt.expectEqual) - t.Logf("SQL 1:\n%s", tt.sql1) - t.Logf("SQL 2:\n%s", tt.sql2) - } - }) - } -} - -func TestEdgeCases(t *testing.T) { - tests := []struct { - name string - sql1 string - sql2 string - expectEqual bool - note string - }{ - { - name: "empty definitions", - sql1: "", - sql2: "", - expectEqual: true, - }, - { - name: "one empty definition", sql1: " SELECT 1", - sql2: "", - expectEqual: false, - }, - { - name: "whitespace only", - sql1: " ", - sql2: "\n\t \n", - expectEqual: false, // Known limitation: pure whitespace fails parsing - note: "Pure whitespace strings fail SQL parsing", - }, - { - name: "comments should be ignored (if parser handles them)", - sql1: " SELECT emp_no /* comment */ FROM employees", - sql2: " SELECT emp_no FROM employees", - expectEqual: true, // pg_query should strip comments - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := compareViewDefinitionsSemantically(tt.sql1, tt.sql2) - if result != tt.expectEqual { - if tt.note != "" { - t.Logf("Expected limitation: %s", tt.note) - } else { - t.Errorf("compareViewDefinitionsSemantically() = %v, expected %v", result, tt.expectEqual) - } - t.Logf("SQL 1: '%s'", tt.sql1) - t.Logf("SQL 2: '%s'", tt.sql2) - } - }) - } -}