Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions internal/diff/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,11 @@ func generateIndexSQLWithName(index *ir.Index, indexName string, targetSchema st
builder.WriteString(", ")
}

// Handle JSON expressions with proper parentheses
if strings.Contains(col.Name, "->>") || strings.Contains(col.Name, "->") {
// Use double parentheses for JSON expressions for clean format
builder.WriteString(fmt.Sprintf("((%s))", col.Name))
} else {
builder.WriteString(col.Name)
}
// Use column name as-is from pg_get_indexdef()
// pg_get_indexdef() already handles parenthesization correctly:
// - Regular columns: column_name
// - Expressions: ((expression))
builder.WriteString(col.Name)

// Add direction if specified
if col.Direction != "" && col.Direction != "ASC" {
Expand Down
268 changes: 17 additions & 251 deletions ir/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"strings"
"sync"

pg_query "github.com/pganalyze/pg_query_go/v6"
"github.com/pgschema/pgschema/ir/queries"
"golang.org/x/sync/errgroup"
)
Expand Down Expand Up @@ -684,10 +683,6 @@ func (i *Inspector) buildIndexes(ctx context.Context, schema *IR, targetSchema s
isPartial := indexRow.IsPartial.Valid && indexRow.IsPartial.Bool
hasExpressions := indexRow.HasExpressions.Valid && indexRow.HasExpressions.Bool
method := indexRow.Method
definition := ""
if indexRow.Indexdef.Valid {
definition = indexRow.Indexdef.String
}

// Determine index type based on properties
indexType := IndexTypeRegular
Expand Down Expand Up @@ -722,15 +717,24 @@ func (i *Inspector) buildIndexes(ctx context.Context, schema *IR, targetSchema s
index.Where = indexRow.PartialPredicate.String
}

// Parse index definition to extract columns
if err := i.parseIndexDefinition(index, definition); err != nil {
// If parsing fails, just continue with empty columns
// This ensures backward compatibility
continue
}
// Extract columns directly from query results (no parsing needed!)
// The query uses pg_get_indexdef(indexrelid, column_position, true) for each column
// and extracts ASC/DESC from the indoption array
for idx := 0; idx < len(indexRow.ColumnDefinitions); idx++ {
columnName := indexRow.ColumnDefinitions[idx]
direction := "ASC" // Default
if idx < len(indexRow.ColumnDirections) {
direction = indexRow.ColumnDirections[idx]
}

indexColumn := &IndexColumn{
Name: columnName,
Position: idx + 1,
Direction: direction,
}

// Store the original definition - simplification will be done during read time in diff module
// Definition is now generated on demand, not stored
index.Columns = append(index.Columns, indexColumn)
}

// Add index to table or materialized view
if table, exists := dbSchema.Tables[tableName]; exists {
Expand All @@ -747,244 +751,6 @@ func (i *Inspector) buildIndexes(ctx context.Context, schema *IR, targetSchema s
return nil
}

// parseIndexDefinition parses an index definition string to extract method and columns using pg_query_go
// Expected format: "CREATE [UNIQUE] INDEX index_name ON [schema.]table USING method (column1 [ASC|DESC], column2, ...)"
func (i *Inspector) parseIndexDefinition(index *Index, definition string) error {
if definition == "" {
return fmt.Errorf("empty index definition")
}

// Parse the definition string using pg_query
result, err := pg_query.Parse(definition)
if err != nil {
return fmt.Errorf("failed to parse index definition: %w", err)
}

// Find the CREATE INDEX statement in the parsed result
var indexStmt *pg_query.IndexStmt
for _, stmt := range result.Stmts {
if node := stmt.GetStmt(); node != nil {
if indexNode := node.GetIndexStmt(); indexNode != nil {
indexStmt = indexNode
break
}
}
}

if indexStmt == nil {
return fmt.Errorf("no CREATE INDEX statement found in definition")
}

// Extract index method
if indexStmt.AccessMethod != "" {
index.Method = indexStmt.AccessMethod
} else {
// Default to btree if not specified
index.Method = "btree"
}

// Parse index columns from IndexParams
for idx, indexElem := range indexStmt.IndexParams {
if elem := indexElem.GetIndexElem(); elem != nil {
var columnName string
var direction string

// Extract column name or expression directly from AST
if elem.Name != "" {
// Simple column name
columnName = elem.Name
} else if elem.Expr != nil {
// Expression column - extract directly from AST
columnName = i.extractExpressionFromAST(elem.Expr)
}

// Extract sort direction directly from AST
switch elem.Ordering {
case pg_query.SortByDir_SORTBY_ASC:
direction = "ASC"
case pg_query.SortByDir_SORTBY_DESC:
direction = "DESC"
default:
direction = "ASC" // Default
}

if columnName != "" {
indexColumn := &IndexColumn{
Name: columnName,
Position: idx + 1,
Direction: direction,
}

index.Columns = append(index.Columns, indexColumn)
}
}
}

return nil
}

// extractExpressionFromAST extracts a string representation of an expression node for index definitions
func (i *Inspector) extractExpressionFromAST(expr *pg_query.Node) string {
if expr == nil {
return ""
}

switch n := expr.Node.(type) {
case *pg_query.Node_ColumnRef:
return i.extractColumnNameFromAST(expr)
case *pg_query.Node_AExpr:
// Handle binary expressions like JSON operators
return i.extractBinaryExpressionFromAST(n.AExpr)
case *pg_query.Node_FuncCall:
// Handle function calls in expressions
return i.extractFunctionCallFromAST(n.FuncCall)
case *pg_query.Node_AConst:
// Handle constants
return i.extractConstantValueFromAST(expr)
case *pg_query.Node_TypeCast:
// Handle type casting expressions like 'method'::text
return i.extractTypeCastFromAST(n.TypeCast)
default:
// For unhandled cases, return a placeholder
return "(expression)"
}
}

// extractColumnNameFromAST extracts column name from a ColumnRef node
func (i *Inspector) extractColumnNameFromAST(node *pg_query.Node) string {
if columnRef := node.GetColumnRef(); columnRef != nil {
if len(columnRef.Fields) > 0 {
var parts []string
for _, field := range columnRef.Fields {
if field != nil {
if str := field.GetString_(); str != nil {
parts = append(parts, str.Sval)
}
}
}
if len(parts) > 0 {
return strings.Join(parts, ".")
}
}
}
return ""
}

// extractBinaryExpressionFromAST extracts string representation of binary expressions
func (i *Inspector) extractBinaryExpressionFromAST(aExpr *pg_query.A_Expr) string {
if aExpr == nil {
return ""
}

left := ""
if aExpr.Lexpr != nil {
left = i.extractExpressionFromAST(aExpr.Lexpr)
}

right := ""
if aExpr.Rexpr != nil {
right = i.extractExpressionFromAST(aExpr.Rexpr)
}

operator := ""
if len(aExpr.Name) > 0 {
if opNode := aExpr.Name[0]; opNode != nil {
if str := opNode.GetString_(); str != nil {
operator = str.Sval
}
}
}

if left != "" && right != "" && operator != "" {
// Handle JSON operators specially - don't add extra parentheses
if operator == "->>" || operator == "->" {
return fmt.Sprintf("%s%s%s", left, operator, right)
}
return fmt.Sprintf("(%s %s %s)", left, operator, right)
}

return fmt.Sprintf("(%s)", left)
}

// extractFunctionCallFromAST extracts string representation of function calls
func (i *Inspector) extractFunctionCallFromAST(funcCall *pg_query.FuncCall) string {
if funcCall == nil {
return ""
}

// Extract function name
funcName := ""
if len(funcCall.Funcname) > 0 {
if nameNode := funcCall.Funcname[0]; nameNode != nil {
if str := nameNode.GetString_(); str != nil {
funcName = str.Sval
}
}
}

if funcName == "" {
return "function()"
}

// Extract function arguments
var args []string
if len(funcCall.Args) > 0 {
for _, argNode := range funcCall.Args {
if argNode != nil {
argStr := i.extractExpressionFromAST(argNode)
if argStr != "" {
args = append(args, argStr)
}
}
}
}

// Build function call with arguments
if len(args) > 0 {
return fmt.Sprintf("%s(%s)", funcName, strings.Join(args, ", "))
}
return fmt.Sprintf("%s()", funcName)
}

// extractConstantValueFromAST extracts string representation of constants
func (i *Inspector) extractConstantValueFromAST(node *pg_query.Node) string {
if aConst := node.GetAConst(); aConst != nil {
if aConst.Isnull {
return "NULL"
}
if aConst.Val != nil {
switch val := aConst.Val.(type) {
case *pg_query.A_Const_Sval:
return fmt.Sprintf("'%s'", val.Sval.Sval)
case *pg_query.A_Const_Ival:
return strconv.FormatInt(int64(val.Ival.Ival), 10)
case *pg_query.A_Const_Fval:
return val.Fval.Fval
case *pg_query.A_Const_Boolval:
if val.Boolval.Boolval {
return "true"
}
return "false"
}
}
}
return ""
}

// extractTypeCastFromAST extracts string representation of type cast expressions
func (i *Inspector) extractTypeCastFromAST(typeCast *pg_query.TypeCast) string {
if typeCast == nil {
return ""
}

// Extract the expression being cast
expr := ""
if typeCast.Arg != nil {
expr = i.extractExpressionFromAST(typeCast.Arg)
}

return expr
}

func (i *Inspector) buildSequences(ctx context.Context, schema *IR, targetSchema string) error {
sequences, err := i.queries.GetSequencesForSchema(ctx, sql.NullString{String: targetSchema, Valid: true})
Expand Down
27 changes: 20 additions & 7 deletions ir/queries/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ ORDER BY n.nspname, t.relname, i.relname;

-- GetIndexesForSchema retrieves all indexes for a specific schema
-- name: GetIndexesForSchema :many
SELECT
SELECT
n.nspname as schemaname,
t.relname as tablename,
i.relname as indexname,
Expand All @@ -261,26 +261,39 @@ SELECT
(idx.indpred IS NOT NULL) as is_partial,
am.amname as method,
pg_get_indexdef(idx.indexrelid) as indexdef,
CASE
CASE
WHEN idx.indpred IS NOT NULL THEN pg_get_expr(idx.indpred, idx.indrelid)
ELSE NULL
END as partial_predicate,
CASE
CASE
WHEN idx.indexprs IS NOT NULL THEN true
ELSE false
END as has_expressions,
COALESCE(d.description, '') AS index_comment
COALESCE(d.description, '') AS index_comment,
idx.indnatts as num_columns,
ARRAY(
SELECT pg_get_indexdef(idx.indexrelid, k::int, true)
FROM generate_series(1, idx.indnatts) k
) as column_definitions,
ARRAY(
SELECT
CASE
WHEN (idx.indoption[k-1] & 1) = 1 THEN 'DESC'
ELSE 'ASC'
END
FROM generate_series(1, idx.indnatts) k
) as column_directions
FROM pg_index idx
JOIN pg_class i ON i.oid = idx.indexrelid
JOIN pg_class t ON t.oid = idx.indrelid
JOIN pg_namespace n ON n.oid = t.relnamespace
JOIN pg_am am ON am.oid = i.relam
LEFT JOIN pg_description d ON d.objoid = i.oid AND d.objsubid = 0
WHERE
WHERE
NOT idx.indisprimary
AND NOT EXISTS (
SELECT 1 FROM pg_constraint c
WHERE c.conindid = idx.indexrelid
SELECT 1 FROM pg_constraint c
WHERE c.conindid = idx.indexrelid
AND c.contype IN ('u', 'p')
)
AND n.nspname = $1
Expand Down
Loading