Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/diff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string) []Diff {
for _, funcName := range funcNames {
function := dbSchema.Functions[funcName]
// Use schema.name(arguments) as key to distinguish functions with different signatures
key := function.Schema + "." + funcName + "(" + function.Arguments + ")"
key := function.Schema + "." + funcName + "(" + function.GetArguments() + ")"
oldFunctions[key] = function
}
}
Expand All @@ -503,7 +503,7 @@ func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string) []Diff {
for _, funcName := range funcNames {
function := dbSchema.Functions[funcName]
// Use schema.name(arguments) as key to distinguish functions with different signatures
key := function.Schema + "." + funcName + "(" + function.Arguments + ")"
key := function.Schema + "." + funcName + "(" + function.GetArguments() + ")"
newFunctions[key] = function
}
}
Expand Down
92 changes: 34 additions & 58 deletions internal/diff/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,8 @@ func generateDropFunctionsSQL(functions []*ir.Function, targetSchema string, col
functionName := qualifyEntityName(function.Schema, function.Name, targetSchema)
var sql string

// Build argument list for DROP statement using normalized Parameters array
var argsList string
if len(function.Parameters) > 0 {
// Format parameters for DROP (omit names and defaults, include only types)
// Per PostgreSQL docs, DROP FUNCTION only needs input arguments (IN, INOUT, VARIADIC)
// Exclude OUT and TABLE mode parameters as they're part of the return signature
var argTypes []string
for _, param := range function.Parameters {
// Include only input parameter modes: IN (empty/implicit), INOUT, VARIADIC
if param.Mode == "" || param.Mode == "IN" || param.Mode == "INOUT" || param.Mode == "VARIADIC" {
argTypes = append(argTypes, param.DataType)
}
}
argsList = strings.Join(argTypes, ", ")
} else if function.Arguments != "" {
argsList = function.Arguments
}
// Build argument list for DROP statement using GetArguments()
argsList := function.GetArguments()

if argsList != "" {
sql = fmt.Sprintf("DROP FUNCTION IF EXISTS %s(%s);", functionName, argsList)
Expand Down Expand Up @@ -109,25 +94,16 @@ func generateFunctionSQL(function *ir.Function, targetSchema string) string {
functionName := qualifyEntityName(function.Schema, function.Name, targetSchema)
stmt.WriteString(fmt.Sprintf("CREATE OR REPLACE FUNCTION %s", functionName))

// Add parameters - prefer structured Parameters array for normalized types
if len(function.Parameters) > 0 {
// Build parameter list from structured Parameters array
// Exclude TABLE mode parameters as they're part of RETURNS clause
var paramParts []string
for _, param := range function.Parameters {
if param.Mode != "TABLE" {
paramParts = append(paramParts, formatFunctionParameter(param, true))
}
}
if len(paramParts) > 0 {
stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.Join(paramParts, ",\n ")))
} else {
stmt.WriteString("()")
// Add parameters from structured Parameters array
// Exclude TABLE mode parameters as they're part of RETURNS clause
var paramParts []string
for _, param := range function.Parameters {
if param.Mode != "TABLE" {
paramParts = append(paramParts, formatFunctionParameter(param, true))
}
} else if function.Signature != "" {
stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.ReplaceAll(function.Signature, ", ", ",\n ")))
} else if function.Arguments != "" {
stmt.WriteString(fmt.Sprintf("(%s)", function.Arguments))
}
if len(paramParts) > 0 {
stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.Join(paramParts, ",\n ")))
} else {
stmt.WriteString("()")
}
Expand Down Expand Up @@ -260,35 +236,35 @@ func functionsEqual(old, new *ir.Function) bool {
if old.Language != new.Language {
return false
}

// For RETURNS TABLE functions, the Parameters array includes TABLE output columns
// which can cause comparison issues. In this case, rely on ReturnType comparison instead.
isTableReturn := strings.HasPrefix(old.ReturnType, "TABLE(") || strings.HasPrefix(new.ReturnType, "TABLE(")

if !isTableReturn {
// For non-TABLE functions, compare using normalized Parameters array
// This ensures type aliases like "character varying" vs "varchar" are treated as equal
hasOldParams := len(old.Parameters) > 0
hasNewParams := len(new.Parameters) > 0

if hasOldParams && hasNewParams {
// Both have Parameters - compare them
return parametersEqual(old.Parameters, new.Parameters)
} else if hasOldParams || hasNewParams {
// One has Parameters, one doesn't - they're different
return false
}
if old.Volatility != new.Volatility {
return false
}

// For TABLE functions or functions without Parameters, fall back to Arguments/Signature
if old.Arguments != new.Arguments {
if old.IsStrict != new.IsStrict {
return false
}
if old.Signature != new.Signature {
if old.IsSecurityDefiner != new.IsSecurityDefiner {
return false
}

return true
// Compare using normalized Parameters array
// This ensures type aliases like "character varying" vs "varchar" are treated as equal
// For RETURNS TABLE functions, exclude TABLE mode parameters (they're in ReturnType)
// Only compare input parameters (IN, INOUT, VARIADIC, OUT)
oldInputParams := filterNonTableParameters(old.Parameters)
newInputParams := filterNonTableParameters(new.Parameters)
return parametersEqual(oldInputParams, newInputParams)
}

// filterNonTableParameters filters out TABLE mode parameters
// TABLE parameters are output columns in RETURNS TABLE() and shouldn't be compared as input parameters
func filterNonTableParameters(params []*ir.Parameter) []*ir.Parameter {
var filtered []*ir.Parameter
for _, param := range params {
if param.Mode != "TABLE" {
filtered = append(filtered, param)
}
}
return filtered
}

// parametersEqual compares two parameter arrays for equality
Expand Down
102 changes: 27 additions & 75 deletions internal/diff/procedure.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,31 +137,14 @@ func generateProcedureSQL(procedure *ir.Procedure, targetSchema string) string {
procedureName := qualifyEntityName(procedure.Schema, procedure.Name, targetSchema)
stmt.WriteString(fmt.Sprintf("CREATE OR REPLACE PROCEDURE %s", procedureName))

// Add parameters - prefer structured Parameters array, then signature, then arguments
if len(procedure.Parameters) > 0 {
// Build parameter list from structured Parameters array
// Always include mode explicitly (matching pg_dump behavior)
var paramParts []string
for _, param := range procedure.Parameters {
paramParts = append(paramParts, formatParameterString(param, true))
}
if len(paramParts) > 0 {
stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.Join(paramParts, ",\n ")))
} else {
stmt.WriteString("()")
}
} else if procedure.Signature != "" {
// Use detailed signature if available
stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.ReplaceAll(procedure.Signature, ", ", ",\n ")))
} else if procedure.Arguments != "" {
// Format Arguments field with newlines if it contains multiple parameters
args := procedure.Arguments
if strings.Contains(args, ", ") {
args = strings.ReplaceAll(args, ", ", ",\n ")
stmt.WriteString(fmt.Sprintf("(\n %s\n)", args))
} else {
stmt.WriteString(fmt.Sprintf("(%s)", args))
}
// Add parameters from structured Parameters array
// Always include mode explicitly (matching pg_dump behavior)
var paramParts []string
for _, param := range procedure.Parameters {
paramParts = append(paramParts, formatParameterString(param, true))
}
if len(paramParts) > 0 {
stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.Join(paramParts, ",\n ")))
} else {
stmt.WriteString("()")
}
Expand Down Expand Up @@ -246,64 +229,33 @@ func proceduresEqual(old, new *ir.Procedure) bool {
if old.Language != new.Language {
return false
}
if old.Arguments != new.Arguments {
return false
}
if old.Signature != new.Signature {

// Compare using normalized Parameters array instead of Signature
// This ensures proper comparison regardless of how parameters are specified
hasOldParams := len(old.Parameters) > 0
hasNewParams := len(new.Parameters) > 0

if hasOldParams && hasNewParams {
// Both have Parameters - compare them
return parametersEqual(old.Parameters, new.Parameters)
} else if hasOldParams || hasNewParams {
// One has Parameters, one doesn't - they're different
return false
}

// Both have no parameters - they're equal
return true
}

// formatProcedureParametersForDrop formats procedure parameters for DROP PROCEDURE statements
// Returns the full parameter signature including mode and name (e.g., "IN order_id integer, IN amount numeric")
// This is necessary for proper procedure identification in PostgreSQL
func formatProcedureParametersForDrop(procedure *ir.Procedure) string {
// First, try to use the structured Parameters array if available
if len(procedure.Parameters) > 0 {
var paramParts []string
for _, param := range procedure.Parameters {
// Use helper function with includeDefault=false for DROP statements
paramParts = append(paramParts, formatParameterString(param, false))
}
return strings.Join(paramParts, ", ")
}

// Fallback to Signature field if Parameters not available
if procedure.Signature != "" {
// Signature should already have the mode information
// Just need to remove DEFAULT clauses
var paramParts []string
params := strings.Split(procedure.Signature, ",")
for _, param := range params {
param = strings.TrimSpace(param)
// Remove DEFAULT clauses
if idx := strings.Index(param, " DEFAULT "); idx != -1 {
param = param[:idx]
}
paramParts = append(paramParts, param)
}
return strings.Join(paramParts, ", ")
// Use the structured Parameters array
var paramParts []string
for _, param := range procedure.Parameters {
// Use helper function with includeDefault=false for DROP statements
paramParts = append(paramParts, formatParameterString(param, false))
}

// Last resort: try to parse Arguments field and add IN mode
if procedure.Arguments != "" {
var paramParts []string
params := strings.Split(procedure.Arguments, ",")
for _, param := range params {
param = strings.TrimSpace(param)
// Remove DEFAULT clauses
if idx := strings.Index(param, " DEFAULT "); idx != -1 {
param = param[:idx]
}
// Add IN mode prefix if not already present
if !strings.HasPrefix(param, "IN ") && !strings.HasPrefix(param, "OUT ") && !strings.HasPrefix(param, "INOUT ") {
param = "IN " + param
}
paramParts = append(paramParts, param)
}
return strings.Join(paramParts, ", ")
}

return ""
return strings.Join(paramParts, ", ")
}
10 changes: 0 additions & 10 deletions ir/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,6 @@ func (i *Inspector) buildFunctions(ctx context.Context, schema *IR, targetSchema
if fn.FunctionComment.Valid {
comment = fn.FunctionComment.String
}
arguments := i.safeInterfaceToString(fn.FunctionArguments)
signature := i.safeInterfaceToString(fn.FunctionSignature)

// Check if function should be ignored
Expand Down Expand Up @@ -1157,8 +1156,6 @@ func (i *Inspector) buildFunctions(ctx context.Context, schema *IR, targetSchema
Definition: definition,
ReturnType: i.safeInterfaceToString(fn.DataType),
Language: i.safeInterfaceToString(fn.ExternalLanguage),
Arguments: arguments,
Signature: signature,
Comment: comment,
Parameters: parameters,
Volatility: volatility,
Expand Down Expand Up @@ -1339,7 +1336,6 @@ func (i *Inspector) buildProcedures(ctx context.Context, schema *IR, targetSchem
if proc.ProcedureComment.Valid {
comment = proc.ProcedureComment.String
}
arguments := i.safeInterfaceToString(proc.ProcedureArguments)
signature := i.safeInterfaceToString(proc.ProcedureSignature)

// Check if procedure should be ignored
Expand All @@ -1360,8 +1356,6 @@ func (i *Inspector) buildProcedures(ctx context.Context, schema *IR, targetSchem
Name: procedureName,
Definition: definition,
Language: i.safeInterfaceToString(proc.ExternalLanguage),
Arguments: arguments,
Signature: signature,
Comment: comment,
Parameters: parameters,
}
Expand All @@ -1385,8 +1379,6 @@ func (i *Inspector) buildAggregates(ctx context.Context, schema *IR, targetSchem
if agg.AggregateComment.Valid {
comment = agg.AggregateComment.String
}
arguments := i.safeInterfaceToString(agg.AggregateArguments)
signature := i.safeInterfaceToString(agg.AggregateSignature)
returnType := i.safeInterfaceToString(agg.AggregateReturnType)
transitionFunction := i.safeInterfaceToString(agg.TransitionFunction)
transitionFunctionSchema := i.safeInterfaceToString(agg.TransitionFunctionSchema)
Expand All @@ -1400,8 +1392,6 @@ func (i *Inspector) buildAggregates(ctx context.Context, schema *IR, targetSchem
aggregate := &Aggregate{
Schema: schemaName,
Name: aggregateName,
Arguments: arguments,
Signature: signature,
ReturnType: returnType,
TransitionFunction: transitionFunction,
TransitionFunctionSchema: transitionFunctionSchema,
Expand Down
31 changes: 24 additions & 7 deletions ir/ir.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package ir

import "sync"
import (
"strings"
"sync"
)

// IR represents the complete database schema intermediate representation
type IR struct {
Expand Down Expand Up @@ -125,15 +128,33 @@ type Function struct {
Definition string `json:"definition"`
ReturnType string `json:"return_type"`
Language string `json:"language"`
Arguments string `json:"arguments,omitempty"`
Signature string `json:"signature,omitempty"`
Parameters []*Parameter `json:"parameters,omitempty"`
Comment string `json:"comment,omitempty"`
Volatility string `json:"volatility,omitempty"` // IMMUTABLE, STABLE, VOLATILE
IsStrict bool `json:"is_strict,omitempty"` // STRICT or null behavior
IsSecurityDefiner bool `json:"is_security_definer,omitempty"` // SECURITY DEFINER
}

// GetArguments returns the function arguments string (types only) for function identification.
// This is built dynamically from the Parameters array to ensure it uses normalized types.
// Per PostgreSQL DROP FUNCTION syntax, only input parameters are included (IN, INOUT, VARIADIC).
func (f *Function) GetArguments() string {
if len(f.Parameters) == 0 {
return ""
}

var argTypes []string
for _, param := range f.Parameters {
// Include only input parameter modes for DROP FUNCTION compatibility
// Exclude OUT and TABLE mode parameters (they're part of return signature)
if param.Mode == "" || param.Mode == "IN" || param.Mode == "INOUT" || param.Mode == "VARIADIC" {
argTypes = append(argTypes, param.DataType)
}
}

return strings.Join(argTypes, ", ")
}

// Parameter represents a function parameter
type Parameter struct {
Name string `json:"name"`
Expand Down Expand Up @@ -336,8 +357,6 @@ type Type struct {
type Aggregate struct {
Schema string `json:"schema"`
Name string `json:"name"`
Arguments string `json:"arguments,omitempty"`
Signature string `json:"signature,omitempty"`
ReturnType string `json:"return_type"`
TransitionFunction string `json:"transition_function"`
TransitionFunctionSchema string `json:"transition_function_schema,omitempty"`
Expand All @@ -354,8 +373,6 @@ type Procedure struct {
Name string `json:"name"`
Definition string `json:"definition"`
Language string `json:"language"`
Arguments string `json:"arguments,omitempty"`
Signature string `json:"signature,omitempty"`
Parameters []*Parameter `json:"parameters,omitempty"`
Comment string `json:"comment,omitempty"`
}
Expand Down
Loading