From d94caf0730d10d91dd2ec80e76eeeac7b86c370f Mon Sep 17 00:00:00 2001 From: Tianzhou Date: Sun, 19 Oct 2025 22:16:28 +0800 Subject: [PATCH 1/2] fix: remove arguments and avoid normalization --- internal/diff/diff.go | 4 +- internal/diff/function.go | 35 +++----- internal/diff/procedure.go | 31 ------- ir/inspector.go | 6 -- ir/ir.go | 28 +++++- ir/normalize.go | 86 +++---------------- ir/parser.go | 47 +--------- .../create_function/add_function/diff.sql | 3 +- .../diff/create_function/add_function/new.sql | 3 +- .../create_function/add_function/plan.json | 2 +- .../create_function/add_function/plan.sql | 3 +- .../create_function/add_function/plan.txt | 3 +- .../plan.json | 2 +- .../alter_function_same_signature/plan.json | 2 +- .../create_function/drop_function/plan.json | 2 +- 15 files changed, 67 insertions(+), 190 deletions(-) diff --git a/internal/diff/diff.go b/internal/diff/diff.go index f0de5ba2..c17d7e75 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -492,7 +492,7 @@ func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string) []Diff { for _, funcName := range funcNames { function := dbSchema.Functions[funcName] // Use schema.name(arguments) as key to distinguish functions with different signatures - key := function.Schema + "." + funcName + "(" + function.Arguments + ")" + key := function.Schema + "." + funcName + "(" + function.GetArguments() + ")" oldFunctions[key] = function } } @@ -503,7 +503,7 @@ func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string) []Diff { for _, funcName := range funcNames { function := dbSchema.Functions[funcName] // Use schema.name(arguments) as key to distinguish functions with different signatures - key := function.Schema + "." + funcName + "(" + function.Arguments + ")" + key := function.Schema + "." + funcName + "(" + function.GetArguments() + ")" newFunctions[key] = function } } diff --git a/internal/diff/function.go b/internal/diff/function.go index ad09cd66..1cde7668 100644 --- a/internal/diff/function.go +++ b/internal/diff/function.go @@ -64,23 +64,8 @@ func generateDropFunctionsSQL(functions []*ir.Function, targetSchema string, col functionName := qualifyEntityName(function.Schema, function.Name, targetSchema) var sql string - // Build argument list for DROP statement using normalized Parameters array - var argsList string - if len(function.Parameters) > 0 { - // Format parameters for DROP (omit names and defaults, include only types) - // Per PostgreSQL docs, DROP FUNCTION only needs input arguments (IN, INOUT, VARIADIC) - // Exclude OUT and TABLE mode parameters as they're part of the return signature - var argTypes []string - for _, param := range function.Parameters { - // Include only input parameter modes: IN (empty/implicit), INOUT, VARIADIC - if param.Mode == "" || param.Mode == "IN" || param.Mode == "INOUT" || param.Mode == "VARIADIC" { - argTypes = append(argTypes, param.DataType) - } - } - argsList = strings.Join(argTypes, ", ") - } else if function.Arguments != "" { - argsList = function.Arguments - } + // Build argument list for DROP statement using GetArguments() + argsList := function.GetArguments() if argsList != "" { sql = fmt.Sprintf("DROP FUNCTION IF EXISTS %s(%s);", functionName, argsList) @@ -126,8 +111,6 @@ func generateFunctionSQL(function *ir.Function, targetSchema string) string { } } else if function.Signature != "" { stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.ReplaceAll(function.Signature, ", ", ",\n "))) - } else if function.Arguments != "" { - stmt.WriteString(fmt.Sprintf("(%s)", function.Arguments)) } else { stmt.WriteString("()") } @@ -260,6 +243,15 @@ func functionsEqual(old, new *ir.Function) bool { if old.Language != new.Language { return false } + if old.Volatility != new.Volatility { + return false + } + if old.IsStrict != new.IsStrict { + return false + } + if old.IsSecurityDefiner != new.IsSecurityDefiner { + return false + } // For RETURNS TABLE functions, the Parameters array includes TABLE output columns // which can cause comparison issues. In this case, rely on ReturnType comparison instead. @@ -280,10 +272,7 @@ func functionsEqual(old, new *ir.Function) bool { } } - // For TABLE functions or functions without Parameters, fall back to Arguments/Signature - if old.Arguments != new.Arguments { - return false - } + // For TABLE functions or functions without Parameters, fall back to Signature comparison if old.Signature != new.Signature { return false } diff --git a/internal/diff/procedure.go b/internal/diff/procedure.go index 8238f683..af88f203 100644 --- a/internal/diff/procedure.go +++ b/internal/diff/procedure.go @@ -153,15 +153,6 @@ func generateProcedureSQL(procedure *ir.Procedure, targetSchema string) string { } else if procedure.Signature != "" { // Use detailed signature if available stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.ReplaceAll(procedure.Signature, ", ", ",\n "))) - } else if procedure.Arguments != "" { - // Format Arguments field with newlines if it contains multiple parameters - args := procedure.Arguments - if strings.Contains(args, ", ") { - args = strings.ReplaceAll(args, ", ", ",\n ") - stmt.WriteString(fmt.Sprintf("(\n %s\n)", args)) - } else { - stmt.WriteString(fmt.Sprintf("(%s)", args)) - } } else { stmt.WriteString("()") } @@ -246,9 +237,6 @@ func proceduresEqual(old, new *ir.Procedure) bool { if old.Language != new.Language { return false } - if old.Arguments != new.Arguments { - return false - } if old.Signature != new.Signature { return false } @@ -286,24 +274,5 @@ func formatProcedureParametersForDrop(procedure *ir.Procedure) string { return strings.Join(paramParts, ", ") } - // Last resort: try to parse Arguments field and add IN mode - if procedure.Arguments != "" { - var paramParts []string - params := strings.Split(procedure.Arguments, ",") - for _, param := range params { - param = strings.TrimSpace(param) - // Remove DEFAULT clauses - if idx := strings.Index(param, " DEFAULT "); idx != -1 { - param = param[:idx] - } - // Add IN mode prefix if not already present - if !strings.HasPrefix(param, "IN ") && !strings.HasPrefix(param, "OUT ") && !strings.HasPrefix(param, "INOUT ") { - param = "IN " + param - } - paramParts = append(paramParts, param) - } - return strings.Join(paramParts, ", ") - } - return "" } diff --git a/ir/inspector.go b/ir/inspector.go index 338fb0a2..69622873 100644 --- a/ir/inspector.go +++ b/ir/inspector.go @@ -1120,7 +1120,6 @@ func (i *Inspector) buildFunctions(ctx context.Context, schema *IR, targetSchema if fn.FunctionComment.Valid { comment = fn.FunctionComment.String } - arguments := i.safeInterfaceToString(fn.FunctionArguments) signature := i.safeInterfaceToString(fn.FunctionSignature) // Check if function should be ignored @@ -1157,7 +1156,6 @@ func (i *Inspector) buildFunctions(ctx context.Context, schema *IR, targetSchema Definition: definition, ReturnType: i.safeInterfaceToString(fn.DataType), Language: i.safeInterfaceToString(fn.ExternalLanguage), - Arguments: arguments, Signature: signature, Comment: comment, Parameters: parameters, @@ -1339,7 +1337,6 @@ func (i *Inspector) buildProcedures(ctx context.Context, schema *IR, targetSchem if proc.ProcedureComment.Valid { comment = proc.ProcedureComment.String } - arguments := i.safeInterfaceToString(proc.ProcedureArguments) signature := i.safeInterfaceToString(proc.ProcedureSignature) // Check if procedure should be ignored @@ -1360,7 +1357,6 @@ func (i *Inspector) buildProcedures(ctx context.Context, schema *IR, targetSchem Name: procedureName, Definition: definition, Language: i.safeInterfaceToString(proc.ExternalLanguage), - Arguments: arguments, Signature: signature, Comment: comment, Parameters: parameters, @@ -1385,7 +1381,6 @@ func (i *Inspector) buildAggregates(ctx context.Context, schema *IR, targetSchem if agg.AggregateComment.Valid { comment = agg.AggregateComment.String } - arguments := i.safeInterfaceToString(agg.AggregateArguments) signature := i.safeInterfaceToString(agg.AggregateSignature) returnType := i.safeInterfaceToString(agg.AggregateReturnType) transitionFunction := i.safeInterfaceToString(agg.TransitionFunction) @@ -1400,7 +1395,6 @@ func (i *Inspector) buildAggregates(ctx context.Context, schema *IR, targetSchem aggregate := &Aggregate{ Schema: schemaName, Name: aggregateName, - Arguments: arguments, Signature: signature, ReturnType: returnType, TransitionFunction: transitionFunction, diff --git a/ir/ir.go b/ir/ir.go index dffcff91..b29e4dee 100644 --- a/ir/ir.go +++ b/ir/ir.go @@ -1,6 +1,9 @@ package ir -import "sync" +import ( + "strings" + "sync" +) // IR represents the complete database schema intermediate representation type IR struct { @@ -125,7 +128,6 @@ type Function struct { Definition string `json:"definition"` ReturnType string `json:"return_type"` Language string `json:"language"` - Arguments string `json:"arguments,omitempty"` Signature string `json:"signature,omitempty"` Parameters []*Parameter `json:"parameters,omitempty"` Comment string `json:"comment,omitempty"` @@ -134,6 +136,26 @@ type Function struct { IsSecurityDefiner bool `json:"is_security_definer,omitempty"` // SECURITY DEFINER } +// GetArguments returns the function arguments string (types only) for function identification. +// This is built dynamically from the Parameters array to ensure it uses normalized types. +// Per PostgreSQL DROP FUNCTION syntax, only input parameters are included (IN, INOUT, VARIADIC). +func (f *Function) GetArguments() string { + if len(f.Parameters) == 0 { + return "" + } + + var argTypes []string + for _, param := range f.Parameters { + // Include only input parameter modes for DROP FUNCTION compatibility + // Exclude OUT and TABLE mode parameters (they're part of return signature) + if param.Mode == "" || param.Mode == "IN" || param.Mode == "INOUT" || param.Mode == "VARIADIC" { + argTypes = append(argTypes, param.DataType) + } + } + + return strings.Join(argTypes, ", ") +} + // Parameter represents a function parameter type Parameter struct { Name string `json:"name"` @@ -336,7 +358,6 @@ type Type struct { type Aggregate struct { Schema string `json:"schema"` Name string `json:"name"` - Arguments string `json:"arguments,omitempty"` Signature string `json:"signature,omitempty"` ReturnType string `json:"return_type"` TransitionFunction string `json:"transition_function"` @@ -354,7 +375,6 @@ type Procedure struct { Name string `json:"name"` Definition string `json:"definition"` Language string `json:"language"` - Arguments string `json:"arguments,omitempty"` Signature string `json:"signature,omitempty"` Parameters []*Parameter `json:"parameters,omitempty"` Comment string `json:"comment,omitempty"` diff --git a/ir/normalize.go b/ir/normalize.go index 0d62029b..fbb6d76d 100644 --- a/ir/normalize.go +++ b/ir/normalize.go @@ -240,10 +240,15 @@ func normalizeFunction(function *Function) { function.Language = strings.ToLower(function.Language) // Normalize return type to handle PostgreSQL-specific formats function.ReturnType = normalizeFunctionReturnType(function.ReturnType) - // Normalize parameter types + // Normalize parameter types and default values for _, param := range function.Parameters { if param != nil { param.DataType = normalizePostgreSQLType(param.DataType) + // Normalize default values + if param.DefaultValue != nil { + normalized := normalizeDefaultValue(*param.DefaultValue) + param.DefaultValue = &normalized + } } } // Normalize function body to handle whitespace differences @@ -278,80 +283,17 @@ func normalizeProcedure(procedure *Procedure) { // Normalize language to lowercase (PLPGSQL → plpgsql) procedure.Language = strings.ToLower(procedure.Language) - // Normalize arguments field when signature is present - // Inspector provides: Arguments: "integer, text", Signature: "IN user_id integer, IN new_status text" - // Parser provides: Arguments: "user_id integer, new_status text", no Signature - // We need to make inspector match parser format - if procedure.Signature != "" && procedure.Arguments != "" { - // Extract parameter names and types from signature - procedure.Arguments = normalizeProcedureArguments(procedure.Signature) - // Clear signature as parser doesn't set it - procedure.Signature = "" - } -} - -// normalizeProcedureArguments extracts parameter names and types from a procedure signature -func normalizeProcedureArguments(signature string) string { - if signature == "" { - return "" - } - - // Parse signature like "IN user_id integer, IN new_status text" - // to "user_id integer, new_status text" - params := strings.Split(signature, ",") - var normalizedParams []string - - for _, param := range params { - param = strings.TrimSpace(param) - if param == "" { - continue - } - - // Remove IN/OUT/INOUT modifiers - param = regexp.MustCompile(`^(IN|OUT|INOUT)\s+`).ReplaceAllString(param, "") - - // Handle DEFAULT values - need to remove redundant type casts - if strings.Contains(param, " DEFAULT ") { - parts := strings.Split(param, " DEFAULT ") - if len(parts) == 2 { - // Parse the parameter name and type - paramDef := strings.TrimSpace(parts[0]) - defaultValue := strings.TrimSpace(parts[1]) - - // Remove redundant type casts from string literals - // e.g., 'credit_card'::text -> 'credit_card' - defaultValue = regexp.MustCompile(`'([^']+)'::text\b`).ReplaceAllString(defaultValue, "'$1'") - - param = paramDef + " DEFAULT " + defaultValue - } - } - - // Extract name and type - fields := strings.Fields(param) - if len(fields) >= 2 { - // Check if this contains DEFAULT - defaultIdx := -1 - for i, field := range fields { - if field == "DEFAULT" { - defaultIdx = i - break - } - } - - if defaultIdx > 0 && defaultIdx >= 2 { - // Format as "name type DEFAULT value" - name := fields[0] - typeStr := strings.Join(fields[1:defaultIdx], " ") - defaultStr := strings.Join(fields[defaultIdx:], " ") - normalizedParams = append(normalizedParams, name+" "+typeStr+" "+defaultStr) - } else { - // Format as "name type" - normalizedParams = append(normalizedParams, fields[0]+" "+strings.Join(fields[1:], " ")) + // Normalize parameter types and default values + for _, param := range procedure.Parameters { + if param != nil { + param.DataType = normalizePostgreSQLType(param.DataType) + // Normalize default values + if param.DefaultValue != nil { + normalized := normalizeDefaultValue(*param.DefaultValue) + param.DefaultValue = &normalized } } } - - return strings.Join(normalizedParams, ", ") } // normalizeFunctionSignature normalizes function signatures for consistent comparison diff --git a/ir/parser.go b/ir/parser.go index ddd3f452..77b3eab9 100644 --- a/ir/parser.go +++ b/ir/parser.go @@ -1773,8 +1773,7 @@ func (p *Parser) parseCreateFunction(funcStmt *pg_query.CreateFunctionStmt) erro definition := p.extractFunctionDefinitionFromAST(funcStmt) parameters := p.extractFunctionParametersFromAST(funcStmt) - // Build Arguments and Signature strings from parameters - var argParts []string + // Build Signature string from parameters var sigParts []string for _, param := range parameters { @@ -1784,9 +1783,6 @@ func (p *Parser) parseCreateFunction(funcStmt *pg_query.CreateFunctionStmt) erro continue } - // Arguments string (for function identification) - types only - argParts = append(argParts, param.DataType) - // Signature string (for CREATE statement) - names and types if param.Name != "" { sigPart := fmt.Sprintf("%s %s", param.Name, param.DataType) @@ -1800,7 +1796,6 @@ func (p *Parser) parseCreateFunction(funcStmt *pg_query.CreateFunctionStmt) erro } } - arguments := strings.Join(argParts, ", ") signature := strings.Join(sigParts, ", ") // Extract function options (volatility, security, strict) @@ -1815,7 +1810,6 @@ func (p *Parser) parseCreateFunction(funcStmt *pg_query.CreateFunctionStmt) erro Definition: definition, ReturnType: returnType, Language: language, - Arguments: arguments, Signature: signature, Parameters: parameters, Volatility: volatility, @@ -1866,26 +1860,11 @@ func (p *Parser) parseCreateProcedure(funcStmt *pg_query.CreateFunctionStmt) err definition := p.extractFunctionDefinitionFromAST(funcStmt) parameters := p.extractFunctionParametersFromAST(funcStmt) - // Convert parameters to argument string for Procedure struct - // Also build signature with explicit modes - var arguments string + // Build signature with explicit modes var signature string if len(parameters) > 0 { - var argParts []string var sigParts []string for _, param := range parameters { - // For Arguments field (legacy, without mode) - if param.Name != "" { - argPart := param.Name + " " + param.DataType - // Add DEFAULT value if present - if param.DefaultValue != nil { - argPart += " DEFAULT " + *param.DefaultValue - } - argParts = append(argParts, argPart) - } else { - argParts = append(argParts, param.DataType) - } - // For Signature field (with explicit mode) var sigPart string if param.Mode != "" && param.Mode != "IN" { @@ -1906,7 +1885,6 @@ func (p *Parser) parseCreateProcedure(funcStmt *pg_query.CreateFunctionStmt) err } sigParts = append(sigParts, sigPart) } - arguments = strings.Join(argParts, ", ") signature = strings.Join(sigParts, ", ") } @@ -1915,7 +1893,6 @@ func (p *Parser) parseCreateProcedure(funcStmt *pg_query.CreateFunctionStmt) err Schema: schemaName, Name: procName, Language: language, - Arguments: arguments, Signature: signature, Parameters: parameters, Definition: definition, @@ -3128,24 +3105,7 @@ func (p *Parser) parseCreateAggregate(defineStmt *pg_query.DefineStmt) error { // Get or create schema dbSchema := p.schema.getOrCreateSchema(schemaName) - // Extract aggregate arguments - var arguments string - if len(defineStmt.Args) > 0 { - if listNode := defineStmt.Args[0].GetList(); listNode != nil { - var argTypes []string - for _, item := range listNode.Items { - if funcParam := item.GetFunctionParameter(); funcParam != nil { - if funcParam.ArgType != nil { - argType := p.parseTypeName(funcParam.ArgType) - argTypes = append(argTypes, argType) - } - } - } - if len(argTypes) > 0 { - arguments = argTypes[0] // For now, just take the first argument type - } - } - } + // Arguments field has been removed - aggregates will use Signature field if needed // Extract aggregate options from definition var stateFunction string @@ -3183,7 +3143,6 @@ func (p *Parser) parseCreateAggregate(defineStmt *pg_query.DefineStmt) error { aggregate := &Aggregate{ Schema: schemaName, Name: aggregateName, - Arguments: arguments, ReturnType: returnType, StateType: stateType, TransitionFunction: stateFunction, diff --git a/testdata/diff/create_function/add_function/diff.sql b/testdata/diff/create_function/add_function/diff.sql index 98e6f79a..b0129734 100644 --- a/testdata/diff/create_function/add_function/diff.sql +++ b/testdata/diff/create_function/add_function/diff.sql @@ -1,6 +1,7 @@ CREATE OR REPLACE FUNCTION process_order( order_id integer, - discount_percent numeric DEFAULT 0 + discount_percent numeric DEFAULT 0, + note varchar DEFAULT '' ) RETURNS numeric LANGUAGE plpgsql diff --git a/testdata/diff/create_function/add_function/new.sql b/testdata/diff/create_function/add_function/new.sql index 843b3d24..a430daee 100644 --- a/testdata/diff/create_function/add_function/new.sql +++ b/testdata/diff/create_function/add_function/new.sql @@ -1,6 +1,7 @@ CREATE FUNCTION process_order( order_id integer, - discount_percent numeric DEFAULT 0 + discount_percent numeric DEFAULT 0, + note varchar DEFAULT '' ) RETURNS numeric LANGUAGE plpgsql diff --git a/testdata/diff/create_function/add_function/plan.json b/testdata/diff/create_function/add_function/plan.json index 4322bfb9..1fc13b93 100644 --- a/testdata/diff/create_function/add_function/plan.json +++ b/testdata/diff/create_function/add_function/plan.json @@ -9,7 +9,7 @@ { "steps": [ { - "sql": "CREATE OR REPLACE FUNCTION process_order(\n order_id integer,\n discount_percent numeric DEFAULT 0\n)\nRETURNS numeric\nLANGUAGE plpgsql\nSECURITY DEFINER\nVOLATILE\nSTRICT\nAS $$\nDECLARE\n total numeric;\nBEGIN\n SELECT amount INTO total FROM orders WHERE id = order_id;\n RETURN total - (total * discount_percent / 100);\nEND;\n$$;", + "sql": "CREATE OR REPLACE FUNCTION process_order(\n order_id integer,\n discount_percent numeric DEFAULT 0,\n note varchar DEFAULT ''\n)\nRETURNS numeric\nLANGUAGE plpgsql\nSECURITY DEFINER\nVOLATILE\nSTRICT\nAS $$\nDECLARE\n total numeric;\nBEGIN\n SELECT amount INTO total FROM orders WHERE id = order_id;\n RETURN total - (total * discount_percent / 100);\nEND;\n$$;", "type": "function", "operation": "create", "path": "public.process_order" diff --git a/testdata/diff/create_function/add_function/plan.sql b/testdata/diff/create_function/add_function/plan.sql index 98e6f79a..b0129734 100644 --- a/testdata/diff/create_function/add_function/plan.sql +++ b/testdata/diff/create_function/add_function/plan.sql @@ -1,6 +1,7 @@ CREATE OR REPLACE FUNCTION process_order( order_id integer, - discount_percent numeric DEFAULT 0 + discount_percent numeric DEFAULT 0, + note varchar DEFAULT '' ) RETURNS numeric LANGUAGE plpgsql diff --git a/testdata/diff/create_function/add_function/plan.txt b/testdata/diff/create_function/add_function/plan.txt index b9441546..eca4bd93 100644 --- a/testdata/diff/create_function/add_function/plan.txt +++ b/testdata/diff/create_function/add_function/plan.txt @@ -11,7 +11,8 @@ DDL to be executed: CREATE OR REPLACE FUNCTION process_order( order_id integer, - discount_percent numeric DEFAULT 0 + discount_percent numeric DEFAULT 0, + note varchar DEFAULT '' ) RETURNS numeric LANGUAGE plpgsql diff --git a/testdata/diff/create_function/alter_function_different_signature/plan.json b/testdata/diff/create_function/alter_function_different_signature/plan.json index f9b4a91c..a5777634 100644 --- a/testdata/diff/create_function/alter_function_different_signature/plan.json +++ b/testdata/diff/create_function/alter_function_different_signature/plan.json @@ -3,7 +3,7 @@ "pgschema_version": "1.4.0", "created_at": "1970-01-01T00:00:00Z", "source_fingerprint": { - "hash": "912618f23b4d55b8abcf56d0155b6be724ffe7e43548af3391ec0b7cbcff5f6d" + "hash": "c849c80ad70c026d11e0d262b5e1a1ae96a98435b116346aed0fc2521ca0f510" }, "groups": [ { diff --git a/testdata/diff/create_function/alter_function_same_signature/plan.json b/testdata/diff/create_function/alter_function_same_signature/plan.json index 1b10169b..72b9915d 100644 --- a/testdata/diff/create_function/alter_function_same_signature/plan.json +++ b/testdata/diff/create_function/alter_function_same_signature/plan.json @@ -3,7 +3,7 @@ "pgschema_version": "1.4.0", "created_at": "1970-01-01T00:00:00Z", "source_fingerprint": { - "hash": "912618f23b4d55b8abcf56d0155b6be724ffe7e43548af3391ec0b7cbcff5f6d" + "hash": "c849c80ad70c026d11e0d262b5e1a1ae96a98435b116346aed0fc2521ca0f510" }, "groups": [ { diff --git a/testdata/diff/create_function/drop_function/plan.json b/testdata/diff/create_function/drop_function/plan.json index 32a0ce45..87ebf0bc 100644 --- a/testdata/diff/create_function/drop_function/plan.json +++ b/testdata/diff/create_function/drop_function/plan.json @@ -3,7 +3,7 @@ "pgschema_version": "1.4.0", "created_at": "1970-01-01T00:00:00Z", "source_fingerprint": { - "hash": "088aa747b7e56f2117edee2741301754f9890301077cbc83f29339c7aeccb41e" + "hash": "e886d70cfd1f1ba9b9d04415434fd246e8d2683ede276ecc59daa1e5f92f447b" }, "groups": [ { From 3a483107fb03748914c2ad7966ac81d873be0a8a Mon Sep 17 00:00:00 2001 From: Tianzhou Date: Sun, 19 Oct 2025 22:45:28 +0800 Subject: [PATCH 2/2] fix: remove signature field --- internal/diff/function.go | 63 +++++++--------- internal/diff/procedure.go | 71 +++++++------------ ir/inspector.go | 4 -- ir/ir.go | 3 - ir/normalize.go | 33 ++++----- ir/parser.go | 55 -------------- .../diff.sql | 2 +- .../new.sql | 2 +- .../old.sql | 2 +- .../plan.json | 4 +- .../plan.sql | 2 +- .../plan.txt | 2 +- .../alter_function_same_signature/plan.json | 2 +- .../create_function/drop_function/plan.json | 2 +- .../alter_procedure/plan.json | 2 +- .../create_procedure/drop_procedure/plan.json | 2 +- testdata/diff/migrate/v5/plan.json | 2 +- 17 files changed, 76 insertions(+), 177 deletions(-) diff --git a/internal/diff/function.go b/internal/diff/function.go index 1cde7668..b352ab6c 100644 --- a/internal/diff/function.go +++ b/internal/diff/function.go @@ -94,23 +94,16 @@ func generateFunctionSQL(function *ir.Function, targetSchema string) string { functionName := qualifyEntityName(function.Schema, function.Name, targetSchema) stmt.WriteString(fmt.Sprintf("CREATE OR REPLACE FUNCTION %s", functionName)) - // Add parameters - prefer structured Parameters array for normalized types - if len(function.Parameters) > 0 { - // Build parameter list from structured Parameters array - // Exclude TABLE mode parameters as they're part of RETURNS clause - var paramParts []string - for _, param := range function.Parameters { - if param.Mode != "TABLE" { - paramParts = append(paramParts, formatFunctionParameter(param, true)) - } + // Add parameters from structured Parameters array + // Exclude TABLE mode parameters as they're part of RETURNS clause + var paramParts []string + for _, param := range function.Parameters { + if param.Mode != "TABLE" { + paramParts = append(paramParts, formatFunctionParameter(param, true)) } - if len(paramParts) > 0 { - stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.Join(paramParts, ",\n "))) - } else { - stmt.WriteString("()") - } - } else if function.Signature != "" { - stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.ReplaceAll(function.Signature, ", ", ",\n "))) + } + if len(paramParts) > 0 { + stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.Join(paramParts, ",\n "))) } else { stmt.WriteString("()") } @@ -253,31 +246,25 @@ func functionsEqual(old, new *ir.Function) bool { return false } - // For RETURNS TABLE functions, the Parameters array includes TABLE output columns - // which can cause comparison issues. In this case, rely on ReturnType comparison instead. - isTableReturn := strings.HasPrefix(old.ReturnType, "TABLE(") || strings.HasPrefix(new.ReturnType, "TABLE(") - - if !isTableReturn { - // For non-TABLE functions, compare using normalized Parameters array - // This ensures type aliases like "character varying" vs "varchar" are treated as equal - hasOldParams := len(old.Parameters) > 0 - hasNewParams := len(new.Parameters) > 0 + // Compare using normalized Parameters array + // This ensures type aliases like "character varying" vs "varchar" are treated as equal + // For RETURNS TABLE functions, exclude TABLE mode parameters (they're in ReturnType) + // Only compare input parameters (IN, INOUT, VARIADIC, OUT) + oldInputParams := filterNonTableParameters(old.Parameters) + newInputParams := filterNonTableParameters(new.Parameters) + return parametersEqual(oldInputParams, newInputParams) +} - if hasOldParams && hasNewParams { - // Both have Parameters - compare them - return parametersEqual(old.Parameters, new.Parameters) - } else if hasOldParams || hasNewParams { - // One has Parameters, one doesn't - they're different - return false +// filterNonTableParameters filters out TABLE mode parameters +// TABLE parameters are output columns in RETURNS TABLE() and shouldn't be compared as input parameters +func filterNonTableParameters(params []*ir.Parameter) []*ir.Parameter { + var filtered []*ir.Parameter + for _, param := range params { + if param.Mode != "TABLE" { + filtered = append(filtered, param) } } - - // For TABLE functions or functions without Parameters, fall back to Signature comparison - if old.Signature != new.Signature { - return false - } - - return true + return filtered } // parametersEqual compares two parameter arrays for equality diff --git a/internal/diff/procedure.go b/internal/diff/procedure.go index af88f203..24a46740 100644 --- a/internal/diff/procedure.go +++ b/internal/diff/procedure.go @@ -137,22 +137,14 @@ func generateProcedureSQL(procedure *ir.Procedure, targetSchema string) string { procedureName := qualifyEntityName(procedure.Schema, procedure.Name, targetSchema) stmt.WriteString(fmt.Sprintf("CREATE OR REPLACE PROCEDURE %s", procedureName)) - // Add parameters - prefer structured Parameters array, then signature, then arguments - if len(procedure.Parameters) > 0 { - // Build parameter list from structured Parameters array - // Always include mode explicitly (matching pg_dump behavior) - var paramParts []string - for _, param := range procedure.Parameters { - paramParts = append(paramParts, formatParameterString(param, true)) - } - if len(paramParts) > 0 { - stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.Join(paramParts, ",\n "))) - } else { - stmt.WriteString("()") - } - } else if procedure.Signature != "" { - // Use detailed signature if available - stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.ReplaceAll(procedure.Signature, ", ", ",\n "))) + // Add parameters from structured Parameters array + // Always include mode explicitly (matching pg_dump behavior) + var paramParts []string + for _, param := range procedure.Parameters { + paramParts = append(paramParts, formatParameterString(param, true)) + } + if len(paramParts) > 0 { + stmt.WriteString(fmt.Sprintf("(\n %s\n)", strings.Join(paramParts, ",\n "))) } else { stmt.WriteString("()") } @@ -237,9 +229,21 @@ func proceduresEqual(old, new *ir.Procedure) bool { if old.Language != new.Language { return false } - if old.Signature != new.Signature { + + // Compare using normalized Parameters array instead of Signature + // This ensures proper comparison regardless of how parameters are specified + hasOldParams := len(old.Parameters) > 0 + hasNewParams := len(new.Parameters) > 0 + + if hasOldParams && hasNewParams { + // Both have Parameters - compare them + return parametersEqual(old.Parameters, new.Parameters) + } else if hasOldParams || hasNewParams { + // One has Parameters, one doesn't - they're different return false } + + // Both have no parameters - they're equal return true } @@ -247,32 +251,11 @@ func proceduresEqual(old, new *ir.Procedure) bool { // Returns the full parameter signature including mode and name (e.g., "IN order_id integer, IN amount numeric") // This is necessary for proper procedure identification in PostgreSQL func formatProcedureParametersForDrop(procedure *ir.Procedure) string { - // First, try to use the structured Parameters array if available - if len(procedure.Parameters) > 0 { - var paramParts []string - for _, param := range procedure.Parameters { - // Use helper function with includeDefault=false for DROP statements - paramParts = append(paramParts, formatParameterString(param, false)) - } - return strings.Join(paramParts, ", ") - } - - // Fallback to Signature field if Parameters not available - if procedure.Signature != "" { - // Signature should already have the mode information - // Just need to remove DEFAULT clauses - var paramParts []string - params := strings.Split(procedure.Signature, ",") - for _, param := range params { - param = strings.TrimSpace(param) - // Remove DEFAULT clauses - if idx := strings.Index(param, " DEFAULT "); idx != -1 { - param = param[:idx] - } - paramParts = append(paramParts, param) - } - return strings.Join(paramParts, ", ") + // Use the structured Parameters array + var paramParts []string + for _, param := range procedure.Parameters { + // Use helper function with includeDefault=false for DROP statements + paramParts = append(paramParts, formatParameterString(param, false)) } - - return "" + return strings.Join(paramParts, ", ") } diff --git a/ir/inspector.go b/ir/inspector.go index 69622873..8688f4ca 100644 --- a/ir/inspector.go +++ b/ir/inspector.go @@ -1156,7 +1156,6 @@ func (i *Inspector) buildFunctions(ctx context.Context, schema *IR, targetSchema Definition: definition, ReturnType: i.safeInterfaceToString(fn.DataType), Language: i.safeInterfaceToString(fn.ExternalLanguage), - Signature: signature, Comment: comment, Parameters: parameters, Volatility: volatility, @@ -1357,7 +1356,6 @@ func (i *Inspector) buildProcedures(ctx context.Context, schema *IR, targetSchem Name: procedureName, Definition: definition, Language: i.safeInterfaceToString(proc.ExternalLanguage), - Signature: signature, Comment: comment, Parameters: parameters, } @@ -1381,7 +1379,6 @@ func (i *Inspector) buildAggregates(ctx context.Context, schema *IR, targetSchem if agg.AggregateComment.Valid { comment = agg.AggregateComment.String } - signature := i.safeInterfaceToString(agg.AggregateSignature) returnType := i.safeInterfaceToString(agg.AggregateReturnType) transitionFunction := i.safeInterfaceToString(agg.TransitionFunction) transitionFunctionSchema := i.safeInterfaceToString(agg.TransitionFunctionSchema) @@ -1395,7 +1392,6 @@ func (i *Inspector) buildAggregates(ctx context.Context, schema *IR, targetSchem aggregate := &Aggregate{ Schema: schemaName, Name: aggregateName, - Signature: signature, ReturnType: returnType, TransitionFunction: transitionFunction, TransitionFunctionSchema: transitionFunctionSchema, diff --git a/ir/ir.go b/ir/ir.go index b29e4dee..ac33bd6a 100644 --- a/ir/ir.go +++ b/ir/ir.go @@ -128,7 +128,6 @@ type Function struct { Definition string `json:"definition"` ReturnType string `json:"return_type"` Language string `json:"language"` - Signature string `json:"signature,omitempty"` Parameters []*Parameter `json:"parameters,omitempty"` Comment string `json:"comment,omitempty"` Volatility string `json:"volatility,omitempty"` // IMMUTABLE, STABLE, VOLATILE @@ -358,7 +357,6 @@ type Type struct { type Aggregate struct { Schema string `json:"schema"` Name string `json:"name"` - Signature string `json:"signature,omitempty"` ReturnType string `json:"return_type"` TransitionFunction string `json:"transition_function"` TransitionFunctionSchema string `json:"transition_function_schema,omitempty"` @@ -375,7 +373,6 @@ type Procedure struct { Name string `json:"name"` Definition string `json:"definition"` Language string `json:"language"` - Signature string `json:"signature,omitempty"` Parameters []*Parameter `json:"parameters,omitempty"` Comment string `json:"comment,omitempty"` } diff --git a/ir/normalize.go b/ir/normalize.go index fbb6d76d..10e06700 100644 --- a/ir/normalize.go +++ b/ir/normalize.go @@ -235,15 +235,20 @@ func normalizeFunction(function *Function) { return } - function.Signature = normalizeFunctionSignature(function.Signature) // lowercase LANGUAGE plpgsql is more common in modern usage function.Language = strings.ToLower(function.Language) // Normalize return type to handle PostgreSQL-specific formats function.ReturnType = normalizeFunctionReturnType(function.ReturnType) - // Normalize parameter types and default values + // Normalize parameter types, modes, and default values for _, param := range function.Parameters { if param != nil { param.DataType = normalizePostgreSQLType(param.DataType) + // Normalize mode: empty string → "IN" for functions (PostgreSQL default) + // Functions: IN is default, only OUT/INOUT/VARIADIC need explicit mode + // But for consistent comparison, normalize empty to "IN" + if param.Mode == "" { + param.Mode = "IN" + } // Normalize default values if param.DefaultValue != nil { normalized := normalizeDefaultValue(*param.DefaultValue) @@ -283,10 +288,14 @@ func normalizeProcedure(procedure *Procedure) { // Normalize language to lowercase (PLPGSQL → plpgsql) procedure.Language = strings.ToLower(procedure.Language) - // Normalize parameter types and default values + // Normalize parameter types, modes, and default values for _, param := range procedure.Parameters { if param != nil { param.DataType = normalizePostgreSQLType(param.DataType) + // Normalize mode: empty string → "IN" for procedures (PostgreSQL default) + if param.Mode == "" { + param.Mode = "IN" + } // Normalize default values if param.DefaultValue != nil { normalized := normalizeDefaultValue(*param.DefaultValue) @@ -296,24 +305,6 @@ func normalizeProcedure(procedure *Procedure) { } } -// normalizeFunctionSignature normalizes function signatures for consistent comparison -func normalizeFunctionSignature(signature string) string { - if signature == "" { - return signature - } - - // Remove extra whitespace - signature = strings.TrimSpace(signature) - signature = regexp.MustCompile(`\s+`).ReplaceAllString(signature, " ") - - // Normalize parameter formatting - signature = regexp.MustCompile(`\(\s*`).ReplaceAllString(signature, "(") - signature = regexp.MustCompile(`\s*\)`).ReplaceAllString(signature, ")") - signature = regexp.MustCompile(`\s*,\s*`).ReplaceAllString(signature, ", ") - - return signature -} - // normalizeFunctionReturnType normalizes function return types, especially TABLE types func normalizeFunctionReturnType(returnType string) string { if returnType == "" { diff --git a/ir/parser.go b/ir/parser.go index 77b3eab9..b6dc660e 100644 --- a/ir/parser.go +++ b/ir/parser.go @@ -1773,31 +1773,6 @@ func (p *Parser) parseCreateFunction(funcStmt *pg_query.CreateFunctionStmt) erro definition := p.extractFunctionDefinitionFromAST(funcStmt) parameters := p.extractFunctionParametersFromAST(funcStmt) - // Build Signature string from parameters - var sigParts []string - - for _, param := range parameters { - // Only include input parameters (IN, INOUT, VARIADIC) in function signature - // OUT and TABLE parameters are part of RETURNS TABLE(...) and should not be in the signature - if param.Mode == "OUT" || param.Mode == "TABLE" { - continue - } - - // Signature string (for CREATE statement) - names and types - if param.Name != "" { - sigPart := fmt.Sprintf("%s %s", param.Name, param.DataType) - // Add DEFAULT value if present - if param.DefaultValue != nil { - sigPart += fmt.Sprintf(" DEFAULT %s", *param.DefaultValue) - } - sigParts = append(sigParts, sigPart) - } else { - sigParts = append(sigParts, param.DataType) - } - } - - signature := strings.Join(sigParts, ", ") - // Extract function options (volatility, security, strict) volatility := p.extractFunctionVolatilityFromAST(funcStmt) isSecurityDefiner := p.extractFunctionSecurityFromAST(funcStmt) @@ -1810,7 +1785,6 @@ func (p *Parser) parseCreateFunction(funcStmt *pg_query.CreateFunctionStmt) erro Definition: definition, ReturnType: returnType, Language: language, - Signature: signature, Parameters: parameters, Volatility: volatility, IsSecurityDefiner: isSecurityDefiner, @@ -1860,40 +1834,11 @@ func (p *Parser) parseCreateProcedure(funcStmt *pg_query.CreateFunctionStmt) err definition := p.extractFunctionDefinitionFromAST(funcStmt) parameters := p.extractFunctionParametersFromAST(funcStmt) - // Build signature with explicit modes - var signature string - if len(parameters) > 0 { - var sigParts []string - for _, param := range parameters { - // For Signature field (with explicit mode) - var sigPart string - if param.Mode != "" && param.Mode != "IN" { - // Include non-default modes (OUT, INOUT, VARIADIC) - sigPart = param.Mode + " " - } else { - // Include IN mode (default) explicitly to match pg_dump behavior - sigPart = "IN " - } - if param.Name != "" { - sigPart += param.Name + " " + param.DataType - } else { - sigPart += param.DataType - } - // Add DEFAULT value if present - if param.DefaultValue != nil { - sigPart += " DEFAULT " + *param.DefaultValue - } - sigParts = append(sigParts, sigPart) - } - signature = strings.Join(sigParts, ", ") - } - // Create procedure procedure := &Procedure{ Schema: schemaName, Name: procName, Language: language, - Signature: signature, Parameters: parameters, Definition: definition, } diff --git a/testdata/diff/create_function/alter_function_different_signature/diff.sql b/testdata/diff/create_function/alter_function_different_signature/diff.sql index 0a425ece..c65d0503 100644 --- a/testdata/diff/create_function/alter_function_different_signature/diff.sql +++ b/testdata/diff/create_function/alter_function_different_signature/diff.sql @@ -2,7 +2,7 @@ DROP FUNCTION IF EXISTS process_order(integer, numeric); CREATE OR REPLACE FUNCTION process_order( customer_email text, - priority boolean DEFAULT false + priority boolean ) RETURNS TABLE(status text, processed_at timestamp) LANGUAGE plpgsql diff --git a/testdata/diff/create_function/alter_function_different_signature/new.sql b/testdata/diff/create_function/alter_function_different_signature/new.sql index dcac0b19..c76ba855 100644 --- a/testdata/diff/create_function/alter_function_different_signature/new.sql +++ b/testdata/diff/create_function/alter_function_different_signature/new.sql @@ -1,6 +1,6 @@ CREATE FUNCTION process_order( customer_email text, - priority boolean DEFAULT false + priority boolean ) RETURNS TABLE(status text, processed_at timestamp) LANGUAGE plpgsql diff --git a/testdata/diff/create_function/alter_function_different_signature/old.sql b/testdata/diff/create_function/alter_function_different_signature/old.sql index 367cae29..b909d237 100644 --- a/testdata/diff/create_function/alter_function_different_signature/old.sql +++ b/testdata/diff/create_function/alter_function_different_signature/old.sql @@ -1,6 +1,6 @@ CREATE FUNCTION process_order( order_id integer, - discount_percent numeric DEFAULT 0 + discount_percent numeric ) RETURNS numeric LANGUAGE plpgsql diff --git a/testdata/diff/create_function/alter_function_different_signature/plan.json b/testdata/diff/create_function/alter_function_different_signature/plan.json index a5777634..63da2859 100644 --- a/testdata/diff/create_function/alter_function_different_signature/plan.json +++ b/testdata/diff/create_function/alter_function_different_signature/plan.json @@ -3,7 +3,7 @@ "pgschema_version": "1.4.0", "created_at": "1970-01-01T00:00:00Z", "source_fingerprint": { - "hash": "c849c80ad70c026d11e0d262b5e1a1ae96a98435b116346aed0fc2521ca0f510" + "hash": "6999cab9e41f75143c1f09a16bb229452a4a06cd1171eece2a7466ca8d1323d6" }, "groups": [ { @@ -15,7 +15,7 @@ "path": "public.process_order" }, { - "sql": "CREATE OR REPLACE FUNCTION process_order(\n customer_email text,\n priority boolean DEFAULT false\n)\nRETURNS TABLE(status text, processed_at timestamp)\nLANGUAGE plpgsql\nSECURITY DEFINER\nSTABLE\nAS $$\nBEGIN\n RETURN QUERY\n SELECT 'completed'::text, NOW()\n WHERE priority = true;\nEND;\n$$;", + "sql": "CREATE OR REPLACE FUNCTION process_order(\n customer_email text,\n priority boolean\n)\nRETURNS TABLE(status text, processed_at timestamp)\nLANGUAGE plpgsql\nSECURITY DEFINER\nSTABLE\nAS $$\nBEGIN\n RETURN QUERY\n SELECT 'completed'::text, NOW()\n WHERE priority = true;\nEND;\n$$;", "type": "function", "operation": "create", "path": "public.process_order" diff --git a/testdata/diff/create_function/alter_function_different_signature/plan.sql b/testdata/diff/create_function/alter_function_different_signature/plan.sql index 0a425ece..c65d0503 100644 --- a/testdata/diff/create_function/alter_function_different_signature/plan.sql +++ b/testdata/diff/create_function/alter_function_different_signature/plan.sql @@ -2,7 +2,7 @@ DROP FUNCTION IF EXISTS process_order(integer, numeric); CREATE OR REPLACE FUNCTION process_order( customer_email text, - priority boolean DEFAULT false + priority boolean ) RETURNS TABLE(status text, processed_at timestamp) LANGUAGE plpgsql diff --git a/testdata/diff/create_function/alter_function_different_signature/plan.txt b/testdata/diff/create_function/alter_function_different_signature/plan.txt index 4e312af7..c4caf7b2 100644 --- a/testdata/diff/create_function/alter_function_different_signature/plan.txt +++ b/testdata/diff/create_function/alter_function_different_signature/plan.txt @@ -14,7 +14,7 @@ DROP FUNCTION IF EXISTS process_order(integer, numeric); CREATE OR REPLACE FUNCTION process_order( customer_email text, - priority boolean DEFAULT false + priority boolean ) RETURNS TABLE(status text, processed_at timestamp) LANGUAGE plpgsql diff --git a/testdata/diff/create_function/alter_function_same_signature/plan.json b/testdata/diff/create_function/alter_function_same_signature/plan.json index 72b9915d..b2bc0d5e 100644 --- a/testdata/diff/create_function/alter_function_same_signature/plan.json +++ b/testdata/diff/create_function/alter_function_same_signature/plan.json @@ -3,7 +3,7 @@ "pgschema_version": "1.4.0", "created_at": "1970-01-01T00:00:00Z", "source_fingerprint": { - "hash": "c849c80ad70c026d11e0d262b5e1a1ae96a98435b116346aed0fc2521ca0f510" + "hash": "60f55cc9364a4c46fec3d3c3b819f5507eea75c2667985f2a9c1a2954f786e4e" }, "groups": [ { diff --git a/testdata/diff/create_function/drop_function/plan.json b/testdata/diff/create_function/drop_function/plan.json index 87ebf0bc..5cc8fb2d 100644 --- a/testdata/diff/create_function/drop_function/plan.json +++ b/testdata/diff/create_function/drop_function/plan.json @@ -3,7 +3,7 @@ "pgschema_version": "1.4.0", "created_at": "1970-01-01T00:00:00Z", "source_fingerprint": { - "hash": "e886d70cfd1f1ba9b9d04415434fd246e8d2683ede276ecc59daa1e5f92f447b" + "hash": "c34208400ed55b8e5d1ee74c1200d4e126ed191b12f9005e80d878e34885968e" }, "groups": [ { diff --git a/testdata/diff/create_procedure/alter_procedure/plan.json b/testdata/diff/create_procedure/alter_procedure/plan.json index a39626be..f185d0ac 100644 --- a/testdata/diff/create_procedure/alter_procedure/plan.json +++ b/testdata/diff/create_procedure/alter_procedure/plan.json @@ -3,7 +3,7 @@ "pgschema_version": "1.4.0", "created_at": "1970-01-01T00:00:00Z", "source_fingerprint": { - "hash": "bfc3ecb23e756b22a9efa8a07704b4ec5542c334b76a3338374dc25ea1817d29" + "hash": "08df93dce9ea1ef62c3edfbfdd06ee868ec8976c33d28242bc2618ba4e9688d8" }, "groups": [ { diff --git a/testdata/diff/create_procedure/drop_procedure/plan.json b/testdata/diff/create_procedure/drop_procedure/plan.json index 0a48407f..ff07130c 100644 --- a/testdata/diff/create_procedure/drop_procedure/plan.json +++ b/testdata/diff/create_procedure/drop_procedure/plan.json @@ -3,7 +3,7 @@ "pgschema_version": "1.4.0", "created_at": "1970-01-01T00:00:00Z", "source_fingerprint": { - "hash": "53ad909d589a920e0f96183fa15eefd1ad2889fdcc64d7486cb100eb507a9055" + "hash": "9f8e6e768179535ba16d74d4ebbb96c3c72439b06bc367c26cc58bfefdc6b17f" }, "groups": [ { diff --git a/testdata/diff/migrate/v5/plan.json b/testdata/diff/migrate/v5/plan.json index af365a85..652aea01 100644 --- a/testdata/diff/migrate/v5/plan.json +++ b/testdata/diff/migrate/v5/plan.json @@ -3,7 +3,7 @@ "pgschema_version": "1.4.0", "created_at": "1970-01-01T00:00:00Z", "source_fingerprint": { - "hash": "e3d1814aca3de8f3c5c86812161fabcc15764d3d41ef4f141f892a3cef6c2d38" + "hash": "00cecda254bb0731ef1d20915ed24ba4cdfd550430d43bfb9e0d0815c4f1b738" }, "groups": [ {