diff --git a/internal/diff/function.go b/internal/diff/function.go index a4961f99..48130cd1 100644 --- a/internal/diff/function.go +++ b/internal/diff/function.go @@ -232,11 +232,15 @@ func generateFunctionSQL(function *ir.Function, targetSchema string) string { // Add the function body if function.Definition != "" { - // Check if this uses RETURN clause syntax (PG14+) - // pg_get_function_sqlbody returns "RETURN expression" which should not be wrapped - // Use case-insensitive comparison to handle all variations + // Check if this uses SQL-standard body syntax (PG14+) + // pg_get_function_sqlbody returns: + // - "RETURN expression" for simple SQL-standard bodies + // - "BEGIN ATOMIC ... END" for multi-statement SQL-standard bodies + // These should not be wrapped with AS $$ ... $$ trimmedDef := strings.TrimSpace(function.Definition) - if len(trimmedDef) >= 7 && strings.EqualFold(trimmedDef[:7], "RETURN ") { + isSQLStandardBody := (len(trimmedDef) >= 7 && strings.EqualFold(trimmedDef[:7], "RETURN ")) || + (len(trimmedDef) >= 12 && strings.EqualFold(trimmedDef[:12], "BEGIN ATOMIC")) + if isSQLStandardBody { stmt.WriteString(fmt.Sprintf("\n%s;", trimmedDef)) } else { // Traditional AS $$ ... $$ syntax diff --git a/internal/diff/procedure.go b/internal/diff/procedure.go index 1a385bc8..ac7ba17f 100644 --- a/internal/diff/procedure.go +++ b/internal/diff/procedure.go @@ -183,11 +183,15 @@ func generateProcedureSQL(procedure *ir.Procedure, targetSchema string) string { // Add the procedure body if procedure.Definition != "" { - // Check if this uses RETURN clause syntax (PG14+) - // pg_get_function_sqlbody returns "RETURN expression" which should not be wrapped - // Use case-insensitive comparison to handle all variations + // Check if this uses SQL-standard body syntax (PG14+) + // pg_get_function_sqlbody returns "BEGIN ATOMIC ... END" for SQL-standard procedure bodies + // These should not be wrapped with AS $$ ... $$ + // Note: The RETURN check is kept for consistency with function handling, + // though procedures don't support value-returning RETURN statements trimmedDef := strings.TrimSpace(procedure.Definition) - if len(trimmedDef) >= 7 && strings.EqualFold(trimmedDef[:7], "RETURN ") { + isSQLStandardBody := (len(trimmedDef) >= 7 && strings.EqualFold(trimmedDef[:7], "RETURN ")) || + (len(trimmedDef) >= 12 && strings.EqualFold(trimmedDef[:12], "BEGIN ATOMIC")) + if isSQLStandardBody { stmt.WriteString(fmt.Sprintf("\n%s;", trimmedDef)) } else { // Traditional AS $$ ... $$ syntax diff --git a/testdata/diff/create_function/add_function/diff.sql b/testdata/diff/create_function/add_function/diff.sql index b7c08968..c484af6b 100644 --- a/testdata/diff/create_function/add_function/diff.sql +++ b/testdata/diff/create_function/add_function/diff.sql @@ -1,3 +1,14 @@ +CREATE OR REPLACE FUNCTION add_with_tax( + amount numeric, + tax_rate numeric DEFAULT 0.1 +) +RETURNS numeric +LANGUAGE sql +VOLATILE +BEGIN ATOMIC + SELECT (amount + (amount * tax_rate)); +END; + CREATE OR REPLACE FUNCTION calculate_tax( amount numeric, rate numeric diff --git a/testdata/diff/create_function/add_function/new.sql b/testdata/diff/create_function/add_function/new.sql index 8b5dd545..4daf4732 100644 --- a/testdata/diff/create_function/add_function/new.sql +++ b/testdata/diff/create_function/add_function/new.sql @@ -48,4 +48,13 @@ STABLE LEAKPROOF AS $$ SELECT '***' || substring(input from 4); -$$; \ No newline at end of file +$$; + +-- Function testing BEGIN ATOMIC syntax (SQL-standard multi-statement body, PG14+) +-- Reproduces issue #241 +CREATE FUNCTION add_with_tax(amount numeric, tax_rate numeric DEFAULT 0.1) +RETURNS numeric +LANGUAGE SQL +BEGIN ATOMIC + SELECT amount + (amount * tax_rate); +END; diff --git a/testdata/diff/create_function/add_function/plan.json b/testdata/diff/create_function/add_function/plan.json index bca7b6f5..2ebc4995 100644 --- a/testdata/diff/create_function/add_function/plan.json +++ b/testdata/diff/create_function/add_function/plan.json @@ -8,6 +8,12 @@ "groups": [ { "steps": [ + { + "sql": "CREATE OR REPLACE FUNCTION add_with_tax(\n amount numeric,\n tax_rate numeric DEFAULT 0.1\n)\nRETURNS numeric\nLANGUAGE sql\nVOLATILE\nBEGIN ATOMIC\n SELECT (amount + (amount * tax_rate));\nEND;", + "type": "function", + "operation": "create", + "path": "public.add_with_tax" + }, { "sql": "CREATE OR REPLACE FUNCTION calculate_tax(\n amount numeric,\n rate numeric\n)\nRETURNS numeric\nLANGUAGE sql\nIMMUTABLE\nPARALLEL SAFE\nAS $$\n SELECT amount * rate;\n$$;", "type": "function", diff --git a/testdata/diff/create_function/add_function/plan.sql b/testdata/diff/create_function/add_function/plan.sql index b7c08968..c484af6b 100644 --- a/testdata/diff/create_function/add_function/plan.sql +++ b/testdata/diff/create_function/add_function/plan.sql @@ -1,3 +1,14 @@ +CREATE OR REPLACE FUNCTION add_with_tax( + amount numeric, + tax_rate numeric DEFAULT 0.1 +) +RETURNS numeric +LANGUAGE sql +VOLATILE +BEGIN ATOMIC + SELECT (amount + (amount * tax_rate)); +END; + CREATE OR REPLACE FUNCTION calculate_tax( amount numeric, rate numeric diff --git a/testdata/diff/create_function/add_function/plan.txt b/testdata/diff/create_function/add_function/plan.txt index 25da216f..8b1f3bac 100644 --- a/testdata/diff/create_function/add_function/plan.txt +++ b/testdata/diff/create_function/add_function/plan.txt @@ -1,9 +1,10 @@ -Plan: 3 to add. +Plan: 4 to add. Summary by type: - functions: 3 to add + functions: 4 to add Functions: + + add_with_tax + calculate_tax + mask_sensitive_data + process_order @@ -11,6 +12,17 @@ Functions: DDL to be executed: -------------------------------------------------- +CREATE OR REPLACE FUNCTION add_with_tax( + amount numeric, + tax_rate numeric DEFAULT 0.1 +) +RETURNS numeric +LANGUAGE sql +VOLATILE +BEGIN ATOMIC + SELECT (amount + (amount * tax_rate)); +END; + CREATE OR REPLACE FUNCTION calculate_tax( amount numeric, rate numeric diff --git a/testdata/diff/create_procedure/add_procedure/diff.sql b/testdata/diff/create_procedure/add_procedure/diff.sql index 5504914b..5d82c0ad 100644 --- a/testdata/diff/create_procedure/add_procedure/diff.sql +++ b/testdata/diff/create_procedure/add_procedure/diff.sql @@ -9,3 +9,11 @@ BEGIN output_value := input_value + 1; END; $$; + +CREATE OR REPLACE PROCEDURE validate_input( + IN input_value integer +) +LANGUAGE sql +BEGIN ATOMIC + SELECT (input_value * 2); +END; diff --git a/testdata/diff/create_procedure/add_procedure/new.sql b/testdata/diff/create_procedure/add_procedure/new.sql index 6bdfafdf..f48adda6 100644 --- a/testdata/diff/create_procedure/add_procedure/new.sql +++ b/testdata/diff/create_procedure/add_procedure/new.sql @@ -8,4 +8,12 @@ BEGIN RAISE NOTICE 'Input value is: %', input_value; output_value := input_value + 1; END; -$$; \ No newline at end of file +$$; + +-- Procedure testing BEGIN ATOMIC syntax (SQL-standard body, PG14+) +-- Reproduces issue #241 for procedures +CREATE PROCEDURE validate_input(input_value integer) +LANGUAGE SQL +BEGIN ATOMIC + SELECT input_value * 2; +END; \ No newline at end of file diff --git a/testdata/diff/create_procedure/add_procedure/plan.json b/testdata/diff/create_procedure/add_procedure/plan.json index 540e12c9..12f5d98b 100644 --- a/testdata/diff/create_procedure/add_procedure/plan.json +++ b/testdata/diff/create_procedure/add_procedure/plan.json @@ -13,6 +13,12 @@ "type": "procedure", "operation": "create", "path": "public.example_procedure" + }, + { + "sql": "CREATE OR REPLACE PROCEDURE validate_input(\n IN input_value integer\n)\nLANGUAGE sql\nBEGIN ATOMIC\n SELECT (input_value * 2);\nEND;", + "type": "procedure", + "operation": "create", + "path": "public.validate_input" } ] } diff --git a/testdata/diff/create_procedure/add_procedure/plan.sql b/testdata/diff/create_procedure/add_procedure/plan.sql index 5504914b..5d82c0ad 100644 --- a/testdata/diff/create_procedure/add_procedure/plan.sql +++ b/testdata/diff/create_procedure/add_procedure/plan.sql @@ -9,3 +9,11 @@ BEGIN output_value := input_value + 1; END; $$; + +CREATE OR REPLACE PROCEDURE validate_input( + IN input_value integer +) +LANGUAGE sql +BEGIN ATOMIC + SELECT (input_value * 2); +END; diff --git a/testdata/diff/create_procedure/add_procedure/plan.txt b/testdata/diff/create_procedure/add_procedure/plan.txt index 832401ca..ff03a051 100644 --- a/testdata/diff/create_procedure/add_procedure/plan.txt +++ b/testdata/diff/create_procedure/add_procedure/plan.txt @@ -1,10 +1,11 @@ -Plan: 1 to add. +Plan: 2 to add. Summary by type: - procedures: 1 to add + procedures: 2 to add Procedures: + example_procedure + + validate_input DDL to be executed: -------------------------------------------------- @@ -20,3 +21,11 @@ BEGIN output_value := input_value + 1; END; $$; + +CREATE OR REPLACE PROCEDURE validate_input( + IN input_value integer +) +LANGUAGE sql +BEGIN ATOMIC + SELECT (input_value * 2); +END;