diff --git a/ir/inspector.go b/ir/inspector.go index c269b99e..7da2d2f4 100644 --- a/ir/inspector.go +++ b/ir/inspector.go @@ -13,6 +13,25 @@ import ( "golang.org/x/sync/errgroup" ) +// PostgreSQL trigger type bitmask constants from pg_trigger.tgtype +// Reference: https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_trigger.h +const ( + triggerTypeRow = 1 << 0 // TRIGGER_TYPE_ROW - row-level trigger + triggerTypeBefore = 1 << 1 // TRIGGER_TYPE_BEFORE - BEFORE timing + triggerTypeInsert = 1 << 2 // TRIGGER_TYPE_INSERT - INSERT event + triggerTypeDelete = 1 << 3 // TRIGGER_TYPE_DELETE - DELETE event + triggerTypeUpdate = 1 << 4 // TRIGGER_TYPE_UPDATE - UPDATE event + triggerTypeTruncate = 1 << 5 // TRIGGER_TYPE_TRUNCATE - TRUNCATE event + triggerTypeInstead = 1 << 6 // TRIGGER_TYPE_INSTEAD - INSTEAD OF timing +) + +// PostgreSQL sequence type maximum value constants +const ( + smallintMaxValue = 32767 // Maximum value for smallint sequences + integerMaxValue = 2147483647 // Maximum value for integer sequences + bigintMaxValue = 9223372036854775807 // Maximum value for bigint sequences (math.MaxInt64) +) + // Inspector builds IR from database queries type Inspector struct { db *sql.DB @@ -29,7 +48,6 @@ func NewInspector(db *sql.DB, ignoreConfig *IgnoreConfig) *Inspector { } } -// queryGroup represents a group of queries that can be executed concurrently // BuildIR builds the schema IR from the database for a specific schema func (i *Inspector) BuildIR(ctx context.Context, targetSchema string) (*IR, error) { schema := NewIR() @@ -117,6 +135,7 @@ func (i *Inspector) BuildIR(ctx context.Context, targetSchema string) (*IR, erro return schema, nil } +// queryGroup represents a group of queries that can be executed concurrently type queryGroup struct { name string funcs []func(context.Context, *IR, string) error @@ -177,8 +196,7 @@ func (i *Inspector) buildSchemas(ctx context.Context, schema *IR, targetSchema s return err } - name := fmt.Sprintf("%s", schemaName) - schema.getOrCreateSchema(name) + schema.getOrCreateSchema(i.safeInterfaceToString(schemaName)) return nil } @@ -190,9 +208,9 @@ func (i *Inspector) buildTables(ctx context.Context, schema *IR, targetSchema st } for _, table := range tables { - schemaName := fmt.Sprintf("%s", table.TableSchema) - tableName := fmt.Sprintf("%s", table.TableName) - tableType := fmt.Sprintf("%s", table.TableType) + schemaName := i.safeInterfaceToString(table.TableSchema) + tableName := i.safeInterfaceToString(table.TableName) + tableType := i.safeInterfaceToString(table.TableType) comment := "" if table.TableComment.Valid { comment = table.TableComment.String @@ -245,9 +263,9 @@ func (i *Inspector) buildColumns(ctx context.Context, schema *IR, targetSchema s } for _, col := range columns { - schemaName := fmt.Sprintf("%s", col.TableSchema) - tableName := fmt.Sprintf("%s", col.TableName) - columnName := fmt.Sprintf("%s", col.ColumnName) + schemaName := i.safeInterfaceToString(col.TableSchema) + tableName := i.safeInterfaceToString(col.TableName) + columnName := i.safeInterfaceToString(col.ColumnName) comment := "" if col.ColumnComment.Valid { comment = col.ColumnComment.String @@ -270,7 +288,7 @@ func (i *Inspector) buildColumns(ctx context.Context, schema *IR, targetSchema s Name: columnName, Position: i.safeInterfaceToInt(col.OrdinalPosition, 0), DataType: dataType, - IsNullable: fmt.Sprintf("%s", col.IsNullable) == "YES", + IsNullable: i.safeInterfaceToString(col.IsNullable) == "YES", Comment: comment, } @@ -309,10 +327,10 @@ func (i *Inspector) buildColumns(ctx context.Context, schema *IR, targetSchema s } // Handle identity columns - if fmt.Sprintf("%s", col.IsIdentity) == "YES" { + if i.safeInterfaceToString(col.IsIdentity) == "YES" { identity := &Identity{ Generation: i.safeInterfaceToString(col.IdentityGeneration), - Cycle: fmt.Sprintf("%s", col.IdentityCycle) == "YES", + Cycle: i.safeInterfaceToString(col.IdentityCycle) == "YES", } if start := i.safeInterfaceToInt64(col.IdentityStart, -1); start >= 0 { @@ -778,12 +796,12 @@ func (i *Inspector) buildSequences(ctx context.Context, schema *IR, targetSchema dbSchema := schema.getOrCreateSchema(schemaName) // Set empty DataType for sequences that use PostgreSQL's implicit bigint default - dataType := fmt.Sprintf("%s", seq.DataType) + dataType := i.safeInterfaceToString(seq.DataType) if dataType == "bigint" { // Check if this is a default bigint by looking at min/max values - // Default bigint sequences have min_value=1 and max_value=9223372036854775807 + // Default bigint sequences have min_value=1 and max_value=bigintMaxValue if seq.MinimumValue.Valid && seq.MinimumValue.Int64 == 1 && - seq.MaximumValue.Valid && seq.MaximumValue.Int64 == 9223372036854775807 { + seq.MaximumValue.Valid && seq.MaximumValue.Int64 == bigintMaxValue { dataType = "" // This means it was not explicitly specified } } @@ -816,17 +834,7 @@ func (i *Inspector) buildSequences(ctx context.Context, schema *IR, targetSchema if seq.MaximumValue.Valid { maxVal := seq.MaximumValue.Int64 - var defaultMax int64 - switch dataType { - case "smallint": - defaultMax = 32767 // smallint max - case "integer": - defaultMax = 2147483647 // integer max - case "bigint", "": - defaultMax = 9223372036854775807 // bigint max (math.MaxInt64) - default: - defaultMax = 9223372036854775807 // bigint max (math.MaxInt64) - } + defaultMax := getSequenceMaxValueForType(dataType) // Only set if not the default for this data type if maxVal != defaultMax { sequence.MaxValue = &maxVal @@ -888,8 +896,8 @@ func (i *Inspector) buildFunctions(ctx context.Context, schema *IR, targetSchema } for _, fn := range functions { - schemaName := fmt.Sprintf("%s", fn.RoutineSchema) - functionName := fmt.Sprintf("%s", fn.RoutineName) + schemaName := i.safeInterfaceToString(fn.RoutineSchema) + functionName := i.safeInterfaceToString(fn.RoutineName) comment := "" if fn.FunctionComment.Valid { comment = fn.FunctionComment.String @@ -1139,24 +1147,25 @@ func (i *Inspector) stripSameSchemaPrefix(typeName, routineSchema string) string return typeName } +// oidToTypeName maps PostgreSQL type OIDs to standard SQL type names. +// Reference: https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat +var oidToTypeName = map[int64]string{ + 16: "boolean", + 20: "bigint", + 21: "smallint", + 23: "integer", + 25: "text", + 1043: "character varying", + 1082: "date", + 1114: "timestamp without time zone", // Will be normalized later + 1184: "timestamp with time zone", + 1700: "numeric", + 2950: "uuid", +} + // lookupTypeNameFromOID converts PostgreSQL type OID to type name func (i *Inspector) lookupTypeNameFromOID(oid int64) string { - // Common type OID mappings (can be extended as needed) - typeMap := map[int64]string{ - 16: "boolean", - 20: "bigint", - 21: "smallint", - 23: "integer", - 25: "text", - 1043: "character varying", - 1082: "date", - 1114: "timestamp without time zone", // Will be normalized later - 1184: "timestamp with time zone", - 1700: "numeric", - 2950: "uuid", - } - - if typeName, exists := typeMap[oid]; exists { + if typeName, exists := oidToTypeName[oid]; exists { return typeName } @@ -1172,8 +1181,8 @@ func (i *Inspector) buildProcedures(ctx context.Context, schema *IR, targetSchem } for _, proc := range procedures { - schemaName := fmt.Sprintf("%s", proc.RoutineSchema) - procedureName := fmt.Sprintf("%s", proc.RoutineName) + schemaName := i.safeInterfaceToString(proc.RoutineSchema) + procedureName := i.safeInterfaceToString(proc.RoutineName) comment := "" if proc.ProcedureComment.Valid { comment = proc.ProcedureComment.String @@ -1475,14 +1484,10 @@ func (i *Inspector) buildTriggers(ctx context.Context, schema *IR, targetSchema // decodeTriggerTiming decodes trigger timing from pg_trigger.tgtype bitmask func (i *Inspector) decodeTriggerTiming(tgtype int16) TriggerTiming { - // PostgreSQL tgtype encoding for timing: - // TRIGGER_TYPE_BEFORE = 1 << 1 (2) - // TRIGGER_TYPE_INSTEAD = 1 << 6 (64) - // AFTER is represented by the absence of both BEFORE and INSTEAD bits - if tgtype&(1<<6) != 0 { + if tgtype&triggerTypeInstead != 0 { return TriggerTimingInsteadOf } - if tgtype&(1<<1) != 0 { + if tgtype&triggerTypeBefore != 0 { return TriggerTimingBefore } // If neither BEFORE nor INSTEAD, then it's AFTER @@ -1491,23 +1496,18 @@ func (i *Inspector) decodeTriggerTiming(tgtype int16) TriggerTiming { // decodeTriggerEvents decodes trigger events from pg_trigger.tgtype bitmask func (i *Inspector) decodeTriggerEvents(tgtype int16) []TriggerEvent { - // PostgreSQL tgtype encoding for events: - // TRIGGER_TYPE_INSERT = 1 << 2 (4) - // TRIGGER_TYPE_DELETE = 1 << 3 (8) - // TRIGGER_TYPE_UPDATE = 1 << 4 (16) - // TRIGGER_TYPE_TRUNCATE = 1 << 5 (32) var events []TriggerEvent - if tgtype&(1<<2) != 0 { + if tgtype&triggerTypeInsert != 0 { events = append(events, TriggerEventInsert) } - if tgtype&(1<<4) != 0 { + if tgtype&triggerTypeUpdate != 0 { events = append(events, TriggerEventUpdate) } - if tgtype&(1<<3) != 0 { + if tgtype&triggerTypeDelete != 0 { events = append(events, TriggerEventDelete) } - if tgtype&(1<<5) != 0 { + if tgtype&triggerTypeTruncate != 0 { events = append(events, TriggerEventTruncate) } @@ -1516,10 +1516,7 @@ func (i *Inspector) decodeTriggerEvents(tgtype int16) []TriggerEvent { // decodeTriggerLevel decodes trigger level from pg_trigger.tgtype bitmask func (i *Inspector) decodeTriggerLevel(tgtype int16) TriggerLevel { - // PostgreSQL tgtype encoding for level: - // TRIGGER_TYPE_ROW = 1 << 0 (1) - // If bit 0 is set, it's a row-level trigger, otherwise statement-level - if tgtype&(1<<0) != 0 { + if tgtype&triggerTypeRow != 0 { return TriggerLevelRow } return TriggerLevelStatement @@ -2192,3 +2189,16 @@ func (i *Inspector) safeInterfaceToBool(val interface{}, defaultVal bool) bool { } return defaultVal } + +// getSequenceMaxValueForType returns the default maximum value for a sequence based on its data type +func getSequenceMaxValueForType(dataType string) int64 { + switch dataType { + case "smallint": + return smallintMaxValue + case "integer": + return integerMaxValue + default: + // bigint is the default for sequences when no type is specified + return bigintMaxValue + } +} diff --git a/ir/ir.go b/ir/ir.go index c2009df2..71a5858d 100644 --- a/ir/ir.go +++ b/ir/ir.go @@ -146,15 +146,21 @@ type Function struct { // This is built dynamically from the Parameters array to ensure it uses normalized types. // Per PostgreSQL DROP FUNCTION syntax, only input parameters are included (IN, INOUT, VARIADIC). func (f *Function) GetArguments() string { - if len(f.Parameters) == 0 { + return getInputParameterTypes(f.Parameters) +} + +// getInputParameterTypes extracts input parameter types from a parameter list. +// Per PostgreSQL DROP FUNCTION/PROCEDURE syntax, only input parameters are included +// (IN, INOUT, VARIADIC). OUT and TABLE mode parameters are excluded as they're part +// of the return signature. +func getInputParameterTypes(params []*Parameter) string { + if len(params) == 0 { return "" } var argTypes []string - for _, param := range f.Parameters { - // Include only input parameter modes for DROP FUNCTION compatibility - // Exclude OUT and TABLE mode parameters (they're part of return signature) - if param.Mode == "" || param.Mode == "IN" || param.Mode == "INOUT" || param.Mode == "VARIADIC" { + for _, param := range params { + if isInputParameter(param.Mode) { argTypes = append(argTypes, param.DataType) } } @@ -162,6 +168,12 @@ func (f *Function) GetArguments() string { return strings.Join(argTypes, ", ") } +// isInputParameter returns true if the parameter mode represents an input parameter. +// PostgreSQL DROP FUNCTION/PROCEDURE syntax only includes input parameters. +func isInputParameter(mode string) bool { + return mode == "" || mode == "IN" || mode == "INOUT" || mode == "VARIADIC" +} + // Parameter represents a function parameter type Parameter struct { Name string `json:"name"` @@ -388,20 +400,7 @@ type Procedure struct { // This is built dynamically from the Parameters array to ensure it uses normalized types. // Per PostgreSQL DROP PROCEDURE syntax, only input parameters are included (IN, INOUT, VARIADIC). func (p *Procedure) GetArguments() string { - if len(p.Parameters) == 0 { - return "" - } - - var argTypes []string - for _, param := range p.Parameters { - // Include only input parameter modes for DROP PROCEDURE compatibility - // Exclude OUT and TABLE mode parameters (they're part of return signature) - if param.Mode == "" || param.Mode == "IN" || param.Mode == "INOUT" || param.Mode == "VARIADIC" { - argTypes = append(argTypes, param.DataType) - } - } - - return strings.Join(argTypes, ", ") + return getInputParameterTypes(p.Parameters) } // DefaultPrivilegeObjectType represents the object type for default privileges diff --git a/ir/normalize.go b/ir/normalize.go index cccd8af9..2cc068bc 100644 --- a/ir/normalize.go +++ b/ir/normalize.go @@ -712,6 +712,111 @@ func normalizeDomainConstraint(constraint *DomainConstraint) { constraint.Definition = def } +// postgresTypeNormalization maps PostgreSQL internal type names to standard SQL types. +// This map is used by normalizePostgreSQLType to normalize type representations. +var postgresTypeNormalization = map[string]string{ + // Numeric types + "int2": "smallint", + "int4": "integer", + "int8": "bigint", + "float4": "real", + "float8": "double precision", + "bool": "boolean", + "pg_catalog.int2": "smallint", + "pg_catalog.int4": "integer", + "pg_catalog.int8": "bigint", + "pg_catalog.float4": "real", + "pg_catalog.float8": "double precision", + "pg_catalog.bool": "boolean", + "pg_catalog.numeric": "numeric", + + // Character types + "bpchar": "character", + "character varying": "varchar", // Prefer short form + "pg_catalog.text": "text", + "pg_catalog.varchar": "varchar", // Prefer short form + "pg_catalog.bpchar": "character", + + // Date/time types - convert verbose forms to canonical short forms + "timestamp with time zone": "timestamptz", + "timestamp without time zone": "timestamp", + "time with time zone": "timetz", + "timestamptz": "timestamptz", + "timetz": "timetz", + "pg_catalog.timestamptz": "timestamptz", + "pg_catalog.timestamp": "timestamp", + "pg_catalog.date": "date", + "pg_catalog.time": "time", + "pg_catalog.timetz": "timetz", + "pg_catalog.interval": "interval", + + // Array types (internal PostgreSQL array notation with underscore prefix) + "_text": "text[]", + "_int2": "smallint[]", + "_int4": "integer[]", + "_int8": "bigint[]", + "_float4": "real[]", + "_float8": "double precision[]", + "_bool": "boolean[]", + "_varchar": "varchar[]", // Prefer short form + "_char": "character[]", + "_bpchar": "character[]", + "_numeric": "numeric[]", + "_uuid": "uuid[]", + "_json": "json[]", + "_jsonb": "jsonb[]", + "_bytea": "bytea[]", + "_inet": "inet[]", + "_cidr": "cidr[]", + "_macaddr": "macaddr[]", + "_macaddr8": "macaddr8[]", + "_date": "date[]", + "_time": "time[]", + "_timetz": "timetz[]", + "_timestamp": "timestamp[]", + "_timestamptz": "timestamptz[]", + "_interval": "interval[]", + + // Array types (basetype[] format from SQL query) + "int2[]": "smallint[]", + "int4[]": "integer[]", + "int8[]": "bigint[]", + "float4[]": "real[]", + "float8[]": "double precision[]", + "bool[]": "boolean[]", + "varchar[]": "varchar[]", + "bpchar[]": "character[]", + "numeric[]": "numeric[]", + "uuid[]": "uuid[]", + "json[]": "json[]", + "jsonb[]": "jsonb[]", + "bytea[]": "bytea[]", + "inet[]": "inet[]", + "cidr[]": "cidr[]", + "macaddr[]": "macaddr[]", + "macaddr8[]": "macaddr8[]", + "date[]": "date[]", + "time[]": "time[]", + "timetz[]": "timetz[]", + "timestamp[]": "timestamp[]", + "timestamptz[]": "timestamptz[]", + "interval[]": "interval[]", + + // Other common types + "pg_catalog.uuid": "uuid", + "pg_catalog.json": "json", + "pg_catalog.jsonb": "jsonb", + "pg_catalog.bytea": "bytea", + "pg_catalog.inet": "inet", + "pg_catalog.cidr": "cidr", + "pg_catalog.macaddr": "macaddr", + + // Serial types + "serial": "serial", + "smallserial": "smallserial", + "bigserial": "bigserial", +} + // normalizePostgreSQLType normalizes PostgreSQL internal type names to standard SQL types. // This function handles both expressions (with type casts) and direct type names. func normalizePostgreSQLType(input string) string { @@ -719,117 +824,13 @@ func normalizePostgreSQLType(input string) string { return input } - // Map of PostgreSQL internal types to standard SQL types - typeMap := map[string]string{ - // Numeric types - "int2": "smallint", - "int4": "integer", - "int8": "bigint", - "float4": "real", - "float8": "double precision", - "bool": "boolean", - "pg_catalog.int2": "smallint", - "pg_catalog.int4": "integer", - "pg_catalog.int8": "bigint", - "pg_catalog.float4": "real", - "pg_catalog.float8": "double precision", - "pg_catalog.bool": "boolean", - "pg_catalog.numeric": "numeric", - - // Character types - "bpchar": "character", - "character varying": "varchar", // Prefer short form - "pg_catalog.text": "text", - "pg_catalog.varchar": "varchar", // Prefer short form - "pg_catalog.bpchar": "character", - - // Date/time types - convert verbose forms to canonical short forms - "timestamp with time zone": "timestamptz", - "timestamp without time zone": "timestamp", - "time with time zone": "timetz", - "timestamptz": "timestamptz", - "timetz": "timetz", - "pg_catalog.timestamptz": "timestamptz", - "pg_catalog.timestamp": "timestamp", - "pg_catalog.date": "date", - "pg_catalog.time": "time", - "pg_catalog.timetz": "timetz", - "pg_catalog.interval": "interval", - - // Array types (internal PostgreSQL array notation with underscore prefix) - "_text": "text[]", - "_int2": "smallint[]", - "_int4": "integer[]", - "_int8": "bigint[]", - "_float4": "real[]", - "_float8": "double precision[]", - "_bool": "boolean[]", - "_varchar": "varchar[]", // Prefer short form - "_char": "character[]", - "_bpchar": "character[]", - "_numeric": "numeric[]", - "_uuid": "uuid[]", - "_json": "json[]", - "_jsonb": "jsonb[]", - "_bytea": "bytea[]", - "_inet": "inet[]", - "_cidr": "cidr[]", - "_macaddr": "macaddr[]", - "_macaddr8": "macaddr8[]", - "_date": "date[]", - "_time": "time[]", - "_timetz": "timetz[]", - "_timestamp": "timestamp[]", - "_timestamptz": "timestamptz[]", - "_interval": "interval[]", - - // Array types (basetype[] format from SQL query) - "int2[]": "smallint[]", - "int4[]": "integer[]", - "int8[]": "bigint[]", - "float4[]": "real[]", - "float8[]": "double precision[]", - "bool[]": "boolean[]", - "varchar[]": "varchar[]", - "bpchar[]": "character[]", - "numeric[]": "numeric[]", - "uuid[]": "uuid[]", - "json[]": "json[]", - "jsonb[]": "jsonb[]", - "bytea[]": "bytea[]", - "inet[]": "inet[]", - "cidr[]": "cidr[]", - "macaddr[]": "macaddr[]", - "macaddr8[]": "macaddr8[]", - "date[]": "date[]", - "time[]": "time[]", - "timetz[]": "timetz[]", - "timestamp[]": "timestamp[]", - "timestamptz[]": "timestamptz[]", - "interval[]": "interval[]", - - // Other common types - "pg_catalog.uuid": "uuid", - "pg_catalog.json": "json", - "pg_catalog.jsonb": "jsonb", - "pg_catalog.bytea": "bytea", - "pg_catalog.inet": "inet", - "pg_catalog.cidr": "cidr", - "pg_catalog.macaddr": "macaddr", - - // Serial types - "serial": "serial", - "smallserial": "smallserial", - "bigserial": "bigserial", - } - // Check if this is an expression with type casts (contains "::") if strings.Contains(input, "::") { // Handle expressions with type casts expr := input // Replace PostgreSQL internal type names with standard SQL types in type casts - for pgType, sqlType := range typeMap { + for pgType, sqlType := range postgresTypeNormalization { expr = strings.ReplaceAll(expr, "::"+pgType, "::"+sqlType) } @@ -846,7 +847,7 @@ func normalizePostgreSQLType(input string) string { typeName := input // Check if we have a direct mapping - if normalized, exists := typeMap[typeName]; exists { + if normalized, exists := postgresTypeNormalization[typeName]; exists { return normalized }