diff --git a/cmd/plan/plan.go b/cmd/plan/plan.go index ab6173bb..3f02b3e5 100644 --- a/cmd/plan/plan.go +++ b/cmd/plan/plan.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/pgschema/pgschema/cmd/util" "github.com/pgschema/pgschema/internal/diff" @@ -397,6 +398,8 @@ func processOutput(migrationPlan *plan.Plan, output outputSpec, cmd *cobra.Comma // Without this normalization, generated DDL would reference non-existent temporary schemas // and fail when applied to the target database. func normalizeSchemaNames(irData *ir.IR, fromSchema, toSchema string) { + replaceString := newSchemaStringReplacer(fromSchema, toSchema) + // Normalize schema names in Schemas map if schema, exists := irData.Schemas[fromSchema]; exists { delete(irData.Schemas, fromSchema) @@ -418,6 +421,7 @@ func normalizeSchemaNames(irData *ir.IR, fromSchema, toSchema string) { if constraint.ReferencedSchema == fromSchema { constraint.ReferencedSchema = toSchema } + constraint.CheckClause = replaceString(constraint.CheckClause) } // Normalize schema references in table dependencies @@ -427,11 +431,23 @@ func normalizeSchemaNames(irData *ir.IR, fromSchema, toSchema string) { } } + // Normalize column data types and expressions + for _, column := range table.Columns { + column.DataType = replaceString(column.DataType) + if column.DefaultValue != nil { + *column.DefaultValue = replaceString(*column.DefaultValue) + } + if column.GeneratedExpr != nil { + *column.GeneratedExpr = replaceString(*column.GeneratedExpr) + } + } + // Normalize schema names in indexes for _, index := range table.Indexes { if index.Schema == fromSchema { index.Schema = toSchema } + index.Where = replaceString(index.Where) } // Normalize schema names in triggers @@ -439,6 +455,8 @@ func normalizeSchemaNames(irData *ir.IR, fromSchema, toSchema string) { if trigger.Schema == fromSchema { trigger.Schema = toSchema } + trigger.Function = replaceString(trigger.Function) + trigger.Condition = replaceString(trigger.Condition) } // Normalize schema names in RLS policies @@ -446,45 +464,94 @@ func normalizeSchemaNames(irData *ir.IR, fromSchema, toSchema string) { if policy.Schema == fromSchema { policy.Schema = toSchema } + policy.Using = replaceString(policy.Using) + policy.WithCheck = replaceString(policy.WithCheck) } } // Views for _, view := range schema.Views { view.Schema = toSchema + view.Definition = replaceString(view.Definition) // Normalize schema names in materialized view indexes for _, index := range view.Indexes { if index.Schema == fromSchema { index.Schema = toSchema } + index.Where = replaceString(index.Where) } } // Functions for _, fn := range schema.Functions { fn.Schema = toSchema + fn.ReturnType = replaceString(fn.ReturnType) + fn.Definition = replaceString(fn.Definition) + for _, param := range fn.Parameters { + param.DataType = replaceString(param.DataType) + } } // Procedures for _, proc := range schema.Procedures { proc.Schema = toSchema + proc.Definition = replaceString(proc.Definition) + for _, param := range proc.Parameters { + param.DataType = replaceString(param.DataType) + } } // Types for _, typ := range schema.Types { typ.Schema = toSchema + typ.BaseType = replaceString(typ.BaseType) + typ.Default = replaceString(typ.Default) + for _, col := range typ.Columns { + col.DataType = replaceString(col.DataType) + } + for _, constraint := range typ.Constraints { + constraint.Definition = replaceString(constraint.Definition) + } } // Sequences for _, seq := range schema.Sequences { seq.Schema = toSchema + seq.DataType = replaceString(seq.DataType) + seq.OwnedByTable = replaceString(seq.OwnedByTable) } // Aggregates for _, agg := range schema.Aggregates { agg.Schema = toSchema + agg.ReturnType = replaceString(agg.ReturnType) + agg.TransitionFunction = replaceString(agg.TransitionFunction) + agg.StateType = replaceString(agg.StateType) + agg.InitialCondition = replaceString(agg.InitialCondition) + agg.FinalFunction = replaceString(agg.FinalFunction) + } + } +} + +func newSchemaStringReplacer(fromSchema, toSchema string) func(string) string { + if fromSchema == "" || toSchema == "" || fromSchema == toSchema { + return func(s string) string { return s } + } + + replacements := []string{ + fmt.Sprintf(`"%s".`, fromSchema), fmt.Sprintf(`"%s".`, toSchema), + fmt.Sprintf(`%s.`, fromSchema), fmt.Sprintf(`%s.`, toSchema), + fmt.Sprintf(`"%s"`, fromSchema), fmt.Sprintf(`"%s"`, toSchema), + fromSchema, toSchema, + } + + replacer := strings.NewReplacer(replacements...) + return func(input string) string { + if input == "" { + return input } + return replacer.Replace(input) } }