Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions cmd/plan/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"

"github.com/pgschema/pgschema/cmd/util"
"github.com/pgschema/pgschema/internal/diff"
Expand Down Expand Up @@ -397,6 +398,8 @@ func processOutput(migrationPlan *plan.Plan, output outputSpec, cmd *cobra.Comma
// Without this normalization, generated DDL would reference non-existent temporary schemas
// and fail when applied to the target database.
func normalizeSchemaNames(irData *ir.IR, fromSchema, toSchema string) {
replaceString := newSchemaStringReplacer(fromSchema, toSchema)

// Normalize schema names in Schemas map
if schema, exists := irData.Schemas[fromSchema]; exists {
delete(irData.Schemas, fromSchema)
Expand All @@ -418,6 +421,7 @@ func normalizeSchemaNames(irData *ir.IR, fromSchema, toSchema string) {
if constraint.ReferencedSchema == fromSchema {
constraint.ReferencedSchema = toSchema
}
constraint.CheckClause = replaceString(constraint.CheckClause)
}

// Normalize schema references in table dependencies
Expand All @@ -427,64 +431,127 @@ func normalizeSchemaNames(irData *ir.IR, fromSchema, toSchema string) {
}
}

// Normalize column data types and expressions
for _, column := range table.Columns {
column.DataType = replaceString(column.DataType)
if column.DefaultValue != nil {
*column.DefaultValue = replaceString(*column.DefaultValue)
}
if column.GeneratedExpr != nil {
*column.GeneratedExpr = replaceString(*column.GeneratedExpr)
}
}

// Normalize schema names in indexes
for _, index := range table.Indexes {
if index.Schema == fromSchema {
index.Schema = toSchema
}
index.Where = replaceString(index.Where)
}

// Normalize schema names in triggers
for _, trigger := range table.Triggers {
if trigger.Schema == fromSchema {
trigger.Schema = toSchema
}
trigger.Function = replaceString(trigger.Function)
trigger.Condition = replaceString(trigger.Condition)
}

// Normalize schema names in RLS policies
for _, policy := range table.Policies {
if policy.Schema == fromSchema {
policy.Schema = toSchema
}
policy.Using = replaceString(policy.Using)
policy.WithCheck = replaceString(policy.WithCheck)
}
}

// Views
for _, view := range schema.Views {
view.Schema = toSchema
view.Definition = replaceString(view.Definition)

// Normalize schema names in materialized view indexes
for _, index := range view.Indexes {
if index.Schema == fromSchema {
index.Schema = toSchema
}
index.Where = replaceString(index.Where)
}
}

// Functions
for _, fn := range schema.Functions {
fn.Schema = toSchema
fn.ReturnType = replaceString(fn.ReturnType)
fn.Definition = replaceString(fn.Definition)
for _, param := range fn.Parameters {
param.DataType = replaceString(param.DataType)
}
}

// Procedures
for _, proc := range schema.Procedures {
proc.Schema = toSchema
proc.Definition = replaceString(proc.Definition)
for _, param := range proc.Parameters {
param.DataType = replaceString(param.DataType)
}
}

// Types
for _, typ := range schema.Types {
typ.Schema = toSchema
typ.BaseType = replaceString(typ.BaseType)
typ.Default = replaceString(typ.Default)
for _, col := range typ.Columns {
col.DataType = replaceString(col.DataType)
}
for _, constraint := range typ.Constraints {
constraint.Definition = replaceString(constraint.Definition)
}
}

// Sequences
for _, seq := range schema.Sequences {
seq.Schema = toSchema
seq.DataType = replaceString(seq.DataType)
seq.OwnedByTable = replaceString(seq.OwnedByTable)
}

// Aggregates
for _, agg := range schema.Aggregates {
agg.Schema = toSchema
agg.ReturnType = replaceString(agg.ReturnType)
agg.TransitionFunction = replaceString(agg.TransitionFunction)
agg.StateType = replaceString(agg.StateType)
agg.InitialCondition = replaceString(agg.InitialCondition)
agg.FinalFunction = replaceString(agg.FinalFunction)
}
}
}

func newSchemaStringReplacer(fromSchema, toSchema string) func(string) string {
if fromSchema == "" || toSchema == "" || fromSchema == toSchema {
return func(s string) string { return s }
}

replacements := []string{
fmt.Sprintf(`"%s".`, fromSchema), fmt.Sprintf(`"%s".`, toSchema),
fmt.Sprintf(`%s.`, fromSchema), fmt.Sprintf(`%s.`, toSchema),
fmt.Sprintf(`"%s"`, fromSchema), fmt.Sprintf(`"%s"`, toSchema),
fromSchema, toSchema,
}
Comment on lines +542 to +547
Copy link

Copilot AI Nov 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The replacement order in strings.NewReplacer can cause incorrect results. More specific patterns (like \"schema\".) should be replaced before less specific ones (like bare schema). However, strings.NewReplacer processes replacements in order, and a bare schema name could be replaced first, breaking more specific patterns. For example, if fromSchema is "temp" and toSchema is "public", the string \"temp\".table might first match the bare temp replacement, becoming \"public\".table instead of \"public\".table. While this specific example works, consider cases where partial matches could cause issues. The safer approach is to process replacements from most specific to least specific, which requires either reordering the slice to ensure longer patterns are first, or using a different replacement strategy like regexp with word boundaries.

Copilot uses AI. Check for mistakes.

replacer := strings.NewReplacer(replacements...)
return func(input string) string {
if input == "" {
return input
}
return replacer.Replace(input)
}
}

Expand Down
Loading