diff --git a/ir/formatter.go b/ir/formatter.go index 49f191bf..d3a49382 100644 --- a/ir/formatter.go +++ b/ir/formatter.go @@ -220,6 +220,15 @@ func (f *postgreSQLFormatter) formatRangeSubselect(subselect *pg_query.RangeSubs } // formatExpression formats a general expression +// +// NOTE: Two important expression types for array operations: +// 1. A_Expr: Appears when parsing SQL files directly (e.g., "value = ANY(ARRAY[...])") +// 2. ScalarArrayOpExpr: Appears when fetching view definitions from PostgreSQL via pg_get_viewdef() +// +// PostgreSQL internally converts "IN (...)" to "= ANY(ARRAY[...])" when storing views. +// When we fetch the view definition back via pg_get_viewdef(), it returns ScalarArrayOpExpr nodes. +// Both formatAExpr and formatScalarArrayOpExpr convert "= ANY" back to the cleaner "IN" syntax, +// while preserving other operators (>, <, <>) with ANY/ALL syntax. func (f *postgreSQLFormatter) formatExpression(expr *pg_query.Node) { switch { case expr.GetColumnRef() != nil: @@ -312,16 +321,7 @@ func (f *postgreSQLFormatter) formatAExpr(expr *pg_query.A_Expr) { if isEqualityAny && expr.Rexpr != nil { if aArrayExpr := expr.Rexpr.GetAArrayExpr(); aArrayExpr != nil { // Convert "column = ANY(ARRAY[...])" to "column IN (...)" - f.formatExpressionStripCast(expr.Lexpr) - f.buffer.WriteString(" IN (") - for i, elem := range aArrayExpr.Elements { - if i > 0 { - f.buffer.WriteString(", ") - } - // Strip type casts from constants in IN list - f.formatExpressionStripCast(elem) - } - f.buffer.WriteString(")") + f.formatArrayAsIN(expr.Lexpr, aArrayExpr.Elements) return } } @@ -367,18 +367,8 @@ func (f *postgreSQLFormatter) formatAExpr(expr *pg_query.A_Expr) { if len(expr.Name) == 1 && expr.Rexpr != nil { if str := expr.Name[0].GetString_(); str != nil && str.Sval == "=" { if aArrayExpr := expr.Rexpr.GetAArrayExpr(); aArrayExpr != nil { - // Direct array comparison: column = ARRAY[...] - // Convert to IN syntax, stripping unnecessary type casts from constants - f.formatExpressionStripCast(expr.Lexpr) - f.buffer.WriteString(" IN (") - for i, elem := range aArrayExpr.Elements { - if i > 0 { - f.buffer.WriteString(", ") - } - // Strip type casts from constants in IN list - f.formatExpressionStripCast(elem) - } - f.buffer.WriteString(")") + // Direct array comparison: column = ARRAY[...] → column IN (...) + f.formatArrayAsIN(expr.Lexpr, aArrayExpr.Elements) return } } @@ -689,13 +679,40 @@ func (f *postgreSQLFormatter) formatAArrayExpr(arrayExpr *pg_query.A_ArrayExpr) f.buffer.WriteString("]") } -// formatScalarArrayOpExpr formats scalar array operations like "column = ANY (ARRAY[...])" -// and converts them to the simpler "column IN (...)" syntax -func (f *postgreSQLFormatter) formatScalarArrayOpExpr(arrayOp *pg_query.ScalarArrayOpExpr) { - // Check if this is a simple = ANY pattern that can be converted to IN - // UseOr means ANY (disjunction), !UseOr means ALL (conjunction) - // IMPORTANT: We must also verify the operator is equality (=), not other operators like >, <, <> +// formatArrayAsIN is a helper to format "column IN (values)" syntax +// Used by both formatAExpr and formatScalarArrayOpExpr to convert "= ANY(ARRAY[...])" to "IN (...)" +func (f *postgreSQLFormatter) formatArrayAsIN(leftExpr *pg_query.Node, arrayElements []*pg_query.Node) { + // Format left side (the column/expression) + f.formatExpressionStripCast(leftExpr) + + f.buffer.WriteString(" IN (") + // Format array elements as comma-separated list, stripping unnecessary type casts + for i, elem := range arrayElements { + if i > 0 { + f.buffer.WriteString(", ") + } + f.formatExpressionStripCast(elem) + } + + f.buffer.WriteString(")") +} + +// formatScalarArrayOpExpr formats ScalarArrayOpExpr nodes (PostgreSQL's internal array operation representation). +// +// CONTEXT: This function handles a narrow case - formatting view definitions fetched from PostgreSQL +// via pg_get_viewdef(). When PostgreSQL stores views, it converts "IN (...)" to "= ANY(ARRAY[...])" +// internally. When we fetch views back, we get ScalarArrayOpExpr nodes instead of the original A_Expr. +// +// This function converts "= ANY" back to the cleaner "IN (...)" syntax, while preserving +// other operators (>, <, <>, etc.) with their original ANY/ALL syntax. +// +// Example transformations: +// - "value = ANY (ARRAY[1, 2, 3])" → "value IN (1, 2, 3)" (converted) +// - "value > ANY (ARRAY[1, 2, 3])" → "value > ANY (ARRAY[1, 2, 3])" (preserved) +// - "value = ALL (ARRAY[1, 2, 3])" → "value = ALL (ARRAY[1, 2, 3])" (preserved) +func (f *postgreSQLFormatter) formatScalarArrayOpExpr(arrayOp *pg_query.ScalarArrayOpExpr) { + // Validate Args structure if len(arrayOp.Args) != 2 { // Malformed expression, use deparse fallback if deparseResult, err := f.deparseNode(&pg_query.Node{Node: &pg_query.Node_ScalarArrayOpExpr{ScalarArrayOpExpr: arrayOp}}); err == nil { @@ -704,69 +721,54 @@ func (f *postgreSQLFormatter) formatScalarArrayOpExpr(arrayOp *pg_query.ScalarAr return } - // Get the operator name by deparsing + // Deparse once to extract the operator name + // We need to deparse because ScalarArrayOpExpr doesn't directly expose the operator name deparsed, err := f.deparseNode(&pg_query.Node{Node: &pg_query.Node_ScalarArrayOpExpr{ScalarArrayOpExpr: arrayOp}}) if err != nil { - // If deparse fails, just return empty (shouldn't happen in practice) + // If deparse fails, silently return (shouldn't happen in practice) return } - // Extract the operator once to avoid redundant string parsing + // Extract operator from deparsed string (e.g., "value > ANY (...)" → ">") opName := extractOperator(deparsed) - // Check if operator is = (equality) - isEqualityOp := opName == "=" - - // Only convert to IN syntax if it's "= ANY" - if arrayOp.UseOr && isEqualityOp { - // Args[0] is the left side (column), Args[1] is the right side (array) - // Format as "column IN (values)" - // Format left side (the column) - f.formatExpression(arrayOp.Args[0]) - - f.buffer.WriteString(" IN (") - - // Extract values from the array + // Check if this is "= ANY" which can be converted to cleaner "IN" syntax + // - UseOr == true means ANY (disjunction/OR semantics) + // - UseOr == false means ALL (conjunction/AND semantics) + // - Only convert equality with ANY, not other operators or ALL + if arrayOp.UseOr && opName == "=" { + // Convert "column = ANY (ARRAY[...])" → "column IN (...)" if arrayExpr := arrayOp.Args[1].GetArrayExpr(); arrayExpr != nil { - // Format array elements as comma-separated list - for i, elem := range arrayExpr.Elements { - if i > 0 { - f.buffer.WriteString(", ") - } - f.formatExpression(elem) - } - } else { - // Fallback: format the right expression as-is - f.formatExpression(arrayOp.Args[1]) + // Use the shared helper to format as IN syntax + f.formatArrayAsIN(arrayOp.Args[0], arrayExpr.Elements) + return } - - f.buffer.WriteString(")") - return } - // For other operations (like <> ANY, > ANY, = ALL), format manually - // Format: () + // For all other operations (<> ANY, > ANY, < ANY, = ALL, etc.), preserve original syntax + // Format: () - // Format left side + // Format left side (the column/expression) f.formatExpression(arrayOp.Args[0]) - // Use the already-extracted operator + // Format operator if opName != "" { f.buffer.WriteString(" ") f.buffer.WriteString(opName) f.buffer.WriteString(" ") } else { + // Shouldn't happen, but provide fallback f.buffer.WriteString(" ") } - // Format ANY or ALL + // Format ANY or ALL keyword if arrayOp.UseOr { f.buffer.WriteString("ANY (") } else { f.buffer.WriteString("ALL (") } - // Format right side (the array) + // Format right side (the array expression) f.formatExpression(arrayOp.Args[1]) f.buffer.WriteString(")")