Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 162 additions & 25 deletions ir/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,68 @@ func (f *postgreSQLFormatter) formatAConst(constant *pg_query.A_Const) {

// formatAExpr formats an A_Expr (binary/unary expressions)
func (f *postgreSQLFormatter) formatAExpr(expr *pg_query.A_Expr) {
// Special case: Detect "column = ANY (ARRAY[...])" pattern and convert to "column IN (...)"
// Handle AEXPR_OP_ANY and AEXPR_OP_ALL (e.g., "value > ANY(ARRAY[...])")
if expr.Kind == pg_query.A_Expr_Kind_AEXPR_OP_ANY || expr.Kind == pg_query.A_Expr_Kind_AEXPR_OP_ALL {
// Check if this is "= ANY" which can be converted to IN
isEqualityAny := expr.Kind == pg_query.A_Expr_Kind_AEXPR_OP_ANY &&
len(expr.Name) == 1 &&
expr.Name[0].GetString_() != nil &&
expr.Name[0].GetString_().Sval == "="

if isEqualityAny && expr.Rexpr != nil {
if aArrayExpr := expr.Rexpr.GetAArrayExpr(); aArrayExpr != nil {
// Convert "column = ANY(ARRAY[...])" to "column IN (...)"
f.formatExpressionStripCast(expr.Lexpr)
f.buffer.WriteString(" IN (")
for i, elem := range aArrayExpr.Elements {
if i > 0 {
f.buffer.WriteString(", ")
}
// Strip type casts from constants in IN list
f.formatExpressionStripCast(elem)
}
f.buffer.WriteString(")")
return
}
}

// Format other ANY/ALL operations (>, <, <>, etc.)
// Format: <left> <op> <ANY|ALL> (<right>)
if expr.Lexpr != nil {
f.formatExpression(expr.Lexpr)
}

// Format operator
if len(expr.Name) > 0 {
f.buffer.WriteString(" ")
for i, nameNode := range expr.Name {
if i > 0 {
f.buffer.WriteString(".")
}
if str := nameNode.GetString_(); str != nil {
f.buffer.WriteString(str.Sval)
}
}
f.buffer.WriteString(" ")
}

// Format ANY or ALL
if expr.Kind == pg_query.A_Expr_Kind_AEXPR_OP_ANY {
f.buffer.WriteString("ANY (")
} else {
f.buffer.WriteString("ALL (")
}

// Format right operand
if expr.Rexpr != nil {
f.formatExpression(expr.Rexpr)
}

f.buffer.WriteString(")")
return
}

// Special case: Detect "column = ARRAY[...]" pattern and convert to "column IN (...)"
// This pattern appears when parsing view definitions from pg_get_viewdef()
if len(expr.Name) == 1 && expr.Rexpr != nil {
if str := expr.Name[0].GetString_(); str != nil && str.Sval == "=" {
Expand Down Expand Up @@ -507,7 +568,7 @@ func (f *postgreSQLFormatter) formatCaseExpr(caseExpr *pg_query.CaseExpr) {
f.buffer.WriteString(" WHEN ")
f.formatExpression(when.Expr)
f.buffer.WriteString(" THEN ")
f.formatExpression(when.Result)
f.formatExpressionStripCast(when.Result)
}
}

Expand Down Expand Up @@ -633,38 +694,114 @@ func (f *postgreSQLFormatter) formatAArrayExpr(arrayExpr *pg_query.A_ArrayExpr)
func (f *postgreSQLFormatter) formatScalarArrayOpExpr(arrayOp *pg_query.ScalarArrayOpExpr) {
// Check if this is a simple = ANY pattern that can be converted to IN
// UseOr means ANY (disjunction), !UseOr means ALL (conjunction)
isEqualAny := arrayOp.UseOr && len(arrayOp.Args) == 2
// IMPORTANT: We must also verify the operator is equality (=), not other operators like >, <, <>

if len(arrayOp.Args) != 2 {
// Malformed expression, use deparse fallback
if deparseResult, err := f.deparseNode(&pg_query.Node{Node: &pg_query.Node_ScalarArrayOpExpr{ScalarArrayOpExpr: arrayOp}}); err == nil {
f.buffer.WriteString(deparseResult)
}
return
}

if isEqualAny {
// Get the operator name by deparsing
deparsed, err := f.deparseNode(&pg_query.Node{Node: &pg_query.Node_ScalarArrayOpExpr{ScalarArrayOpExpr: arrayOp}})
if err != nil {
// If deparse fails, just return empty (shouldn't happen in practice)
return
}

// Extract the operator once to avoid redundant string parsing
opName := extractOperator(deparsed)

// Check if operator is = (equality)
isEqualityOp := opName == "="

// Only convert to IN syntax if it's "= ANY"
if arrayOp.UseOr && isEqualityOp {
// Args[0] is the left side (column), Args[1] is the right side (array)
// Format as "column IN (values)"
if len(arrayOp.Args) >= 2 {
// Format left side (the column)
f.formatExpression(arrayOp.Args[0])
// Format left side (the column)
f.formatExpression(arrayOp.Args[0])

f.buffer.WriteString(" IN (")
f.buffer.WriteString(" IN (")

// Extract values from the array
if arrayExpr := arrayOp.Args[1].GetArrayExpr(); arrayExpr != nil {
// Format array elements as comma-separated list
for i, elem := range arrayExpr.Elements {
if i > 0 {
f.buffer.WriteString(", ")
}
f.formatExpression(elem)
// Extract values from the array
if arrayExpr := arrayOp.Args[1].GetArrayExpr(); arrayExpr != nil {
// Format array elements as comma-separated list
for i, elem := range arrayExpr.Elements {
if i > 0 {
f.buffer.WriteString(", ")
}
} else {
// Fallback: format the right expression as-is
f.formatExpression(arrayOp.Args[1])
f.formatExpression(elem)
}

f.buffer.WriteString(")")
return
} else {
// Fallback: format the right expression as-is
f.formatExpression(arrayOp.Args[1])
}

f.buffer.WriteString(")")
return
}

// For other operations (like <> ALL) or malformed expressions, use deparse fallback
if deparseResult, err := f.deparseNode(&pg_query.Node{Node: &pg_query.Node_ScalarArrayOpExpr{ScalarArrayOpExpr: arrayOp}}); err == nil {
f.buffer.WriteString(deparseResult)
// For other operations (like <> ANY, > ANY, = ALL), format manually
// Format: <left_expr> <op> <ANY|ALL> (<array_expr>)

// Format left side
f.formatExpression(arrayOp.Args[0])

// Use the already-extracted operator
if opName != "" {
f.buffer.WriteString(" ")
f.buffer.WriteString(opName)
f.buffer.WriteString(" ")
} else {
f.buffer.WriteString(" <unknown> ")
}

// Format ANY or ALL
if arrayOp.UseOr {
f.buffer.WriteString("ANY (")
} else {
f.buffer.WriteString("ALL (")
}

// Format right side (the array)
f.formatExpression(arrayOp.Args[1])

f.buffer.WriteString(")")
}

// extractOperator extracts the operator from a deparsed ScalarArrayOpExpr string
// e.g., "value > ANY (ARRAY[...])" -> ">"
func extractOperator(deparsed string) string {
// Look for pattern: <something> <operator> ANY/ALL
anyIdx := strings.Index(deparsed, " ANY")
allIdx := strings.Index(deparsed, " ALL")

var cutoff int
if anyIdx >= 0 && (allIdx < 0 || anyIdx < allIdx) {
cutoff = anyIdx
} else if allIdx >= 0 {
cutoff = allIdx
} else {
return ""
}

// Work backwards from cutoff to find the operator
// Operators can be: =, <>, !=, <, >, <=, >=, etc.
substr := deparsed[:cutoff]

// Common operators in reverse order of length (to match longest first)
operators := []string{"<>", "!=", "<=", ">=", "=", "<", ">", "~", "!~", "~~", "!~~"}

for _, op := range operators {
// Look for " <op> " pattern
searchPattern := " " + op + " "
if idx := strings.LastIndex(substr, searchPattern); idx >= 0 {
return op
}
}

return ""
}
9 changes: 9 additions & 0 deletions testdata/diff/create_view/add_view_array_operators/diff.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE OR REPLACE VIEW test_array_operators AS
SELECT
id,
value,
CASE WHEN value IN (10, 20, 30) THEN 'matched' ELSE 'not_matched' END AS equal_any_test,
CASE WHEN value > ANY (ARRAY[10, 20, 30]) THEN 'high' ELSE 'low' END AS greater_any_test,
CASE WHEN value < ANY (ARRAY[5, 15, 25]) THEN 'found_lower' ELSE 'all_higher' END AS less_any_test,
CASE WHEN priority <> ANY (ARRAY[1, 2, 3]) THEN 'different' ELSE 'same' END AS not_equal_any_test
FROM test_data;
21 changes: 21 additions & 0 deletions testdata/diff/create_view/add_view_array_operators/new.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
CREATE TABLE public.test_data (
id SERIAL PRIMARY KEY,
value INTEGER,
priority INTEGER
);

-- View with various array operators to test ScalarArrayOpExpr formatting
-- The fix ensures that:
-- 1. Only "= ANY" is converted to "IN" syntax
-- 2. Other operators (>, <, <>) preserve "ANY" syntax
CREATE VIEW public.test_array_operators AS
SELECT
id,
value,
-- This SHOULD be converted to IN syntax (= ANY -> IN)
CASE WHEN value = ANY(ARRAY[10, 20, 30]) THEN 'matched' ELSE 'not_matched' END AS equal_any_test,
-- These should NOT be converted to IN syntax - they must preserve ANY
CASE WHEN value > ANY(ARRAY[10, 20, 30]) THEN 'high' ELSE 'low' END AS greater_any_test,
CASE WHEN value < ANY(ARRAY[5, 15, 25]) THEN 'found_lower' ELSE 'all_higher' END AS less_any_test,
CASE WHEN priority <> ANY(ARRAY[1, 2, 3]) THEN 'different' ELSE 'same' END AS not_equal_any_test
FROM test_data;
5 changes: 5 additions & 0 deletions testdata/diff/create_view/add_view_array_operators/old.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE TABLE public.test_data (
id SERIAL PRIMARY KEY,
value INTEGER,
priority INTEGER
);
20 changes: 20 additions & 0 deletions testdata/diff/create_view/add_view_array_operators/plan.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"version": "1.0.0",
"pgschema_version": "1.4.0",
"created_at": "1970-01-01T00:00:00Z",
"source_fingerprint": {
"hash": "423516aca5eb1f6a8f040d2cf4f1dee6e0dc441f91fc54489812bcdedb29bc28"
},
"groups": [
{
"steps": [
{
"sql": "CREATE OR REPLACE VIEW test_array_operators AS\n SELECT\n id,\n value,\n CASE WHEN value IN (10, 20, 30) THEN 'matched' ELSE 'not_matched' END AS equal_any_test,\n CASE WHEN value > ANY (ARRAY[10, 20, 30]) THEN 'high' ELSE 'low' END AS greater_any_test,\n CASE WHEN value < ANY (ARRAY[5, 15, 25]) THEN 'found_lower' ELSE 'all_higher' END AS less_any_test,\n CASE WHEN priority <> ANY (ARRAY[1, 2, 3]) THEN 'different' ELSE 'same' END AS not_equal_any_test\n FROM test_data;",
"type": "view",
"operation": "create",
"path": "public.test_array_operators"
}
]
}
]
}
9 changes: 9 additions & 0 deletions testdata/diff/create_view/add_view_array_operators/plan.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE OR REPLACE VIEW test_array_operators AS
SELECT
id,
value,
CASE WHEN value IN (10, 20, 30) THEN 'matched' ELSE 'not_matched' END AS equal_any_test,
CASE WHEN value > ANY (ARRAY[10, 20, 30]) THEN 'high' ELSE 'low' END AS greater_any_test,
CASE WHEN value < ANY (ARRAY[5, 15, 25]) THEN 'found_lower' ELSE 'all_higher' END AS less_any_test,
CASE WHEN priority <> ANY (ARRAY[1, 2, 3]) THEN 'different' ELSE 'same' END AS not_equal_any_test
FROM test_data;
20 changes: 20 additions & 0 deletions testdata/diff/create_view/add_view_array_operators/plan.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Plan: 1 to add.

Summary by type:
views: 1 to add

Views:
+ test_array_operators

DDL to be executed:
--------------------------------------------------

CREATE OR REPLACE VIEW test_array_operators AS
SELECT
id,
value,
CASE WHEN value IN (10, 20, 30) THEN 'matched' ELSE 'not_matched' END AS equal_any_test,
CASE WHEN value > ANY (ARRAY[10, 20, 30]) THEN 'high' ELSE 'low' END AS greater_any_test,
CASE WHEN value < ANY (ARRAY[5, 15, 25]) THEN 'found_lower' ELSE 'all_higher' END AS less_any_test,
CASE WHEN priority <> ANY (ARRAY[1, 2, 3]) THEN 'different' ELSE 'same' END AS not_equal_any_test
FROM test_data;