diff --git a/cmd/plan/plan.go b/cmd/plan/plan.go index 2aea9d08..7055a311 100644 --- a/cmd/plan/plan.go +++ b/cmd/plan/plan.go @@ -502,6 +502,10 @@ func normalizeSchemaNames(irData *ir.IR, fromSchema, toSchema string) { for _, param := range fn.Parameters { param.DataType = replaceString(param.DataType) } + // Normalize function dependencies for topological sorting + for i := range fn.Dependencies { + fn.Dependencies[i] = replaceString(fn.Dependencies[i]) + } } // Procedures diff --git a/internal/diff/function.go b/internal/diff/function.go index 48130cd1..3d34ef84 100644 --- a/internal/diff/function.go +++ b/internal/diff/function.go @@ -2,7 +2,6 @@ package diff import ( "fmt" - "sort" "strings" "github.com/pgschema/pgschema/ir" @@ -10,12 +9,8 @@ import ( // generateCreateFunctionsSQL generates CREATE FUNCTION statements func generateCreateFunctionsSQL(functions []*ir.Function, targetSchema string, collector *diffCollector) { - // Sort functions by name for consistent ordering - sortedFunctions := make([]*ir.Function, len(functions)) - copy(sortedFunctions, functions) - sort.Slice(sortedFunctions, func(i, j int) bool { - return sortedFunctions[i].Name < sortedFunctions[j].Name - }) + // Sort functions by dependency order (topological sort) + sortedFunctions := topologicallySortFunctions(functions) for _, function := range sortedFunctions { sql := generateFunctionSQL(function, targetSchema) @@ -127,12 +122,8 @@ func generateModifyFunctionsSQL(diffs []*functionDiff, targetSchema string, coll // generateDropFunctionsSQL generates DROP FUNCTION statements func generateDropFunctionsSQL(functions []*ir.Function, targetSchema string, collector *diffCollector) { - // Sort functions by name for consistent ordering - sortedFunctions := make([]*ir.Function, len(functions)) - copy(sortedFunctions, functions) - sort.Slice(sortedFunctions, func(i, j int) bool { - return sortedFunctions[i].Name < sortedFunctions[j].Name - }) + // Sort functions by reverse dependency order (drop dependents before dependencies) + sortedFunctions := reverseSlice(topologicallySortFunctions(functions)) for _, function := range sortedFunctions { functionName := qualifyEntityName(function.Schema, function.Name, targetSchema) diff --git a/internal/diff/topological.go b/internal/diff/topological.go index 9282173a..3664eb57 100644 --- a/internal/diff/topological.go +++ b/internal/diff/topological.go @@ -403,3 +403,110 @@ func findLastDot(s string) int { } return -1 } + +// topologicallySortFunctions sorts functions across all schemas in dependency order +// Functions that are referenced by other functions will come before the functions that reference them +func topologicallySortFunctions(functions []*ir.Function) []*ir.Function { + if len(functions) <= 1 { + return functions + } + + // Build maps for efficient lookup + funcMap := make(map[string]*ir.Function) + var insertionOrder []string + for _, fn := range functions { + key := fn.Schema + "." + fn.Name + "(" + fn.GetArguments() + ")" + funcMap[key] = fn + insertionOrder = append(insertionOrder, key) + } + + // Build dependency graph + inDegree := make(map[string]int) + adjList := make(map[string][]string) + + // Initialize + for key := range funcMap { + inDegree[key] = 0 + adjList[key] = []string{} + } + + // Build edges: if funcA depends on funcB, add edge funcB -> funcA + for keyA, funcA := range funcMap { + for _, depKey := range funcA.Dependencies { + // depKey is already schema-qualified: schema.name(args) + if _, exists := funcMap[depKey]; exists && keyA != depKey { + adjList[depKey] = append(adjList[depKey], keyA) + inDegree[keyA]++ + } + } + } + + // Kahn's algorithm with deterministic cycle breaking + var queue []string + var result []string + processed := make(map[string]bool, len(funcMap)) + + // Seed queue with nodes that have no incoming edges + for key, degree := range inDegree { + if degree == 0 { + queue = append(queue, key) + } + } + sort.Strings(queue) + + for len(result) < len(funcMap) { + if len(queue) == 0 { + // Cycle detected: pick the next unprocessed function using original insertion order + // + // CYCLE BREAKING STRATEGY FOR FUNCTIONS: + // Setting inDegree[next] = 0 effectively declares "this function has no remaining dependencies" + // for the purpose of breaking the cycle. This is safe because: + // + // 1. The 'processed' map prevents any function from being added to the result twice, even if + // its inDegree becomes zero or negative multiple times (see processed[current] check below). + // + // 2. PostgreSQL allows mutually recursive functions through CREATE OR REPLACE FUNCTION. + // When functions A and B call each other, the creation order doesn't matter because + // PostgreSQL validates function bodies at call time, not at creation time (for most languages). + // + // 3. Using insertion order (alphabetical by schema.name(args)) ensures deterministic output + // when multiple valid orderings exist. + // + // This approach aligns with how PostgreSQL handles function dependencies - it doesn't + // require strict ordering for mutually dependent functions. + next := nextInOrder(insertionOrder, processed) + if next == "" { + break + } + queue = append(queue, next) + inDegree[next] = 0 + } + + current := queue[0] + queue = queue[1:] + if processed[current] { + continue + } + processed[current] = true + result = append(result, current) + + neighbors := append([]string(nil), adjList[current]...) + sort.Strings(neighbors) + + for _, neighbor := range neighbors { + inDegree[neighbor]-- + if inDegree[neighbor] <= 0 && !processed[neighbor] { + queue = append(queue, neighbor) + sort.Strings(queue) + } + } + } + + // Convert result back to function slice + sortedFunctions := make([]*ir.Function, 0, len(result)) + for _, key := range result { + sortedFunctions = append(sortedFunctions, funcMap[key]) + } + + return sortedFunctions +} diff --git a/ir/inspector.go b/ir/inspector.go index 37c481b8..4d39426a 100644 --- a/ir/inspector.go +++ b/ir/inspector.go @@ -120,6 +120,11 @@ func (i *Inspector) BuildIR(ctx context.Context, targetSchema string) (*IR, erro return nil, err } + // Build function dependencies after functions are loaded + if err := i.buildFunctionDependencies(ctx, schema, targetSchema); err != nil { + return nil, err + } + // Group 3 runs after table details are loaded if err := i.executeConcurrentGroup(ctx, schema, targetSchema, group3); err != nil { return nil, err @@ -976,6 +981,46 @@ func (i *Inspector) buildFunctions(ctx context.Context, schema *IR, targetSchema return nil } +func (i *Inspector) buildFunctionDependencies(ctx context.Context, schema *IR, targetSchema string) error { + deps, err := i.queries.GetFunctionDependencies(ctx, sql.NullString{String: targetSchema, Valid: true}) + if err != nil { + return err + } + + dbSchema := schema.Schemas[targetSchema] + if dbSchema == nil { + return nil + } + + // Build a map of dependencies by dependent function key + depMap := make(map[string][]string) + for _, dep := range deps { + dependentArgs := "" + if dep.DependentArgs.Valid { + dependentArgs = dep.DependentArgs.String + } + dependentKey := dep.DependentName + "(" + dependentArgs + ")" + + referencedArgs := "" + if dep.ReferencedArgs.Valid { + referencedArgs = dep.ReferencedArgs.String + } + + // Store as schema.name(args) for cross-schema support + referencedKey := dep.ReferencedSchema + "." + dep.ReferencedName + "(" + referencedArgs + ")" + depMap[dependentKey] = append(depMap[dependentKey], referencedKey) + } + + // Assign dependencies to functions + for funcKey, fn := range dbSchema.Functions { + if deps, ok := depMap[funcKey]; ok { + fn.Dependencies = deps + } + } + + return nil +} + // splitParameterString splits a parameter string by commas, but respects quotes, // parentheses, and brackets. This handles complex defaults like '{1,2,3}' or '{"key": "value"}' func splitParameterString(signature string) []string { diff --git a/ir/ir.go b/ir/ir.go index 55e1d6ee..50497c41 100644 --- a/ir/ir.go +++ b/ir/ir.go @@ -142,6 +142,7 @@ type Function struct { IsLeakproof bool `json:"is_leakproof,omitempty"` // LEAKPROOF Parallel string `json:"parallel,omitempty"` // SAFE, UNSAFE, RESTRICTED SearchPath string `json:"search_path,omitempty"` // SET search_path value + Dependencies []string `json:"dependencies,omitempty"` // Function keys (name(args)) this function depends on } // GetArguments returns the function arguments string (types only) for function identification. diff --git a/ir/queries/queries.sql b/ir/queries/queries.sql index 83da41dc..ad0fb171 100644 --- a/ir/queries/queries.sql +++ b/ir/queries/queries.sql @@ -1386,4 +1386,23 @@ SELECT (aclexplode(acl)).privilege_type AS privilege_type, (aclexplode(acl)).is_grantable AS is_grantable FROM column_acls -ORDER BY table_name, column_name, grantee_oid, privilege_type; \ No newline at end of file +ORDER BY table_name, column_name, grantee_oid, privilege_type; + +-- GetFunctionDependencies retrieves function-to-function dependencies for topological sorting +-- name: GetFunctionDependencies :many +SELECT + dependent_ns.nspname AS dependent_schema, + dependent_proc.proname AS dependent_name, + pg_get_function_identity_arguments(dependent_proc.oid) AS dependent_args, + referenced_ns.nspname AS referenced_schema, + referenced_proc.proname AS referenced_name, + pg_get_function_identity_arguments(referenced_proc.oid) AS referenced_args +FROM pg_depend d +JOIN pg_proc dependent_proc ON d.objid = dependent_proc.oid +JOIN pg_namespace dependent_ns ON dependent_proc.pronamespace = dependent_ns.oid +JOIN pg_proc referenced_proc ON d.refobjid = referenced_proc.oid +JOIN pg_namespace referenced_ns ON referenced_proc.pronamespace = referenced_ns.oid +WHERE d.classid = 'pg_proc'::regclass + AND d.refclassid = 'pg_proc'::regclass + AND d.deptype = 'n' + AND dependent_ns.nspname = $1; \ No newline at end of file diff --git a/ir/queries/queries.sql.go b/ir/queries/queries.sql.go index 151a41a7..647114d9 100644 --- a/ir/queries/queries.sql.go +++ b/ir/queries/queries.sql.go @@ -3234,3 +3234,62 @@ func (q *Queries) GetViewsForSchema(ctx context.Context, dollar_1 sql.NullString } return items, nil } + +const getFunctionDependencies = `-- name: GetFunctionDependencies :many +SELECT + dependent_ns.nspname AS dependent_schema, + dependent_proc.proname AS dependent_name, + pg_get_function_identity_arguments(dependent_proc.oid) AS dependent_args, + referenced_ns.nspname AS referenced_schema, + referenced_proc.proname AS referenced_name, + pg_get_function_identity_arguments(referenced_proc.oid) AS referenced_args +FROM pg_depend d +JOIN pg_proc dependent_proc ON d.objid = dependent_proc.oid +JOIN pg_namespace dependent_ns ON dependent_proc.pronamespace = dependent_ns.oid +JOIN pg_proc referenced_proc ON d.refobjid = referenced_proc.oid +JOIN pg_namespace referenced_ns ON referenced_proc.pronamespace = referenced_ns.oid +WHERE d.classid = 'pg_proc'::regclass + AND d.refclassid = 'pg_proc'::regclass + AND d.deptype = 'n' + AND dependent_ns.nspname = $1 +` + +type GetFunctionDependenciesRow struct { + DependentSchema string `db:"dependent_schema" json:"dependent_schema"` + DependentName string `db:"dependent_name" json:"dependent_name"` + DependentArgs sql.NullString `db:"dependent_args" json:"dependent_args"` + ReferencedSchema string `db:"referenced_schema" json:"referenced_schema"` + ReferencedName string `db:"referenced_name" json:"referenced_name"` + ReferencedArgs sql.NullString `db:"referenced_args" json:"referenced_args"` +} + +// GetFunctionDependencies retrieves function-to-function dependencies for topological sorting +func (q *Queries) GetFunctionDependencies(ctx context.Context, dollar_1 sql.NullString) ([]GetFunctionDependenciesRow, error) { + rows, err := q.db.QueryContext(ctx, getFunctionDependencies, dollar_1) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetFunctionDependenciesRow + for rows.Next() { + var i GetFunctionDependenciesRow + if err := rows.Scan( + &i.DependentSchema, + &i.DependentName, + &i.DependentArgs, + &i.ReferencedSchema, + &i.ReferencedName, + &i.ReferencedArgs, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/testdata/diff/create_function/drop_function/diff.sql b/testdata/diff/create_function/drop_function/diff.sql index 827c3d35..52298e17 100644 --- a/testdata/diff/create_function/drop_function/diff.sql +++ b/testdata/diff/create_function/drop_function/diff.sql @@ -1,4 +1,4 @@ REVOKE EXECUTE ON FUNCTION process_order(order_id integer, discount_percent numeric) FROM api_role; -DROP FUNCTION IF EXISTS get_user_stats(integer); -DROP FUNCTION IF EXISTS process_order(integer, numeric); DROP FUNCTION IF EXISTS process_payment(integer, text); +DROP FUNCTION IF EXISTS process_order(integer, numeric); +DROP FUNCTION IF EXISTS get_user_stats(integer); diff --git a/testdata/diff/create_function/drop_function/plan.json b/testdata/diff/create_function/drop_function/plan.json index e971f41d..725f6930 100644 --- a/testdata/diff/create_function/drop_function/plan.json +++ b/testdata/diff/create_function/drop_function/plan.json @@ -15,10 +15,10 @@ "path": "privileges.FUNCTION.process_order(order_id integer, discount_percent numeric).api_role" }, { - "sql": "DROP FUNCTION IF EXISTS get_user_stats(integer);", + "sql": "DROP FUNCTION IF EXISTS process_payment(integer, text);", "type": "function", "operation": "drop", - "path": "public.get_user_stats" + "path": "public.process_payment" }, { "sql": "DROP FUNCTION IF EXISTS process_order(integer, numeric);", @@ -27,10 +27,10 @@ "path": "public.process_order" }, { - "sql": "DROP FUNCTION IF EXISTS process_payment(integer, text);", + "sql": "DROP FUNCTION IF EXISTS get_user_stats(integer);", "type": "function", "operation": "drop", - "path": "public.process_payment" + "path": "public.get_user_stats" } ] } diff --git a/testdata/diff/create_function/drop_function/plan.sql b/testdata/diff/create_function/drop_function/plan.sql index 7d2aa9cc..acaab384 100644 --- a/testdata/diff/create_function/drop_function/plan.sql +++ b/testdata/diff/create_function/drop_function/plan.sql @@ -1,7 +1,7 @@ REVOKE EXECUTE ON FUNCTION process_order(order_id integer, discount_percent numeric) FROM api_role; -DROP FUNCTION IF EXISTS get_user_stats(integer); +DROP FUNCTION IF EXISTS process_payment(integer, text); DROP FUNCTION IF EXISTS process_order(integer, numeric); -DROP FUNCTION IF EXISTS process_payment(integer, text); +DROP FUNCTION IF EXISTS get_user_stats(integer); diff --git a/testdata/diff/create_function/drop_function/plan.txt b/testdata/diff/create_function/drop_function/plan.txt index 814100ab..dbe61f36 100644 --- a/testdata/diff/create_function/drop_function/plan.txt +++ b/testdata/diff/create_function/drop_function/plan.txt @@ -17,8 +17,8 @@ DDL to be executed: REVOKE EXECUTE ON FUNCTION process_order(order_id integer, discount_percent numeric) FROM api_role; -DROP FUNCTION IF EXISTS get_user_stats(integer); +DROP FUNCTION IF EXISTS process_payment(integer, text); DROP FUNCTION IF EXISTS process_order(integer, numeric); -DROP FUNCTION IF EXISTS process_payment(integer, text); +DROP FUNCTION IF EXISTS get_user_stats(integer); diff --git a/testdata/diff/dependency/function_to_function/diff.sql b/testdata/diff/dependency/function_to_function/diff.sql new file mode 100644 index 00000000..351cb7d9 --- /dev/null +++ b/testdata/diff/dependency/function_to_function/diff.sql @@ -0,0 +1,13 @@ +CREATE OR REPLACE FUNCTION get_raw_result() +RETURNS integer +LANGUAGE sql +VOLATILE +RETURN 42; + +CREATE OR REPLACE FUNCTION process_result( + val integer DEFAULT get_raw_result() +) +RETURNS text +LANGUAGE sql +VOLATILE +RETURN ('Processed: '::text || (val)::text); diff --git a/testdata/diff/dependency/function_to_function/new.sql b/testdata/diff/dependency/function_to_function/new.sql new file mode 100644 index 00000000..c22a6066 --- /dev/null +++ b/testdata/diff/dependency/function_to_function/new.sql @@ -0,0 +1,12 @@ +-- Base function that returns a simple type +CREATE OR REPLACE FUNCTION public.get_raw_result() +RETURNS integer +LANGUAGE SQL +RETURN 42; + +-- Function with default value that references first function +-- PostgreSQL tracks this dependency via pg_depend +CREATE OR REPLACE FUNCTION public.process_result(val integer DEFAULT get_raw_result()) +RETURNS text +LANGUAGE SQL +RETURN ('Processed: '::text || val::text); diff --git a/testdata/diff/dependency/function_to_function/old.sql b/testdata/diff/dependency/function_to_function/old.sql new file mode 100644 index 00000000..8943c537 --- /dev/null +++ b/testdata/diff/dependency/function_to_function/old.sql @@ -0,0 +1 @@ +-- Empty schema (no functions) diff --git a/testdata/diff/dependency/function_to_function/plan.json b/testdata/diff/dependency/function_to_function/plan.json new file mode 100644 index 00000000..b09bcab6 --- /dev/null +++ b/testdata/diff/dependency/function_to_function/plan.json @@ -0,0 +1,26 @@ +{ + "version": "1.0.0", + "pgschema_version": "1.6.1", + "created_at": "1970-01-01T00:00:00Z", + "source_fingerprint": { + "hash": "965b1131737c955e24c7f827c55bd78e4cb49a75adfd04229e0ba297376f5085" + }, + "groups": [ + { + "steps": [ + { + "sql": "CREATE OR REPLACE FUNCTION get_raw_result()\nRETURNS integer\nLANGUAGE sql\nVOLATILE\nRETURN 42;", + "type": "function", + "operation": "create", + "path": "public.get_raw_result" + }, + { + "sql": "CREATE OR REPLACE FUNCTION process_result(\n val integer DEFAULT get_raw_result()\n)\nRETURNS text\nLANGUAGE sql\nVOLATILE\nRETURN ('Processed: '::text || (val)::text);", + "type": "function", + "operation": "create", + "path": "public.process_result" + } + ] + } + ] +} diff --git a/testdata/diff/dependency/function_to_function/plan.sql b/testdata/diff/dependency/function_to_function/plan.sql new file mode 100644 index 00000000..351cb7d9 --- /dev/null +++ b/testdata/diff/dependency/function_to_function/plan.sql @@ -0,0 +1,13 @@ +CREATE OR REPLACE FUNCTION get_raw_result() +RETURNS integer +LANGUAGE sql +VOLATILE +RETURN 42; + +CREATE OR REPLACE FUNCTION process_result( + val integer DEFAULT get_raw_result() +) +RETURNS text +LANGUAGE sql +VOLATILE +RETURN ('Processed: '::text || (val)::text); diff --git a/testdata/diff/dependency/function_to_function/plan.txt b/testdata/diff/dependency/function_to_function/plan.txt new file mode 100644 index 00000000..2e4d8ff2 --- /dev/null +++ b/testdata/diff/dependency/function_to_function/plan.txt @@ -0,0 +1,25 @@ +Plan: 2 to add. + +Summary by type: + functions: 2 to add + +Functions: + + get_raw_result + + process_result + +DDL to be executed: +-------------------------------------------------- + +CREATE OR REPLACE FUNCTION get_raw_result() +RETURNS integer +LANGUAGE sql +VOLATILE +RETURN 42; + +CREATE OR REPLACE FUNCTION process_result( + val integer DEFAULT get_raw_result() +) +RETURNS text +LANGUAGE sql +VOLATILE +RETURN ('Processed: '::text || (val)::text);