From 3553498b1a9e8a6147310fb9418a2b66199b997a Mon Sep 17 00:00:00 2001 From: Tianzhou Date: Thu, 16 Oct 2025 01:37:48 +0800 Subject: [PATCH] fix: view array operator --- ir/formatter.go | 187 +++++++++++++++--- .../add_view_array_operators/diff.sql | 9 + .../add_view_array_operators/new.sql | 21 ++ .../add_view_array_operators/old.sql | 5 + .../add_view_array_operators/plan.json | 20 ++ .../add_view_array_operators/plan.sql | 9 + .../add_view_array_operators/plan.txt | 20 ++ 7 files changed, 246 insertions(+), 25 deletions(-) create mode 100644 testdata/diff/create_view/add_view_array_operators/diff.sql create mode 100644 testdata/diff/create_view/add_view_array_operators/new.sql create mode 100644 testdata/diff/create_view/add_view_array_operators/old.sql create mode 100644 testdata/diff/create_view/add_view_array_operators/plan.json create mode 100644 testdata/diff/create_view/add_view_array_operators/plan.sql create mode 100644 testdata/diff/create_view/add_view_array_operators/plan.txt diff --git a/ir/formatter.go b/ir/formatter.go index 752c4695..49f191bf 100644 --- a/ir/formatter.go +++ b/ir/formatter.go @@ -301,7 +301,68 @@ func (f *postgreSQLFormatter) formatAConst(constant *pg_query.A_Const) { // formatAExpr formats an A_Expr (binary/unary expressions) func (f *postgreSQLFormatter) formatAExpr(expr *pg_query.A_Expr) { - // Special case: Detect "column = ANY (ARRAY[...])" pattern and convert to "column IN (...)" + // Handle AEXPR_OP_ANY and AEXPR_OP_ALL (e.g., "value > ANY(ARRAY[...])") + if expr.Kind == pg_query.A_Expr_Kind_AEXPR_OP_ANY || expr.Kind == pg_query.A_Expr_Kind_AEXPR_OP_ALL { + // Check if this is "= ANY" which can be converted to IN + isEqualityAny := expr.Kind == pg_query.A_Expr_Kind_AEXPR_OP_ANY && + len(expr.Name) == 1 && + expr.Name[0].GetString_() != nil && + expr.Name[0].GetString_().Sval == "=" + + if isEqualityAny && expr.Rexpr != nil { + if aArrayExpr := expr.Rexpr.GetAArrayExpr(); aArrayExpr != nil { + // Convert "column = ANY(ARRAY[...])" to "column IN (...)" + f.formatExpressionStripCast(expr.Lexpr) + f.buffer.WriteString(" IN (") + for i, elem := range aArrayExpr.Elements { + if i > 0 { + f.buffer.WriteString(", ") + } + // Strip type casts from constants in IN list + f.formatExpressionStripCast(elem) + } + f.buffer.WriteString(")") + return + } + } + + // Format other ANY/ALL operations (>, <, <>, etc.) + // Format: () + if expr.Lexpr != nil { + f.formatExpression(expr.Lexpr) + } + + // Format operator + if len(expr.Name) > 0 { + f.buffer.WriteString(" ") + for i, nameNode := range expr.Name { + if i > 0 { + f.buffer.WriteString(".") + } + if str := nameNode.GetString_(); str != nil { + f.buffer.WriteString(str.Sval) + } + } + f.buffer.WriteString(" ") + } + + // Format ANY or ALL + if expr.Kind == pg_query.A_Expr_Kind_AEXPR_OP_ANY { + f.buffer.WriteString("ANY (") + } else { + f.buffer.WriteString("ALL (") + } + + // Format right operand + if expr.Rexpr != nil { + f.formatExpression(expr.Rexpr) + } + + f.buffer.WriteString(")") + return + } + + // Special case: Detect "column = ARRAY[...]" pattern and convert to "column IN (...)" // This pattern appears when parsing view definitions from pg_get_viewdef() if len(expr.Name) == 1 && expr.Rexpr != nil { if str := expr.Name[0].GetString_(); str != nil && str.Sval == "=" { @@ -507,7 +568,7 @@ func (f *postgreSQLFormatter) formatCaseExpr(caseExpr *pg_query.CaseExpr) { f.buffer.WriteString(" WHEN ") f.formatExpression(when.Expr) f.buffer.WriteString(" THEN ") - f.formatExpression(when.Result) + f.formatExpressionStripCast(when.Result) } } @@ -633,38 +694,114 @@ func (f *postgreSQLFormatter) formatAArrayExpr(arrayExpr *pg_query.A_ArrayExpr) func (f *postgreSQLFormatter) formatScalarArrayOpExpr(arrayOp *pg_query.ScalarArrayOpExpr) { // Check if this is a simple = ANY pattern that can be converted to IN // UseOr means ANY (disjunction), !UseOr means ALL (conjunction) - isEqualAny := arrayOp.UseOr && len(arrayOp.Args) == 2 + // IMPORTANT: We must also verify the operator is equality (=), not other operators like >, <, <> + + if len(arrayOp.Args) != 2 { + // Malformed expression, use deparse fallback + if deparseResult, err := f.deparseNode(&pg_query.Node{Node: &pg_query.Node_ScalarArrayOpExpr{ScalarArrayOpExpr: arrayOp}}); err == nil { + f.buffer.WriteString(deparseResult) + } + return + } - if isEqualAny { + // Get the operator name by deparsing + deparsed, err := f.deparseNode(&pg_query.Node{Node: &pg_query.Node_ScalarArrayOpExpr{ScalarArrayOpExpr: arrayOp}}) + if err != nil { + // If deparse fails, just return empty (shouldn't happen in practice) + return + } + + // Extract the operator once to avoid redundant string parsing + opName := extractOperator(deparsed) + + // Check if operator is = (equality) + isEqualityOp := opName == "=" + + // Only convert to IN syntax if it's "= ANY" + if arrayOp.UseOr && isEqualityOp { // Args[0] is the left side (column), Args[1] is the right side (array) // Format as "column IN (values)" - if len(arrayOp.Args) >= 2 { - // Format left side (the column) - f.formatExpression(arrayOp.Args[0]) + // Format left side (the column) + f.formatExpression(arrayOp.Args[0]) - f.buffer.WriteString(" IN (") + f.buffer.WriteString(" IN (") - // Extract values from the array - if arrayExpr := arrayOp.Args[1].GetArrayExpr(); arrayExpr != nil { - // Format array elements as comma-separated list - for i, elem := range arrayExpr.Elements { - if i > 0 { - f.buffer.WriteString(", ") - } - f.formatExpression(elem) + // Extract values from the array + if arrayExpr := arrayOp.Args[1].GetArrayExpr(); arrayExpr != nil { + // Format array elements as comma-separated list + for i, elem := range arrayExpr.Elements { + if i > 0 { + f.buffer.WriteString(", ") } - } else { - // Fallback: format the right expression as-is - f.formatExpression(arrayOp.Args[1]) + f.formatExpression(elem) } - - f.buffer.WriteString(")") - return + } else { + // Fallback: format the right expression as-is + f.formatExpression(arrayOp.Args[1]) } + + f.buffer.WriteString(")") + return } - // For other operations (like <> ALL) or malformed expressions, use deparse fallback - if deparseResult, err := f.deparseNode(&pg_query.Node{Node: &pg_query.Node_ScalarArrayOpExpr{ScalarArrayOpExpr: arrayOp}}); err == nil { - f.buffer.WriteString(deparseResult) + // For other operations (like <> ANY, > ANY, = ALL), format manually + // Format: () + + // Format left side + f.formatExpression(arrayOp.Args[0]) + + // Use the already-extracted operator + if opName != "" { + f.buffer.WriteString(" ") + f.buffer.WriteString(opName) + f.buffer.WriteString(" ") + } else { + f.buffer.WriteString(" ") + } + + // Format ANY or ALL + if arrayOp.UseOr { + f.buffer.WriteString("ANY (") + } else { + f.buffer.WriteString("ALL (") } + + // Format right side (the array) + f.formatExpression(arrayOp.Args[1]) + + f.buffer.WriteString(")") +} + +// extractOperator extracts the operator from a deparsed ScalarArrayOpExpr string +// e.g., "value > ANY (ARRAY[...])" -> ">" +func extractOperator(deparsed string) string { + // Look for pattern: ANY/ALL + anyIdx := strings.Index(deparsed, " ANY") + allIdx := strings.Index(deparsed, " ALL") + + var cutoff int + if anyIdx >= 0 && (allIdx < 0 || anyIdx < allIdx) { + cutoff = anyIdx + } else if allIdx >= 0 { + cutoff = allIdx + } else { + return "" + } + + // Work backwards from cutoff to find the operator + // Operators can be: =, <>, !=, <, >, <=, >=, etc. + substr := deparsed[:cutoff] + + // Common operators in reverse order of length (to match longest first) + operators := []string{"<>", "!=", "<=", ">=", "=", "<", ">", "~", "!~", "~~", "!~~"} + + for _, op := range operators { + // Look for " " pattern + searchPattern := " " + op + " " + if idx := strings.LastIndex(substr, searchPattern); idx >= 0 { + return op + } + } + + return "" } diff --git a/testdata/diff/create_view/add_view_array_operators/diff.sql b/testdata/diff/create_view/add_view_array_operators/diff.sql new file mode 100644 index 00000000..1b4fa9fa --- /dev/null +++ b/testdata/diff/create_view/add_view_array_operators/diff.sql @@ -0,0 +1,9 @@ +CREATE OR REPLACE VIEW test_array_operators AS + SELECT + id, + value, + CASE WHEN value IN (10, 20, 30) THEN 'matched' ELSE 'not_matched' END AS equal_any_test, + CASE WHEN value > ANY (ARRAY[10, 20, 30]) THEN 'high' ELSE 'low' END AS greater_any_test, + CASE WHEN value < ANY (ARRAY[5, 15, 25]) THEN 'found_lower' ELSE 'all_higher' END AS less_any_test, + CASE WHEN priority <> ANY (ARRAY[1, 2, 3]) THEN 'different' ELSE 'same' END AS not_equal_any_test + FROM test_data; diff --git a/testdata/diff/create_view/add_view_array_operators/new.sql b/testdata/diff/create_view/add_view_array_operators/new.sql new file mode 100644 index 00000000..c45b242f --- /dev/null +++ b/testdata/diff/create_view/add_view_array_operators/new.sql @@ -0,0 +1,21 @@ +CREATE TABLE public.test_data ( + id SERIAL PRIMARY KEY, + value INTEGER, + priority INTEGER +); + +-- View with various array operators to test ScalarArrayOpExpr formatting +-- The fix ensures that: +-- 1. Only "= ANY" is converted to "IN" syntax +-- 2. Other operators (>, <, <>) preserve "ANY" syntax +CREATE VIEW public.test_array_operators AS +SELECT + id, + value, + -- This SHOULD be converted to IN syntax (= ANY -> IN) + CASE WHEN value = ANY(ARRAY[10, 20, 30]) THEN 'matched' ELSE 'not_matched' END AS equal_any_test, + -- These should NOT be converted to IN syntax - they must preserve ANY + CASE WHEN value > ANY(ARRAY[10, 20, 30]) THEN 'high' ELSE 'low' END AS greater_any_test, + CASE WHEN value < ANY(ARRAY[5, 15, 25]) THEN 'found_lower' ELSE 'all_higher' END AS less_any_test, + CASE WHEN priority <> ANY(ARRAY[1, 2, 3]) THEN 'different' ELSE 'same' END AS not_equal_any_test +FROM test_data; diff --git a/testdata/diff/create_view/add_view_array_operators/old.sql b/testdata/diff/create_view/add_view_array_operators/old.sql new file mode 100644 index 00000000..bbbae4c7 --- /dev/null +++ b/testdata/diff/create_view/add_view_array_operators/old.sql @@ -0,0 +1,5 @@ +CREATE TABLE public.test_data ( + id SERIAL PRIMARY KEY, + value INTEGER, + priority INTEGER +); diff --git a/testdata/diff/create_view/add_view_array_operators/plan.json b/testdata/diff/create_view/add_view_array_operators/plan.json new file mode 100644 index 00000000..b263f0eb --- /dev/null +++ b/testdata/diff/create_view/add_view_array_operators/plan.json @@ -0,0 +1,20 @@ +{ + "version": "1.0.0", + "pgschema_version": "1.4.0", + "created_at": "1970-01-01T00:00:00Z", + "source_fingerprint": { + "hash": "423516aca5eb1f6a8f040d2cf4f1dee6e0dc441f91fc54489812bcdedb29bc28" + }, + "groups": [ + { + "steps": [ + { + "sql": "CREATE OR REPLACE VIEW test_array_operators AS\n SELECT\n id,\n value,\n CASE WHEN value IN (10, 20, 30) THEN 'matched' ELSE 'not_matched' END AS equal_any_test,\n CASE WHEN value > ANY (ARRAY[10, 20, 30]) THEN 'high' ELSE 'low' END AS greater_any_test,\n CASE WHEN value < ANY (ARRAY[5, 15, 25]) THEN 'found_lower' ELSE 'all_higher' END AS less_any_test,\n CASE WHEN priority <> ANY (ARRAY[1, 2, 3]) THEN 'different' ELSE 'same' END AS not_equal_any_test\n FROM test_data;", + "type": "view", + "operation": "create", + "path": "public.test_array_operators" + } + ] + } + ] +} diff --git a/testdata/diff/create_view/add_view_array_operators/plan.sql b/testdata/diff/create_view/add_view_array_operators/plan.sql new file mode 100644 index 00000000..1b4fa9fa --- /dev/null +++ b/testdata/diff/create_view/add_view_array_operators/plan.sql @@ -0,0 +1,9 @@ +CREATE OR REPLACE VIEW test_array_operators AS + SELECT + id, + value, + CASE WHEN value IN (10, 20, 30) THEN 'matched' ELSE 'not_matched' END AS equal_any_test, + CASE WHEN value > ANY (ARRAY[10, 20, 30]) THEN 'high' ELSE 'low' END AS greater_any_test, + CASE WHEN value < ANY (ARRAY[5, 15, 25]) THEN 'found_lower' ELSE 'all_higher' END AS less_any_test, + CASE WHEN priority <> ANY (ARRAY[1, 2, 3]) THEN 'different' ELSE 'same' END AS not_equal_any_test + FROM test_data; diff --git a/testdata/diff/create_view/add_view_array_operators/plan.txt b/testdata/diff/create_view/add_view_array_operators/plan.txt new file mode 100644 index 00000000..cfc15a12 --- /dev/null +++ b/testdata/diff/create_view/add_view_array_operators/plan.txt @@ -0,0 +1,20 @@ +Plan: 1 to add. + +Summary by type: + views: 1 to add + +Views: + + test_array_operators + +DDL to be executed: +-------------------------------------------------- + +CREATE OR REPLACE VIEW test_array_operators AS + SELECT + id, + value, + CASE WHEN value IN (10, 20, 30) THEN 'matched' ELSE 'not_matched' END AS equal_any_test, + CASE WHEN value > ANY (ARRAY[10, 20, 30]) THEN 'high' ELSE 'low' END AS greater_any_test, + CASE WHEN value < ANY (ARRAY[5, 15, 25]) THEN 'found_lower' ELSE 'all_higher' END AS less_any_test, + CASE WHEN priority <> ANY (ARRAY[1, 2, 3]) THEN 'different' ELSE 'same' END AS not_equal_any_test + FROM test_data;