Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 148 additions & 71 deletions internal/diff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,36 +250,36 @@ type Diff struct {
}

type ddlDiff struct {
addedSchemas []*ir.Schema
droppedSchemas []*ir.Schema
modifiedSchemas []*schemaDiff
addedTables []*ir.Table
droppedTables []*ir.Table
modifiedTables []*tableDiff
addedViews []*ir.View
droppedViews []*ir.View
modifiedViews []*viewDiff
addedFunctions []*ir.Function
droppedFunctions []*ir.Function
modifiedFunctions []*functionDiff
addedProcedures []*ir.Procedure
droppedProcedures []*ir.Procedure
modifiedProcedures []*procedureDiff
addedTypes []*ir.Type
droppedTypes []*ir.Type
modifiedTypes []*typeDiff
addedSchemas []*ir.Schema
droppedSchemas []*ir.Schema
modifiedSchemas []*schemaDiff
addedTables []*ir.Table
droppedTables []*ir.Table
modifiedTables []*tableDiff
addedViews []*ir.View
droppedViews []*ir.View
modifiedViews []*viewDiff
addedFunctions []*ir.Function
droppedFunctions []*ir.Function
modifiedFunctions []*functionDiff
addedProcedures []*ir.Procedure
droppedProcedures []*ir.Procedure
modifiedProcedures []*procedureDiff
addedTypes []*ir.Type
droppedTypes []*ir.Type
modifiedTypes []*typeDiff
addedSequences []*ir.Sequence
droppedSequences []*ir.Sequence
modifiedSequences []*sequenceDiff
addedDefaultPrivileges []*ir.DefaultPrivilege
droppedDefaultPrivileges []*ir.DefaultPrivilege
modifiedDefaultPrivileges []*defaultPrivilegeDiff
// Explicit object privileges
addedPrivileges []*ir.Privilege
droppedPrivileges []*ir.Privilege
modifiedPrivileges []*privilegeDiff
addedRevokedDefaultPrivs []*ir.RevokedDefaultPrivilege
droppedRevokedDefaultPrivs []*ir.RevokedDefaultPrivilege
addedPrivileges []*ir.Privilege
droppedPrivileges []*ir.Privilege
modifiedPrivileges []*privilegeDiff
addedRevokedDefaultPrivs []*ir.RevokedDefaultPrivilege
droppedRevokedDefaultPrivs []*ir.RevokedDefaultPrivilege
// Column-level privileges
addedColumnPrivileges []*ir.ColumnPrivilege
droppedColumnPrivileges []*ir.ColumnPrivilege
Expand Down Expand Up @@ -411,38 +411,38 @@ type rlsChange struct {
// GenerateMigration compares two IR schemas and returns the SQL differences
func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string) []Diff {
diff := &ddlDiff{
addedSchemas: []*ir.Schema{},
droppedSchemas: []*ir.Schema{},
modifiedSchemas: []*schemaDiff{},
addedTables: []*ir.Table{},
droppedTables: []*ir.Table{},
modifiedTables: []*tableDiff{},
addedViews: []*ir.View{},
droppedViews: []*ir.View{},
modifiedViews: []*viewDiff{},
addedFunctions: []*ir.Function{},
droppedFunctions: []*ir.Function{},
modifiedFunctions: []*functionDiff{},
addedProcedures: []*ir.Procedure{},
droppedProcedures: []*ir.Procedure{},
modifiedProcedures: []*procedureDiff{},
addedTypes: []*ir.Type{},
droppedTypes: []*ir.Type{},
modifiedTypes: []*typeDiff{},
addedSequences: []*ir.Sequence{},
droppedSequences: []*ir.Sequence{},
modifiedSequences: []*sequenceDiff{},
addedDefaultPrivileges: []*ir.DefaultPrivilege{},
droppedDefaultPrivileges: []*ir.DefaultPrivilege{},
modifiedDefaultPrivileges: []*defaultPrivilegeDiff{},
addedPrivileges: []*ir.Privilege{},
droppedPrivileges: []*ir.Privilege{},
modifiedPrivileges: []*privilegeDiff{},
addedRevokedDefaultPrivs: []*ir.RevokedDefaultPrivilege{},
droppedRevokedDefaultPrivs: []*ir.RevokedDefaultPrivilege{},
addedColumnPrivileges: []*ir.ColumnPrivilege{},
droppedColumnPrivileges: []*ir.ColumnPrivilege{},
modifiedColumnPrivileges: []*columnPrivilegeDiff{},
addedSchemas: []*ir.Schema{},
droppedSchemas: []*ir.Schema{},
modifiedSchemas: []*schemaDiff{},
addedTables: []*ir.Table{},
droppedTables: []*ir.Table{},
modifiedTables: []*tableDiff{},
addedViews: []*ir.View{},
droppedViews: []*ir.View{},
modifiedViews: []*viewDiff{},
addedFunctions: []*ir.Function{},
droppedFunctions: []*ir.Function{},
modifiedFunctions: []*functionDiff{},
addedProcedures: []*ir.Procedure{},
droppedProcedures: []*ir.Procedure{},
modifiedProcedures: []*procedureDiff{},
addedTypes: []*ir.Type{},
droppedTypes: []*ir.Type{},
modifiedTypes: []*typeDiff{},
addedSequences: []*ir.Sequence{},
droppedSequences: []*ir.Sequence{},
modifiedSequences: []*sequenceDiff{},
addedDefaultPrivileges: []*ir.DefaultPrivilege{},
droppedDefaultPrivileges: []*ir.DefaultPrivilege{},
modifiedDefaultPrivileges: []*defaultPrivilegeDiff{},
addedPrivileges: []*ir.Privilege{},
droppedPrivileges: []*ir.Privilege{},
modifiedPrivileges: []*privilegeDiff{},
addedRevokedDefaultPrivs: []*ir.RevokedDefaultPrivilege{},
droppedRevokedDefaultPrivs: []*ir.RevokedDefaultPrivilege{},
addedColumnPrivileges: []*ir.ColumnPrivilege{},
droppedColumnPrivileges: []*ir.ColumnPrivilege{},
modifiedColumnPrivileges: []*columnPrivilegeDiff{},
}

// Compare schemas first in deterministic order
Expand Down Expand Up @@ -1411,8 +1411,31 @@ func (d *ddlDiff) generatePreDropMaterializedViewsSQL(targetSchema string, colle
func (d *ddlDiff) generateCreateSQL(targetSchema string, collector *diffCollector) {
// Note: Schema creation is out of scope for schema-level comparisons

// Create types
generateCreateTypesSQL(d.addedTypes, targetSchema, collector)
// Build function lookup early - needed for both domain and table dependency checks
newFunctionLookup := buildFunctionLookup(d.addedFunctions)

// Separate types into domains with/without function dependencies
// Domains with function deps (e.g., CHECK constraints referencing functions) must be created after functions
typesWithoutFunctionDeps := []*ir.Type{}
domainsWithFunctionDeps := []*ir.Type{}
deferredDomainLookup := make(map[string]struct{})

for _, typeObj := range d.addedTypes {
if typeObj.Kind == ir.TypeKindDomain && domainReferencesNewFunction(typeObj, newFunctionLookup) {
domainsWithFunctionDeps = append(domainsWithFunctionDeps, typeObj)
// Track deferred domains so we can defer tables that use them
deferredDomainLookup[strings.ToLower(typeObj.Name)] = struct{}{}
if typeObj.Schema != "" {
qualified := fmt.Sprintf("%s.%s", strings.ToLower(typeObj.Schema), strings.ToLower(typeObj.Name))
deferredDomainLookup[qualified] = struct{}{}
}
} else {
typesWithoutFunctionDeps = append(typesWithoutFunctionDeps, typeObj)
}
}

// Create types WITHOUT function dependencies (enum, composite, and domains without function deps)
generateCreateTypesSQL(typesWithoutFunctionDeps, targetSchema, collector)

// Create sequences
generateCreateSequencesSQL(d.addedSequences, targetSchema, collector)
Expand All @@ -1423,39 +1446,41 @@ func (d *ddlDiff) generateCreateSQL(targetSchema string, collector *diffCollecto
key := fmt.Sprintf("%s.%s", tableDiff.Table.Schema, tableDiff.Table.Name)
existingTables[key] = true
}

newFunctionLookup := buildFunctionLookup(d.addedFunctions)
var shouldDeferPolicy func(*ir.RLSPolicy) bool
if len(newFunctionLookup) > 0 {
shouldDeferPolicy = func(policy *ir.RLSPolicy) bool {
return policyReferencesNewFunction(policy, newFunctionLookup)
}
}

// Separate tables into those that depend on new functions and those that don't
// This ensures we create functions before tables that use them in defaults/checks
tablesWithoutFunctionDeps := []*ir.Table{}
tablesWithFunctionDeps := []*ir.Table{}
// Separate tables into those that depend on new functions/deferred domains and those that don't
// This ensures we create functions and domains before tables that use them
tablesWithoutDeps := []*ir.Table{}
tablesWithDeps := []*ir.Table{}

for _, table := range d.addedTables {
if tableReferencesNewFunction(table, newFunctionLookup) {
tablesWithFunctionDeps = append(tablesWithFunctionDeps, table)
if tableReferencesNewFunction(table, newFunctionLookup) || tableUsesDeferredDomain(table, deferredDomainLookup) {
tablesWithDeps = append(tablesWithDeps, table)
} else {
tablesWithoutFunctionDeps = append(tablesWithoutFunctionDeps, table)
tablesWithoutDeps = append(tablesWithoutDeps, table)
}
}

// Create tables WITHOUT function dependencies first (functions may reference these)
deferredPolicies1, deferredConstraints1 := generateCreateTablesSQL(tablesWithoutFunctionDeps, targetSchema, collector, existingTables, shouldDeferPolicy)
// Create tables WITHOUT function/domain dependencies first (functions may reference these)
deferredPolicies1, deferredConstraints1 := generateCreateTablesSQL(tablesWithoutDeps, targetSchema, collector, existingTables, shouldDeferPolicy)

// Create functions (functions may depend on tables created above)
generateCreateFunctionsSQL(d.addedFunctions, targetSchema, collector)

// Create procedures (procedures may depend on tables)
// Create domains WITH function dependencies (now that functions exist)
// These domains have CHECK constraints that reference functions
generateCreateTypesSQL(domainsWithFunctionDeps, targetSchema, collector)

// Create procedures (procedures may depend on tables and domains)
generateCreateProceduresSQL(d.addedProcedures, targetSchema, collector)

// Create tables WITH function dependencies (now that functions exist)
deferredPolicies2, deferredConstraints2 := generateCreateTablesSQL(tablesWithFunctionDeps, targetSchema, collector, existingTables, shouldDeferPolicy)
// Create tables WITH function/domain dependencies (now that functions and deferred domains exist)
deferredPolicies2, deferredConstraints2 := generateCreateTablesSQL(tablesWithDeps, targetSchema, collector, existingTables, shouldDeferPolicy)

// Add deferred foreign key constraints from BOTH batches AFTER all tables are created
// This ensures FK references to tables in the second batch (function-dependent tables) work correctly
Expand Down Expand Up @@ -1778,6 +1803,58 @@ func policyReferencesNewFunction(policy *ir.RLSPolicy, newFunctions map[string]s
return false
}

// tableUsesDeferredDomain determines if a table uses any deferred domain types in its columns.
func tableUsesDeferredDomain(table *ir.Table, deferredDomains map[string]struct{}) bool {
if len(deferredDomains) == 0 || table == nil {
return false
}

for _, col := range table.Columns {
if col.DataType == "" {
continue
}
// Normalize the type name for lookup
typeName := strings.ToLower(col.DataType)
if _, ok := deferredDomains[typeName]; ok {
return true
}
// Try with table's schema prefix
if table.Schema != "" && !strings.Contains(typeName, ".") {
qualified := fmt.Sprintf("%s.%s", strings.ToLower(table.Schema), typeName)
if _, ok := deferredDomains[qualified]; ok {
return true
}
}
}
return false
}

// domainReferencesNewFunction determines if a domain references any newly added functions
// in its CHECK constraints or default value.
func domainReferencesNewFunction(typeObj *ir.Type, newFunctions map[string]struct{}) bool {
if len(newFunctions) == 0 || typeObj == nil || typeObj.Kind != ir.TypeKindDomain {
return false
}

// Check default value
if typeObj.Default != "" {
if referencesNewFunction(typeObj.Default, typeObj.Schema, newFunctions) {
return true
}
}

// Check CHECK constraints
for _, constraint := range typeObj.Constraints {
if constraint.Definition != "" {
if referencesNewFunction(constraint.Definition, typeObj.Schema, newFunctions) {
return true
}
}
}

return false
}

func referencesNewFunction(expr, defaultSchema string, newFunctions map[string]struct{}) bool {
if expr == "" || len(newFunctions) == 0 {
return false
Expand Down
3 changes: 3 additions & 0 deletions internal/diff/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (

// generateCreateFunctionsSQL generates CREATE FUNCTION statements
func generateCreateFunctionsSQL(functions []*ir.Function, targetSchema string, collector *diffCollector) {
// Build dependencies from function bodies (supplements pg_depend, which doesn't track SQL function body references)
buildFunctionBodyDependencies(functions)

// Sort functions by dependency order (topological sort)
sortedFunctions := topologicallySortFunctions(functions)

Expand Down
79 changes: 79 additions & 0 deletions internal/diff/topological.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package diff

import (
"sort"
"strings"

"github.com/pgschema/pgschema/ir"
)
Expand Down Expand Up @@ -653,3 +654,81 @@ func constraintMatchesFKReference(uniqueConstraint, fkConstraint *ir.Constraint)

return true
}

// buildFunctionBodyDependencies scans function bodies for function calls and populates
// the Dependencies field. This supplements dependencies from pg_depend, which doesn't
// track references inside SQL function bodies.
func buildFunctionBodyDependencies(functions []*ir.Function) {
if len(functions) <= 1 {
return
}

// Build lookup maps by function name (both qualified and unqualified)
// Map to the full key format used by Dependencies: schema.name(args)
type funcInfo struct {
fn *ir.Function
key string
}
functionLookup := make(map[string]funcInfo)

for _, fn := range functions {
key := fn.Schema + "." + fn.Name + "(" + fn.GetArguments() + ")"
name := strings.ToLower(fn.Name)

// Store under unqualified name
functionLookup[name] = funcInfo{fn: fn, key: key}

// Store under qualified name
if fn.Schema != "" {
qualified := strings.ToLower(fn.Schema) + "." + name
functionLookup[qualified] = funcInfo{fn: fn, key: key}
}
}

// For each function, scan its body for function calls
for _, fn := range functions {
if fn.Definition == "" {
continue
}

fnKey := fn.Schema + "." + fn.Name + "(" + fn.GetArguments() + ")"

matches := functionCallRegex.FindAllStringSubmatch(fn.Definition, -1)
for _, match := range matches {
if len(match) < 2 {
continue
}
identifier := strings.ToLower(match[1])
if identifier == "" {
continue
}

// Try to find the referenced function
var info funcInfo
var found bool

if info, found = functionLookup[identifier]; !found {
// Try with schema prefix if identifier is unqualified
if !strings.Contains(identifier, ".") && fn.Schema != "" {
qualified := strings.ToLower(fn.Schema) + "." + identifier
info, found = functionLookup[qualified]
}
}

// If found and not self-reference, add dependency
if found && info.key != fnKey {
// Check if dependency already exists
alreadyExists := false
for _, existing := range fn.Dependencies {
if existing == info.key {
alreadyExists = true
break
}
}
if !alreadyExists {
fn.Dependencies = append(fn.Dependencies, info.key)
}
}
}
}
}
Loading