Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion internal/diff/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func (cd *ColumnDiff) generateColumnSQL(tableSchema, tableName string, targetSch
}

// Handle default value changes
// Default values are already normalized by ir.normalizeColumn
oldDefault := cd.Old.DefaultValue
newDefault := cd.New.DefaultValue

Expand Down Expand Up @@ -67,7 +68,7 @@ func columnsEqual(old, new *ir.Column) bool {
return false
}

// Compare default values
// Compare default values (already normalized by ir.normalizeColumn)
if (old.DefaultValue == nil) != (new.DefaultValue == nil) {
return false
}
Expand Down
70 changes: 64 additions & 6 deletions internal/diff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -943,22 +943,45 @@ func (d *ddlDiff) generateCreateSQL(targetSchema string, collector *diffCollecto
}
}

// Create tables with co-located indexes, policies, and RLS and collect deferred work
deferredPolicies, deferredConstraints := generateCreateTablesSQL(d.addedTables, targetSchema, collector, existingTables, shouldDeferPolicy)
// Separate tables into those that depend on new functions and those that don't
// This ensures we create functions before tables that use them in defaults/checks
tablesWithoutFunctionDeps := []*ir.Table{}
tablesWithFunctionDeps := []*ir.Table{}

for _, table := range d.addedTables {
if tableReferencesNewFunction(table, newFunctionLookup) {
tablesWithFunctionDeps = append(tablesWithFunctionDeps, table)
} else {
tablesWithoutFunctionDeps = append(tablesWithoutFunctionDeps, table)
}
}

// Create tables WITHOUT function dependencies first (functions may reference these)
deferredPolicies1, deferredConstraints1 := generateCreateTablesSQL(tablesWithoutFunctionDeps, targetSchema, collector, existingTables, shouldDeferPolicy)

// Add deferred foreign key constraints now that referenced tables exist
generateDeferredConstraintsSQL(deferredConstraints, targetSchema, collector)
// Add deferred foreign key constraints from first batch
generateDeferredConstraintsSQL(deferredConstraints1, targetSchema, collector)

// Create functions (functions may depend on tables)
// Create functions (functions may depend on tables created above)
generateCreateFunctionsSQL(d.addedFunctions, targetSchema, collector)

// Create procedures (procedures may depend on tables)
generateCreateProceduresSQL(d.addedProcedures, targetSchema, collector)

// Create tables WITH function dependencies (now that functions exist)
deferredPolicies2, deferredConstraints2 := generateCreateTablesSQL(tablesWithFunctionDeps, targetSchema, collector, existingTables, shouldDeferPolicy)

// Add deferred foreign key constraints from second batch
generateDeferredConstraintsSQL(deferredConstraints2, targetSchema, collector)

// Merge deferred policies from both batches
allDeferredPolicies := append(deferredPolicies1, deferredPolicies2...)

// Create policies after functions/procedures to satisfy dependencies
generateCreatePoliciesSQL(deferredPolicies, targetSchema, collector)
generateCreatePoliciesSQL(allDeferredPolicies, targetSchema, collector)

// Create triggers (triggers may depend on functions/procedures)
// Note: We need to create triggers for ALL tables, not just the original d.addedTables
Copy link

Copilot AI Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "We need to create triggers for ALL tables, not just the original d.addedTables", but the code on line 985 still only passes d.addedTables. Either the comment is misleading or the code needs to be updated. If triggers should be created for all newly added tables (including both tablesWithoutFunctionDeps and tablesWithFunctionDeps), the code is already correct since d.addedTables contains all of them. Consider clarifying the comment or removing it if it's not accurate.

Suggested change
// Note: We need to create triggers for ALL tables, not just the original d.addedTables

Copilot uses AI. Check for mistakes.
generateCreateTriggersFromTables(d.addedTables, targetSchema, collector)

