Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 85 additions & 50 deletions internal/diff/procedure.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ func generateModifyProceduresSQL(diffs []*procedureDiff, targetSchema string, co
procedureName := qualifyEntityName(diff.Old.Schema, diff.Old.Name, targetSchema)
var dropSQL string

// For DROP statements, we need just the parameter types, not names
paramTypes := extractParameterTypes(diff.Old)
if paramTypes != "" {
dropSQL = fmt.Sprintf("DROP PROCEDURE IF EXISTS %s(%s);", procedureName, paramTypes)
// For DROP statements, we need the full parameter signature including modes and names
paramSignature := formatProcedureParametersForDrop(diff.Old)
if paramSignature != "" {
dropSQL = fmt.Sprintf("DROP PROCEDURE IF EXISTS %s(%s);", procedureName, paramSignature)
} else {
dropSQL = fmt.Sprintf("DROP PROCEDURE IF EXISTS %s();", procedureName)
}
Expand Down Expand Up @@ -88,11 +88,11 @@ func generateDropProceduresSQL(procedures []*ir.Procedure, targetSchema string,
procedureName := qualifyEntityName(procedure.Schema, procedure.Name, targetSchema)
var sql string

// For DROP statements, we need just the parameter types, not names
// Extract types from the arguments/signature
paramTypes := extractParameterTypes(procedure)
if paramTypes != "" {
sql = fmt.Sprintf("DROP PROCEDURE IF EXISTS %s(%s);", procedureName, paramTypes)
// For DROP statements, we need the full parameter signature including modes and names
// Extract the complete signature from the procedure
paramSignature := formatProcedureParametersForDrop(procedure)
if paramSignature != "" {
sql = fmt.Sprintf("DROP PROCEDURE IF EXISTS %s(%s);", procedureName, paramSignature)
} else {
sql = fmt.Sprintf("DROP PROCEDURE IF EXISTS %s();", procedureName)
}
Expand All @@ -110,6 +110,29 @@ func generateDropProceduresSQL(procedures []*ir.Procedure, targetSchema string,
}
}

// formatParameterString formats a single parameter with mode, name, type, and optional default value
// includeDefault controls whether DEFAULT clauses are included in the output
func formatParameterString(param *ir.Parameter, includeDefault bool) string {
var part string
// Always include mode for clarity (IN is default but we make it explicit)
if param.Mode != "" {
part = param.Mode + " "
} else {
part = "IN "
}
// Add parameter name and type
if param.Name != "" {
part += param.Name + " " + param.DataType
} else {
part += param.DataType
}
// Add DEFAULT value if present and requested
if includeDefault && param.DefaultValue != nil {
part += " DEFAULT " + *param.DefaultValue
}
return part
}

// generateProcedureSQL generates CREATE OR REPLACE PROCEDURE SQL for a procedure
func generateProcedureSQL(procedure *ir.Procedure, targetSchema string) string {
var stmt strings.Builder
Expand All @@ -118,8 +141,21 @@ func generateProcedureSQL(procedure *ir.Procedure, targetSchema string) string {
procedureName := qualifyEntityName(procedure.Schema, procedure.Name, targetSchema)
stmt.WriteString(fmt.Sprintf("CREATE OR REPLACE PROCEDURE %s", procedureName))

// Add parameters using detailed signature if available
if procedure.Signature != "" {
// Add parameters - prefer structured Parameters array, then signature, then arguments
if len(procedure.Parameters) > 0 {
// Build parameter list from structured Parameters array
// Always include mode explicitly (matching pg_dump behavior)
var paramParts []string
for _, param := range procedure.Parameters {
paramParts = append(paramParts, formatParameterString(param, true))
}
if len(paramParts) > 0 {
stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.Join(paramParts, ",\n ")))
} else {
stmt.WriteString("()")
}
} else if procedure.Signature != "" {
// Use detailed signature if available
stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.ReplaceAll(procedure.Signature, ", ", ",\n ")))
} else if procedure.Arguments != "" {
// Format Arguments field with newlines if it contains multiple parameters
Expand Down Expand Up @@ -223,55 +259,54 @@ func proceduresEqual(old, new *ir.Procedure) bool {
return true
}

// extractParameterTypes extracts just the parameter types from a procedure's signature or arguments
// For example: "order_id integer, amount numeric" becomes "integer, numeric"
func extractParameterTypes(procedure *ir.Procedure) string {
// Try to use Arguments field first as it should contain just types
if procedure.Arguments != "" {
// If Arguments contains parameter names (e.g., "order_id integer, amount numeric"),
// extract just the types
args := procedure.Arguments
if strings.Contains(args, " ") {
// This suggests parameter names are included, extract types
var types []string
params := strings.Split(args, ",")
for _, param := range params {
param = strings.TrimSpace(param)
// Split by spaces and take the last part (the type)
parts := strings.Fields(param)
if len(parts) >= 2 {
// Take the type (usually the second part: "name type")
types = append(types, parts[1])
} else if len(parts) == 1 {
// If only one part, assume it's the type
types = append(types, parts[0])
}
}
return strings.Join(types, ", ")
// formatProcedureParametersForDrop formats procedure parameters for DROP PROCEDURE statements
// Returns the full parameter signature including mode and name (e.g., "IN order_id integer, IN amount numeric")
// This is necessary for proper procedure identification in PostgreSQL
func formatProcedureParametersForDrop(procedure *ir.Procedure) string {
// First, try to use the structured Parameters array if available
if len(procedure.Parameters) > 0 {
var paramParts []string
for _, param := range procedure.Parameters {
// Use helper function with includeDefault=false for DROP statements
paramParts = append(paramParts, formatParameterString(param, false))
}
// If no spaces, assume Arguments already contains just types
return args
return strings.Join(paramParts, ", ")
}

// Fallback to Signature field
// Fallback to Signature field if Parameters not available
if procedure.Signature != "" {
var types []string
// Signature should already have the mode information
// Just need to remove DEFAULT clauses
var paramParts []string
params := strings.Split(procedure.Signature, ",")
for _, param := range params {
param = strings.TrimSpace(param)
// Remove DEFAULT clauses and extract type
if strings.Contains(param, " DEFAULT ") {
param = strings.Split(param, " DEFAULT ")[0]
// Remove DEFAULT clauses
if idx := strings.Index(param, " DEFAULT "); idx != -1 {
param = param[:idx]
}
paramParts = append(paramParts, param)
}
return strings.Join(paramParts, ", ")
}

// Last resort: try to parse Arguments field and add IN mode
if procedure.Arguments != "" {
var paramParts []string
params := strings.Split(procedure.Arguments, ",")
for _, param := range params {
param = strings.TrimSpace(param)
// Remove DEFAULT clauses
if idx := strings.Index(param, " DEFAULT "); idx != -1 {
param = param[:idx]
}
// Split by spaces and take the type part
parts := strings.Fields(param)
if len(parts) >= 2 {
types = append(types, parts[1])
} else if len(parts) == 1 {
types = append(types, parts[0])
// Add IN mode prefix if not already present
if !strings.HasPrefix(param, "IN ") && !strings.HasPrefix(param, "OUT ") && !strings.HasPrefix(param, "INOUT ") {
param = "IN " + param
}
paramParts = append(paramParts, param)
}
return strings.Join(types, ", ")
return strings.Join(paramParts, ", ")
}

return ""
Expand Down
33 changes: 24 additions & 9 deletions ir/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,7 @@ func (i *Inspector) buildFunctions(ctx context.Context, schema *IR, targetSchema

// parseParametersFromSignature parses function signature string into Parameter structs
// Example signature: "order_id integer, discount_percent numeric DEFAULT 0"
// Or with modes: "IN order_id integer, OUT result integer"
func (i *Inspector) parseParametersFromSignature(signature string) []*Parameter {
if signature == "" {
return nil
Expand All @@ -1187,11 +1188,11 @@ func (i *Inspector) parseParametersFromSignature(signature string) []*Parameter
}

param := &Parameter{
Mode: "IN", // Default mode for inspector
Mode: "IN", // Default mode
Position: position,
}

// Look for DEFAULT clause
// Look for DEFAULT clause first
defaultIdx := strings.Index(strings.ToUpper(paramStr), " DEFAULT ")
if defaultIdx != -1 {
// Extract default value
Expand All @@ -1200,14 +1201,28 @@ func (i *Inspector) parseParametersFromSignature(signature string) []*Parameter
paramStr = strings.TrimSpace(paramStr[:defaultIdx])
}

// Split into name and type
// Split into parts and check for mode prefix
parts := strings.Fields(paramStr)
if len(parts) >= 2 {
param.Name = parts[0]
param.DataType = strings.Join(parts[1:], " ")
} else if len(parts) == 1 {
// Only type, no name (shouldn't happen but handle gracefully)
param.DataType = parts[0]
if len(parts) == 0 {
continue
}

// Check if first part is a mode keyword (IN, OUT, INOUT, VARIADIC, TABLE)
firstPart := strings.ToUpper(parts[0])
startIdx := 0
if firstPart == "IN" || firstPart == "OUT" || firstPart == "INOUT" || firstPart == "VARIADIC" || firstPart == "TABLE" {
param.Mode = firstPart
startIdx = 1
}

// Parse name and type from remaining parts
remainingParts := parts[startIdx:]
if len(remainingParts) >= 2 {
param.Name = remainingParts[0]
param.DataType = strings.Join(remainingParts[1:], " ")
} else if len(remainingParts) == 1 {
// Only type, no name
param.DataType = remainingParts[0]
}

parameters = append(parameters, param)
Expand Down
27 changes: 27 additions & 0 deletions ir/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -1852,10 +1852,14 @@ func (p *Parser) parseCreateProcedure(funcStmt *pg_query.CreateFunctionStmt) err
parameters := p.extractFunctionParametersFromAST(funcStmt)

// Convert parameters to argument string for Procedure struct
// Also build signature with explicit modes
var arguments string
var signature string
if len(parameters) > 0 {
var argParts []string
var sigParts []string
for _, param := range parameters {
// For Arguments field (legacy, without mode)
if param.Name != "" {
argPart := param.Name + " " + param.DataType
// Add DEFAULT value if present
Expand All @@ -1866,8 +1870,29 @@ func (p *Parser) parseCreateProcedure(funcStmt *pg_query.CreateFunctionStmt) err
} else {
argParts = append(argParts, param.DataType)
}

// For Signature field (with explicit mode)
var sigPart string
if param.Mode != "" && param.Mode != "IN" {
// Include non-default modes (OUT, INOUT, VARIADIC)
sigPart = param.Mode + " "
} else {
// Include IN mode (default) explicitly to match pg_dump behavior
sigPart = "IN "
}
Comment on lines 1876 to 1882
Copy link

Copilot AI Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment on line 1877 contradicts the implementation. The comment states 'Only include non-default modes explicitly' but the code always includes 'IN' mode explicitly in the else branch, which contradicts the 'only' statement.

Copilot uses AI. Check for mistakes.
if param.Name != "" {
sigPart += param.Name + " " + param.DataType
} else {
sigPart += param.DataType
}
// Add DEFAULT value if present
if param.DefaultValue != nil {
sigPart += " DEFAULT " + *param.DefaultValue
}
sigParts = append(sigParts, sigPart)
}
arguments = strings.Join(argParts, ", ")
signature = strings.Join(sigParts, ", ")
}

// Create procedure
Expand All @@ -1876,6 +1901,8 @@ func (p *Parser) parseCreateProcedure(funcStmt *pg_query.CreateFunctionStmt) err
Name: procName,
Language: language,
Arguments: arguments,
Signature: signature,
Parameters: parameters,
Definition: definition,
}

Expand Down
17 changes: 5 additions & 12 deletions testdata/diff/create_procedure/add_procedure/diff.sql
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
CREATE OR REPLACE PROCEDURE update_user_status(
user_id integer,
new_status text
CREATE OR REPLACE PROCEDURE example_procedure(
IN input_value integer,
OUT output_value integer
)
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE users
SET status = new_status, updated_at = NOW()
WHERE id = user_id;

IF NOT FOUND THEN
RAISE EXCEPTION 'User not found: %', user_id;
END IF;

COMMIT;
RAISE NOTICE 'Input value is: %', input_value;
output_value := input_value + 1;
END;
$$;
17 changes: 5 additions & 12 deletions testdata/diff/create_procedure/add_procedure/new.sql
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
CREATE PROCEDURE update_user_status(
user_id integer,
new_status text
CREATE PROCEDURE example_procedure(
IN input_value integer,
OUT output_value integer
)
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE users
SET status = new_status, updated_at = NOW()
WHERE id = user_id;

IF NOT FOUND THEN
RAISE EXCEPTION 'User not found: %', user_id;
END IF;

COMMIT;
RAISE NOTICE 'Input value is: %', input_value;
output_value := input_value + 1;
END;
$$;
6 changes: 3 additions & 3 deletions testdata/diff/create_procedure/add_procedure/plan.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"version": "1.0.0",
"pgschema_version": "1.1.1",
"pgschema_version": "1.2.1",
"created_at": "1970-01-01T00:00:00Z",
"source_fingerprint": {
"hash": "965b1131737c955e24c7f827c55bd78e4cb49a75adfd04229e0ba297376f5085"
Expand All @@ -9,10 +9,10 @@
{
"steps": [
{
"sql": "CREATE OR REPLACE PROCEDURE update_user_status(\n user_id integer,\n new_status text\n)\nLANGUAGE plpgsql\nAS $$\nBEGIN\n UPDATE users \n SET status = new_status, updated_at = NOW() \n WHERE id = user_id;\n \n IF NOT FOUND THEN\n RAISE EXCEPTION 'User not found: %', user_id;\n END IF;\n \n COMMIT;\nEND;\n$$;",
"sql": "CREATE OR REPLACE PROCEDURE example_procedure(\n IN input_value integer,\n OUT output_value integer\n)\nLANGUAGE plpgsql\nAS $$\nBEGIN\n RAISE NOTICE 'Input value is: %', input_value;\n output_value := input_value + 1;\nEND;\n$$;",
"type": "procedure",
"operation": "create",
"path": "public.update_user_status"
"path": "public.example_procedure"
}
]
}
Expand Down
17 changes: 5 additions & 12 deletions testdata/diff/create_procedure/add_procedure/plan.sql
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
CREATE OR REPLACE PROCEDURE update_user_status(
user_id integer,
new_status text
CREATE OR REPLACE PROCEDURE example_procedure(
IN input_value integer,
OUT output_value integer
)
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE users
SET status = new_status, updated_at = NOW()
WHERE id = user_id;

IF NOT FOUND THEN
RAISE EXCEPTION 'User not found: %', user_id;
END IF;

COMMIT;
RAISE NOTICE 'Input value is: %', input_value;
output_value := input_value + 1;
END;
$$;
Loading