From 08c979fff26a29d6ba2cec628efe413938b79ac1 Mon Sep 17 00:00:00 2001 From: tianzhou Date: Fri, 31 Oct 2025 16:31:44 +0800 Subject: [PATCH] fix: index operators --- internal/diff/index.go | 6 +++ ir/inspector.go | 31 ++++++++------- ir/queries/queries.sql | 10 ++++- ir/queries/queries.sql.go | 23 +++++------ testdata/diff/create_index/add_index/diff.sql | 12 ++++++ testdata/diff/create_index/add_index/new.sql | 10 +++++ testdata/diff/create_index/add_index/old.sql | 1 + .../diff/create_index/add_index/plan.json | 38 +++++++++++++++++++ testdata/diff/create_index/add_index/plan.sql | 12 ++++++ testdata/diff/create_index/add_index/plan.txt | 26 +++++++++++++ 10 files changed, 144 insertions(+), 25 deletions(-) create mode 100644 testdata/diff/create_index/add_index/diff.sql create mode 100644 testdata/diff/create_index/add_index/new.sql create mode 100644 testdata/diff/create_index/add_index/old.sql create mode 100644 testdata/diff/create_index/add_index/plan.json create mode 100644 testdata/diff/create_index/add_index/plan.sql create mode 100644 testdata/diff/create_index/add_index/plan.txt diff --git a/internal/diff/index.go b/internal/diff/index.go index 5c87ee28..8887d322 100644 --- a/internal/diff/index.go +++ b/internal/diff/index.go @@ -107,6 +107,12 @@ func generateIndexSQLWithName(index *ir.Index, indexName string, targetSchema st // - Expressions: ((expression)) builder.WriteString(col.Name) + // Add operator class if specified (non-default operator class) + if col.Operator != "" { + builder.WriteString(" ") + builder.WriteString(col.Operator) + } + // Add direction if specified if col.Direction != "" && col.Direction != "ASC" { builder.WriteString(" ") diff --git a/ir/inspector.go b/ir/inspector.go index 178df44e..5ff4b158 100644 --- a/ir/inspector.go +++ b/ir/inspector.go @@ -385,7 +385,6 @@ func (i *Inspector) buildPartitions(ctx context.Context, schema *IR, targetSchem return nil } - func (i *Inspector) buildConstraints(ctx context.Context, schema *IR, targetSchema string) error { constraints, err := i.queries.GetConstraintsForSchema(ctx, sql.NullString{String: targetSchema, Valid: true}) if err != nil { @@ -555,7 +554,7 @@ func (i *Inspector) buildConstraints(ctx context.Context, schema *IR, targetSche sort.Slice(constraint.Columns, func(i, j int) bool { return constraint.Columns[i].Position < constraint.Columns[j].Position }) - + // Also sort referenced columns for foreign keys if constraint.Type == ConstraintTypeForeignKey && len(constraint.ReferencedColumns) > 0 { sort.Slice(constraint.ReferencedColumns, func(i, j int) bool { @@ -563,7 +562,7 @@ func (i *Inspector) buildConstraints(ctx context.Context, schema *IR, targetSche }) } } - + table.Constraints[key.name] = constraint // For partitioned tables, ensure primary key columns are ordered with partition key first @@ -717,9 +716,10 @@ func (i *Inspector) buildIndexes(ctx context.Context, schema *IR, targetSchema s index.Where = indexRow.PartialPredicate.String } - // Extract columns directly from query results (no parsing needed!) + // Extract columns directly from query results // The query uses pg_get_indexdef(indexrelid, column_position, true) for each column // and extracts ASC/DESC from the indoption array + // and operator class names from pg_index.indclass joined with pg_opclass for idx := 0; idx < len(indexRow.ColumnDefinitions); idx++ { columnName := indexRow.ColumnDefinitions[idx] direction := "ASC" // Default @@ -727,10 +727,17 @@ func (i *Inspector) buildIndexes(ctx context.Context, schema *IR, targetSchema s direction = indexRow.ColumnDirections[idx] } + // Get operator class from the ColumnOpclasses array + operatorClass := "" + if idx < len(indexRow.ColumnOpclasses) { + operatorClass = indexRow.ColumnOpclasses[idx] + } + indexColumn := &IndexColumn{ Name: columnName, Position: idx + 1, Direction: direction, + Operator: operatorClass, } index.Columns = append(index.Columns, indexColumn) @@ -751,7 +758,6 @@ func (i *Inspector) buildIndexes(ctx context.Context, schema *IR, targetSchema s return nil } - func (i *Inspector) buildSequences(ctx context.Context, schema *IR, targetSchema string) error { sequences, err := i.queries.GetSequencesForSchema(ctx, sql.NullString{String: targetSchema, Valid: true}) if err != nil { @@ -774,8 +780,8 @@ func (i *Inspector) buildSequences(ctx context.Context, schema *IR, targetSchema if dataType == "bigint" { // Check if this is a default bigint by looking at min/max values // Default bigint sequences have min_value=1 and max_value=9223372036854775807 - if seq.MinimumValue.Valid && seq.MinimumValue.Int64 == 1 && - seq.MaximumValue.Valid && seq.MaximumValue.Int64 == 9223372036854775807 { + if seq.MinimumValue.Valid && seq.MinimumValue.Int64 == 1 && + seq.MaximumValue.Valid && seq.MaximumValue.Int64 == 9223372036854775807 { dataType = "" // This means it was not explicitly specified } } @@ -863,13 +869,13 @@ func (i *Inspector) isIdentityColumn(ctx context.Context, schemaName, tableName, WHERE table_schema = $1 AND table_name = $2 AND column_name = $3` - + var isIdentity string err := i.db.QueryRowContext(ctx, query, schemaName, tableName, columnName).Scan(&isIdentity) if err != nil { return false } - + return isIdentity == "YES" } @@ -935,8 +941,8 @@ func (i *Inspector) buildFunctions(ctx context.Context, schema *IR, targetSchema func splitParameterString(signature string) []string { var params []string var current strings.Builder - depth := 0 // Track nesting depth of (), [], {} - inQuote := false // Track if we're inside a string literal + depth := 0 // Track nesting depth of (), [], {} + inQuote := false // Track if we're inside a string literal i := 0 for i < len(signature) { @@ -1061,7 +1067,6 @@ func (i *Inspector) parseParametersFromSignature(signature string) []*Parameter return parameters } - // lookupTypeNameFromOID converts PostgreSQL type OID to type name func (i *Inspector) lookupTypeNameFromOID(oid int64) string { // Common type OID mappings (can be extended as needed) @@ -1273,7 +1278,7 @@ func extractFunctionCallFromTriggerDef(triggerDef string) string { // Start after "EXECUTE FUNCTION " or "EXECUTE PROCEDURE " start := strings.Index(triggerDef[executeIdx:], " ") + executeIdx + 1 // Skip "EXECUTE" - start = strings.Index(triggerDef[start:], " ") + start + 1 // Skip "FUNCTION"/"PROCEDURE" + start = strings.Index(triggerDef[start:], " ") + start + 1 // Skip "FUNCTION"/"PROCEDURE" // The function call extends to the end of the definition (or a semicolon if present) end := len(triggerDef) diff --git a/ir/queries/queries.sql b/ir/queries/queries.sql index 09743aa5..dc988264 100644 --- a/ir/queries/queries.sql +++ b/ir/queries/queries.sql @@ -282,7 +282,15 @@ SELECT ELSE 'ASC' END FROM generate_series(1, idx.indnatts) k - ) as column_directions + ) as column_directions, + ARRAY( + SELECT CASE + WHEN opc.opcdefault THEN '' -- Omit default operator classes + ELSE COALESCE(opc.opcname, '') + END + FROM generate_series(1, idx.indnatts) k + LEFT JOIN pg_opclass opc ON opc.oid = idx.indclass[k-1] + ) as column_opclasses FROM pg_index idx JOIN pg_class i ON i.oid = idx.indexrelid JOIN pg_class t ON t.oid = idx.indrelid diff --git a/ir/queries/queries.sql.go b/ir/queries/queries.sql.go index acc22ec9..838890c2 100644 --- a/ir/queries/queries.sql.go +++ b/ir/queries/queries.sql.go @@ -1204,10 +1204,7 @@ SELECT ELSE NULL END AS volatility, p.proisstrict AS is_strict, - p.prosecdef AS is_security_definer, - p.proargmodes::text[] as proargmodes, - p.proargnames, - p.proallargtypes::oid[]::text[] as proallargtypes + p.prosecdef AS is_security_definer FROM information_schema.routines r LEFT JOIN pg_proc p ON p.proname = r.routine_name AND p.pronamespace = (SELECT oid FROM pg_namespace WHERE nspname = r.routine_schema) @@ -1232,9 +1229,6 @@ type GetFunctionsForSchemaRow struct { Volatility sql.NullString `db:"volatility" json:"volatility"` IsStrict bool `db:"is_strict" json:"is_strict"` IsSecurityDefiner bool `db:"is_security_definer" json:"is_security_definer"` - Proargmodes []string `db:"proargmodes" json:"proargmodes"` - Proargnames []string `db:"proargnames" json:"proargnames"` - Proallargtypes []string `db:"proallargtypes" json:"proallargtypes"` } // GetFunctionsForSchema retrieves all user-defined functions for a specific schema @@ -1260,9 +1254,6 @@ func (q *Queries) GetFunctionsForSchema(ctx context.Context, dollar_1 sql.NullSt &i.Volatility, &i.IsStrict, &i.IsSecurityDefiner, - pq.Array(&i.Proargmodes), - pq.Array(&i.Proargnames), - pq.Array(&i.Proallargtypes), ); err != nil { return nil, err } @@ -1392,7 +1383,15 @@ SELECT ELSE 'ASC' END FROM generate_series(1, idx.indnatts) k - ) as column_directions + ) as column_directions, + ARRAY( + SELECT CASE + WHEN opc.opcdefault THEN '' -- Omit default operator classes + ELSE COALESCE(opc.opcname, '') + END + FROM generate_series(1, idx.indnatts) k + LEFT JOIN pg_opclass opc ON opc.oid = idx.indclass[k-1] + ) as column_opclasses FROM pg_index idx JOIN pg_class i ON i.oid = idx.indexrelid JOIN pg_class t ON t.oid = idx.indrelid @@ -1425,6 +1424,7 @@ type GetIndexesForSchemaRow struct { NumColumns int16 `db:"num_columns" json:"num_columns"` ColumnDefinitions []string `db:"column_definitions" json:"column_definitions"` ColumnDirections []string `db:"column_directions" json:"column_directions"` + ColumnOpclasses []string `db:"column_opclasses" json:"column_opclasses"` } // GetIndexesForSchema retrieves all indexes for a specific schema @@ -1452,6 +1452,7 @@ func (q *Queries) GetIndexesForSchema(ctx context.Context, dollar_1 sql.NullStri &i.NumColumns, pq.Array(&i.ColumnDefinitions), pq.Array(&i.ColumnDirections), + pq.Array(&i.ColumnOpclasses), ); err != nil { return nil, err } diff --git a/testdata/diff/create_index/add_index/diff.sql b/testdata/diff/create_index/add_index/diff.sql new file mode 100644 index 00000000..347c6ce3 --- /dev/null +++ b/testdata/diff/create_index/add_index/diff.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS users ( + id integer, + email varchar(255) NOT NULL, + name varchar(100), + CONSTRAINT users_pkey PRIMARY KEY (id) +); + +CREATE INDEX IF NOT EXISTS idx_users_email ON users (email varchar_pattern_ops); + +CREATE INDEX IF NOT EXISTS idx_users_id ON users (id); + +CREATE INDEX IF NOT EXISTS idx_users_name ON users (name); diff --git a/testdata/diff/create_index/add_index/new.sql b/testdata/diff/create_index/add_index/new.sql new file mode 100644 index 00000000..0bdd0af6 --- /dev/null +++ b/testdata/diff/create_index/add_index/new.sql @@ -0,0 +1,10 @@ +-- Create a new table with a simple index +CREATE TABLE public.users ( + id INTEGER PRIMARY KEY, + email VARCHAR(255) NOT NULL, + name VARCHAR(100) +); + +CREATE INDEX idx_users_name ON public.users (name); +CREATE INDEX idx_users_email ON public.users (email varchar_pattern_ops); +CREATE INDEX idx_users_id ON public.users (id); diff --git a/testdata/diff/create_index/add_index/old.sql b/testdata/diff/create_index/add_index/old.sql new file mode 100644 index 00000000..47f493b8 --- /dev/null +++ b/testdata/diff/create_index/add_index/old.sql @@ -0,0 +1 @@ +-- Empty schema (starting state) diff --git a/testdata/diff/create_index/add_index/plan.json b/testdata/diff/create_index/add_index/plan.json new file mode 100644 index 00000000..05a1e7b0 --- /dev/null +++ b/testdata/diff/create_index/add_index/plan.json @@ -0,0 +1,38 @@ +{ + "version": "1.0.0", + "pgschema_version": "1.4.0", + "created_at": "1970-01-01T00:00:00Z", + "source_fingerprint": { + "hash": "965b1131737c955e24c7f827c55bd78e4cb49a75adfd04229e0ba297376f5085" + }, + "groups": [ + { + "steps": [ + { + "sql": "CREATE TABLE IF NOT EXISTS users (\n id integer,\n email varchar(255) NOT NULL,\n name varchar(100),\n CONSTRAINT users_pkey PRIMARY KEY (id)\n);", + "type": "table", + "operation": "create", + "path": "public.users" + }, + { + "sql": "CREATE INDEX IF NOT EXISTS idx_users_email ON users (email varchar_pattern_ops);", + "type": "table.index", + "operation": "create", + "path": "public.users.idx_users_email" + }, + { + "sql": "CREATE INDEX IF NOT EXISTS idx_users_id ON users (id);", + "type": "table.index", + "operation": "create", + "path": "public.users.idx_users_id" + }, + { + "sql": "CREATE INDEX IF NOT EXISTS idx_users_name ON users (name);", + "type": "table.index", + "operation": "create", + "path": "public.users.idx_users_name" + } + ] + } + ] +} diff --git a/testdata/diff/create_index/add_index/plan.sql b/testdata/diff/create_index/add_index/plan.sql new file mode 100644 index 00000000..347c6ce3 --- /dev/null +++ b/testdata/diff/create_index/add_index/plan.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS users ( + id integer, + email varchar(255) NOT NULL, + name varchar(100), + CONSTRAINT users_pkey PRIMARY KEY (id) +); + +CREATE INDEX IF NOT EXISTS idx_users_email ON users (email varchar_pattern_ops); + +CREATE INDEX IF NOT EXISTS idx_users_id ON users (id); + +CREATE INDEX IF NOT EXISTS idx_users_name ON users (name); diff --git a/testdata/diff/create_index/add_index/plan.txt b/testdata/diff/create_index/add_index/plan.txt new file mode 100644 index 00000000..bd7fc091 --- /dev/null +++ b/testdata/diff/create_index/add_index/plan.txt @@ -0,0 +1,26 @@ +Plan: 1 to add. + +Summary by type: + tables: 1 to add + +Tables: + + users + + idx_users_email (index) + + idx_users_id (index) + + idx_users_name (index) + +DDL to be executed: +-------------------------------------------------- + +CREATE TABLE IF NOT EXISTS users ( + id integer, + email varchar(255) NOT NULL, + name varchar(100), + CONSTRAINT users_pkey PRIMARY KEY (id) +); + +CREATE INDEX IF NOT EXISTS idx_users_email ON users (email varchar_pattern_ops); + +CREATE INDEX IF NOT EXISTS idx_users_id ON users (id); + +CREATE INDEX IF NOT EXISTS idx_users_name ON users (name);