// Create views
Expand Down Expand Up @@ -1173,6 +1196,41 @@ func buildFunctionLookup(functions []*ir.Function) map[string]struct{} {

var functionCallRegex = regexp.MustCompile(`(?i)([a-z_][a-z0-9_$]*(?:\.[a-z_][a-z0-9_$]*)*)\s*\(`)

// tableReferencesNewFunction determines if a table references any newly added functions
// in column defaults, generated columns, or CHECK constraints.
func tableReferencesNewFunction(table *ir.Table, newFunctions map[string]struct{}) bool {
if len(newFunctions) == 0 || table == nil {
return false
}

// Check column defaults and generated expressions
for _, col := range table.Columns {
// Check default value
if col.DefaultValue != nil && *col.DefaultValue != "" {
if referencesNewFunction(*col.DefaultValue, table.Schema, newFunctions) {
return true
}
}
// Check generated column expression
if col.GeneratedExpr != nil && *col.GeneratedExpr != "" {
if referencesNewFunction(*col.GeneratedExpr, table.Schema, newFunctions) {
return true
}
}
}

// Check CHECK constraints
for _, constraint := range table.Constraints {
if constraint.Type == ir.ConstraintTypeCheck && constraint.CheckClause != "" {
if referencesNewFunction(constraint.CheckClause, table.Schema, newFunctions) {
return true
}
}
}

return false
}

// policyReferencesNewFunction determines if a policy references any newly added functions.
func policyReferencesNewFunction(policy *ir.RLSPolicy, newFunctions map[string]struct{}) bool {
if len(newFunctions) == 0 || policy == nil {
Expand Down
17 changes: 3 additions & 14 deletions internal/diff/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -1184,20 +1184,9 @@ func buildColumnClauses(column *ir.Column, isPartOfAnyPK bool, tableSchema strin

// 2. DEFAULT (skip for SERIAL, identity, or generated columns)
if column.DefaultValue != nil && column.Identity == nil && !column.IsGenerated && !isSerialColumn(column) {
defaultValue := *column.DefaultValue
// Handle schema-agnostic sequence references in defaults
if strings.Contains(defaultValue, "nextval") {
// Remove schema qualifiers from sequence references in the target schema
schemaToRemove := targetSchema
if schemaToRemove == "" {
schemaToRemove = tableSchema
}
schemaPrefix := schemaToRemove + "."
defaultValue = strings.ReplaceAll(defaultValue, schemaPrefix, "")
}

// Type casts are now preserved (from pg_query.Deparse) for canonical representation
parts = append(parts, fmt.Sprintf("DEFAULT %s", defaultValue))
// DefaultValue is already normalized by ir.normalizeColumn
// (schema qualifiers and sequence references are handled there)
parts = append(parts, fmt.Sprintf("DEFAULT %s", *column.DefaultValue))
}

// 3. Generated column syntax (must come before constraints)
Expand Down
32 changes: 23 additions & 9 deletions ir/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ func normalizeTable(table *Table) {
return
}

// Normalize columns
// Normalize columns (pass table schema for context)
for _, column := range table.Columns {
normalizeColumn(column)
normalizeColumn(column, table.Schema)
}

// Normalize policies
Expand All @@ -92,17 +92,19 @@ func normalizeTable(table *Table) {
}

// normalizeColumn normalizes column default values
func normalizeColumn(column *Column) {
// tableSchema is used to strip same-schema qualifiers from function calls
func normalizeColumn(column *Column, tableSchema string) {
if column == nil || column.DefaultValue == nil {
return
}

normalized := normalizeDefaultValue(*column.DefaultValue)
normalized := normalizeDefaultValue(*column.DefaultValue, tableSchema)
column.DefaultValue = &normalized
}

// normalizeDefaultValue normalizes default values for semantic comparison
func normalizeDefaultValue(value string) string {
// tableSchema is used to strip same-schema qualifiers from function calls
func normalizeDefaultValue(value string, tableSchema string) string {
// Remove unnecessary whitespace
value = strings.TrimSpace(value)

Expand All @@ -118,6 +120,18 @@ func normalizeDefaultValue(value string) string {
return value
}

// Normalize function calls - remove schema qualifiers for functions in the same schema
// This matches PostgreSQL's pg_get_expr() behavior which strips same-schema qualifiers
// Example: public.get_status() -> get_status() (when tableSchema is "public")
// other_schema.get_status() -> other_schema.get_status() (preserved)
if tableSchema != "" && strings.Contains(value, tableSchema+".") {
// Pattern: schema.function_name(
// Replace "tableSchema." with "" when followed by identifier and (
prefix := tableSchema + "."
pattern := regexp.MustCompile(regexp.QuoteMeta(prefix) + `([a-zA-Z_][a-zA-Z0-9_]*)\(`)
Copy link

Copilot AI Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regex pattern for matching function names should include $ in the character class to match PostgreSQL's identifier rules. PostgreSQL allows $ in identifiers, and the functionCallRegex in internal/diff/diff.go (line 1197) correctly includes it: [a-z0-9_$].

Update the pattern to:

pattern := regexp.MustCompile(regexp.QuoteMeta(prefix) + `([a-zA-Z_][a-zA-Z0-9_$]*)\(`)

This ensures consistent handling of function names that contain $ characters (e.g., public.my$function()).

Suggested change
pattern := regexp.MustCompile(regexp.QuoteMeta(prefix) + `([a-zA-Z_][a-zA-Z0-9_]*)\(`)
pattern := regexp.MustCompile(regexp.QuoteMeta(prefix) + `([a-zA-Z_][a-zA-Z0-9_$]*)\(`)

Copilot uses AI. Check for mistakes.
value = pattern.ReplaceAllString(value, `${1}(`)
}

// Handle type casting - remove explicit type casts that are semantically equivalent
if strings.Contains(value, "::") {
// Handle NULL::type -> NULL
Expand Down Expand Up @@ -249,9 +263,9 @@ func normalizeFunction(function *Function) {
if param.Mode == "" {
param.Mode = "IN"
}
// Normalize default values
// Normalize default values (pass function schema for context)
if param.DefaultValue != nil {
normalized := normalizeDefaultValue(*param.DefaultValue)
normalized := normalizeDefaultValue(*param.DefaultValue, function.Schema)
param.DefaultValue = &normalized
}
}
Expand Down Expand Up @@ -296,9 +310,9 @@ func normalizeProcedure(procedure *Procedure) {
if param.Mode == "" {
param.Mode = "IN"
}
// Normalize default values
// Normalize default values (pass procedure schema for context)
if param.DefaultValue != nil {
normalized := normalizeDefaultValue(*param.DefaultValue)
normalized := normalizeDefaultValue(*param.DefaultValue, procedure.Schema)
param.DefaultValue = &normalized
}
}
Expand Down
17 changes: 17 additions & 0 deletions testdata/diff/dependency/function_to_table/diff.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
CREATE OR REPLACE FUNCTION get_default_status()
RETURNS text
LANGUAGE plpgsql
SECURITY INVOKER
VOLATILE
AS $$
BEGIN
RETURN 'active';
END;
$$;

CREATE TABLE IF NOT EXISTS users (
id SERIAL,
name text NOT NULL,
status text DEFAULT get_default_status(),
CONSTRAINT users_pkey PRIMARY KEY (id)
);
16 changes: 16 additions & 0 deletions testdata/diff/dependency/function_to_table/new.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-- Simple function that returns a default value
CREATE OR REPLACE FUNCTION public.get_default_status()
RETURNS text
LANGUAGE plpgsql
AS $$
BEGIN
RETURN 'active';
END;
$$;

-- Table with column default that uses the function
CREATE TABLE public.users (
id serial PRIMARY KEY,
name text NOT NULL,
status text DEFAULT get_default_status()
);
1 change: 1 addition & 0 deletions testdata/diff/dependency/function_to_table/old.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
-- Empty schema (no objects)
26 changes: 26 additions & 0 deletions testdata/diff/dependency/function_to_table/plan.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"version": "1.0.0",
"pgschema_version": "1.4.3",
"created_at": "1970-01-01T00:00:00Z",
"source_fingerprint": {
"hash": "965b1131737c955e24c7f827c55bd78e4cb49a75adfd04229e0ba297376f5085"
},
"groups": [
{
"steps": [
{
"sql": "CREATE OR REPLACE FUNCTION get_default_status()\nRETURNS text\nLANGUAGE plpgsql\nSECURITY INVOKER\nVOLATILE\nAS $$\nBEGIN\n RETURN 'active';\nEND;\n$$;",
"type": "function",
"operation": "create",
"path": "public.get_default_status"
},
{
"sql": "CREATE TABLE IF NOT EXISTS users (\n id SERIAL,\n name text NOT NULL,\n status text DEFAULT get_default_status(),\n CONSTRAINT users_pkey PRIMARY KEY (id)\n);",
"type": "table",
"operation": "create",
"path": "public.users"
}
]
}
]
}
17 changes: 17 additions & 0 deletions testdata/diff/dependency/function_to_table/plan.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
CREATE OR REPLACE FUNCTION get_default_status()
RETURNS text
LANGUAGE plpgsql
SECURITY INVOKER
VOLATILE
AS $$
BEGIN
RETURN 'active';
END;
$$;

CREATE TABLE IF NOT EXISTS users (
id SERIAL,
name text NOT NULL,
status text DEFAULT get_default_status(),
CONSTRAINT users_pkey PRIMARY KEY (id)
);
32 changes: 32 additions & 0 deletions testdata/diff/dependency/function_to_table/plan.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Plan: 2 to add.

Summary by type:
functions: 1 to add
tables: 1 to add

Functions:
+ get_default_status

Tables:
+ users

DDL to be executed:
--------------------------------------------------

CREATE OR REPLACE FUNCTION get_default_status()
RETURNS text
LANGUAGE plpgsql
SECURITY INVOKER
VOLATILE
AS $$
BEGIN
RETURN 'active';
END;
$$;

CREATE TABLE IF NOT EXISTS users (
id SERIAL,
name text NOT NULL,
status text DEFAULT get_default_status(),
CONSTRAINT users_pkey PRIMARY KEY (id)
);