Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 75 additions & 65 deletions ir/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,25 @@ import (
"golang.org/x/sync/errgroup"
)

// PostgreSQL trigger type bitmask constants from pg_trigger.tgtype
// Reference: https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_trigger.h
const (
triggerTypeRow = 1 << 0 // TRIGGER_TYPE_ROW - row-level trigger
triggerTypeBefore = 1 << 1 // TRIGGER_TYPE_BEFORE - BEFORE timing
triggerTypeInsert = 1 << 2 // TRIGGER_TYPE_INSERT - INSERT event
triggerTypeDelete = 1 << 3 // TRIGGER_TYPE_DELETE - DELETE event
triggerTypeUpdate = 1 << 4 // TRIGGER_TYPE_UPDATE - UPDATE event
triggerTypeTruncate = 1 << 5 // TRIGGER_TYPE_TRUNCATE - TRUNCATE event
triggerTypeInstead = 1 << 6 // TRIGGER_TYPE_INSTEAD - INSTEAD OF timing
)

// PostgreSQL sequence type maximum value constants
const (
smallintMaxValue = 32767 // Maximum value for smallint sequences
integerMaxValue = 2147483647 // Maximum value for integer sequences
bigintMaxValue = 9223372036854775807 // Maximum value for bigint sequences (math.MaxInt64)
)

// Inspector builds IR from database queries
type Inspector struct {
db *sql.DB
Expand All @@ -29,7 +48,6 @@ func NewInspector(db *sql.DB, ignoreConfig *IgnoreConfig) *Inspector {
}
}

// queryGroup represents a group of queries that can be executed concurrently
// BuildIR builds the schema IR from the database for a specific schema
func (i *Inspector) BuildIR(ctx context.Context, targetSchema string) (*IR, error) {
schema := NewIR()
Expand Down Expand Up @@ -117,6 +135,7 @@ func (i *Inspector) BuildIR(ctx context.Context, targetSchema string) (*IR, erro
return schema, nil
}

// queryGroup represents a group of queries that can be executed concurrently
type queryGroup struct {
name string
funcs []func(context.Context, *IR, string) error
Expand Down Expand Up @@ -177,8 +196,7 @@ func (i *Inspector) buildSchemas(ctx context.Context, schema *IR, targetSchema s
return err
}

name := fmt.Sprintf("%s", schemaName)
schema.getOrCreateSchema(name)
schema.getOrCreateSchema(i.safeInterfaceToString(schemaName))

return nil
}
Expand All @@ -190,9 +208,9 @@ func (i *Inspector) buildTables(ctx context.Context, schema *IR, targetSchema st
}

for _, table := range tables {
schemaName := fmt.Sprintf("%s", table.TableSchema)
tableName := fmt.Sprintf("%s", table.TableName)
tableType := fmt.Sprintf("%s", table.TableType)
schemaName := i.safeInterfaceToString(table.TableSchema)
tableName := i.safeInterfaceToString(table.TableName)
tableType := i.safeInterfaceToString(table.TableType)
comment := ""
if table.TableComment.Valid {
comment = table.TableComment.String
Expand Down Expand Up @@ -245,9 +263,9 @@ func (i *Inspector) buildColumns(ctx context.Context, schema *IR, targetSchema s
}

for _, col := range columns {
schemaName := fmt.Sprintf("%s", col.TableSchema)
tableName := fmt.Sprintf("%s", col.TableName)
columnName := fmt.Sprintf("%s", col.ColumnName)
schemaName := i.safeInterfaceToString(col.TableSchema)
tableName := i.safeInterfaceToString(col.TableName)
columnName := i.safeInterfaceToString(col.ColumnName)
comment := ""
if col.ColumnComment.Valid {
comment = col.ColumnComment.String
Expand All @@ -270,7 +288,7 @@ func (i *Inspector) buildColumns(ctx context.Context, schema *IR, targetSchema s
Name: columnName,
Position: i.safeInterfaceToInt(col.OrdinalPosition, 0),
DataType: dataType,
IsNullable: fmt.Sprintf("%s", col.IsNullable) == "YES",
IsNullable: i.safeInterfaceToString(col.IsNullable) == "YES",
Comment: comment,
}

Expand Down Expand Up @@ -309,10 +327,10 @@ func (i *Inspector) buildColumns(ctx context.Context, schema *IR, targetSchema s
}

// Handle identity columns
if fmt.Sprintf("%s", col.IsIdentity) == "YES" {
if i.safeInterfaceToString(col.IsIdentity) == "YES" {
identity := &Identity{
Generation: i.safeInterfaceToString(col.IdentityGeneration),
Cycle: fmt.Sprintf("%s", col.IdentityCycle) == "YES",
Cycle: i.safeInterfaceToString(col.IdentityCycle) == "YES",
}

if start := i.safeInterfaceToInt64(col.IdentityStart, -1); start >= 0 {
Expand Down Expand Up @@ -778,12 +796,12 @@ func (i *Inspector) buildSequences(ctx context.Context, schema *IR, targetSchema
dbSchema := schema.getOrCreateSchema(schemaName)

// Set empty DataType for sequences that use PostgreSQL's implicit bigint default
dataType := fmt.Sprintf("%s", seq.DataType)
dataType := i.safeInterfaceToString(seq.DataType)
if dataType == "bigint" {
// Check if this is a default bigint by looking at min/max values
// Default bigint sequences have min_value=1 and max_value=9223372036854775807
// Default bigint sequences have min_value=1 and max_value=bigintMaxValue
if seq.MinimumValue.Valid && seq.MinimumValue.Int64 == 1 &&
seq.MaximumValue.Valid && seq.MaximumValue.Int64 == 9223372036854775807 {
seq.MaximumValue.Valid && seq.MaximumValue.Int64 == bigintMaxValue {
dataType = "" // This means it was not explicitly specified
}
}
Expand Down Expand Up @@ -816,17 +834,7 @@ func (i *Inspector) buildSequences(ctx context.Context, schema *IR, targetSchema

if seq.MaximumValue.Valid {
maxVal := seq.MaximumValue.Int64
var defaultMax int64
switch dataType {
case "smallint":
defaultMax = 32767 // smallint max
case "integer":
defaultMax = 2147483647 // integer max
case "bigint", "":
defaultMax = 9223372036854775807 // bigint max (math.MaxInt64)
default:
defaultMax = 9223372036854775807 // bigint max (math.MaxInt64)
}
defaultMax := getSequenceMaxValueForType(dataType)
// Only set if not the default for this data type
if maxVal != defaultMax {
sequence.MaxValue = &maxVal
Expand Down Expand Up @@ -888,8 +896,8 @@ func (i *Inspector) buildFunctions(ctx context.Context, schema *IR, targetSchema
}

for _, fn := range functions {
schemaName := fmt.Sprintf("%s", fn.RoutineSchema)
functionName := fmt.Sprintf("%s", fn.RoutineName)
schemaName := i.safeInterfaceToString(fn.RoutineSchema)
functionName := i.safeInterfaceToString(fn.RoutineName)
comment := ""
if fn.FunctionComment.Valid {
comment = fn.FunctionComment.String
Expand Down Expand Up @@ -1139,24 +1147,25 @@ func (i *Inspector) stripSameSchemaPrefix(typeName, routineSchema string) string
return typeName
}

// oidToTypeName maps PostgreSQL type OIDs to standard SQL type names.
// Reference: https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat
var oidToTypeName = map[int64]string{
16: "boolean",
20: "bigint",
21: "smallint",
23: "integer",
25: "text",
1043: "character varying",
1082: "date",
1114: "timestamp without time zone", // Will be normalized later
1184: "timestamp with time zone",
1700: "numeric",
2950: "uuid",
}

// lookupTypeNameFromOID converts PostgreSQL type OID to type name
func (i *Inspector) lookupTypeNameFromOID(oid int64) string {
// Common type OID mappings (can be extended as needed)
typeMap := map[int64]string{
16: "boolean",
20: "bigint",
21: "smallint",
23: "integer",
25: "text",
1043: "character varying",
1082: "date",
1114: "timestamp without time zone", // Will be normalized later
1184: "timestamp with time zone",
1700: "numeric",
2950: "uuid",
}

if typeName, exists := typeMap[oid]; exists {
if typeName, exists := oidToTypeName[oid]; exists {
return typeName
}

Expand All @@ -1172,8 +1181,8 @@ func (i *Inspector) buildProcedures(ctx context.Context, schema *IR, targetSchem
}

for _, proc := range procedures {
schemaName := fmt.Sprintf("%s", proc.RoutineSchema)
procedureName := fmt.Sprintf("%s", proc.RoutineName)
schemaName := i.safeInterfaceToString(proc.RoutineSchema)
procedureName := i.safeInterfaceToString(proc.RoutineName)
comment := ""
if proc.ProcedureComment.Valid {
comment = proc.ProcedureComment.String
Expand Down Expand Up @@ -1475,14 +1484,10 @@ func (i *Inspector) buildTriggers(ctx context.Context, schema *IR, targetSchema

// decodeTriggerTiming decodes trigger timing from pg_trigger.tgtype bitmask
func (i *Inspector) decodeTriggerTiming(tgtype int16) TriggerTiming {
// PostgreSQL tgtype encoding for timing:
// TRIGGER_TYPE_BEFORE = 1 << 1 (2)
// TRIGGER_TYPE_INSTEAD = 1 << 6 (64)
// AFTER is represented by the absence of both BEFORE and INSTEAD bits
if tgtype&(1<<6) != 0 {
if tgtype&triggerTypeInstead != 0 {
return TriggerTimingInsteadOf
}
if tgtype&(1<<1) != 0 {
if tgtype&triggerTypeBefore != 0 {
return TriggerTimingBefore
}
// If neither BEFORE nor INSTEAD, then it's AFTER
Expand All @@ -1491,23 +1496,18 @@ func (i *Inspector) decodeTriggerTiming(tgtype int16) TriggerTiming {

// decodeTriggerEvents decodes trigger events from pg_trigger.tgtype bitmask
func (i *Inspector) decodeTriggerEvents(tgtype int16) []TriggerEvent {
// PostgreSQL tgtype encoding for events:
// TRIGGER_TYPE_INSERT = 1 << 2 (4)
// TRIGGER_TYPE_DELETE = 1 << 3 (8)
// TRIGGER_TYPE_UPDATE = 1 << 4 (16)
// TRIGGER_TYPE_TRUNCATE = 1 << 5 (32)
var events []TriggerEvent

if tgtype&(1<<2) != 0 {
if tgtype&triggerTypeInsert != 0 {
events = append(events, TriggerEventInsert)
}
if tgtype&(1<<4) != 0 {
if tgtype&triggerTypeUpdate != 0 {
events = append(events, TriggerEventUpdate)
}
if tgtype&(1<<3) != 0 {
if tgtype&triggerTypeDelete != 0 {
events = append(events, TriggerEventDelete)
}
if tgtype&(1<<5) != 0 {
if tgtype&triggerTypeTruncate != 0 {
events = append(events, TriggerEventTruncate)
}

Expand All @@ -1516,10 +1516,7 @@ func (i *Inspector) decodeTriggerEvents(tgtype int16) []TriggerEvent {

// decodeTriggerLevel decodes trigger level from pg_trigger.tgtype bitmask
func (i *Inspector) decodeTriggerLevel(tgtype int16) TriggerLevel {
// PostgreSQL tgtype encoding for level:
// TRIGGER_TYPE_ROW = 1 << 0 (1)
// If bit 0 is set, it's a row-level trigger, otherwise statement-level
if tgtype&(1<<0) != 0 {
if tgtype&triggerTypeRow != 0 {
return TriggerLevelRow
}
return TriggerLevelStatement
Expand Down Expand Up @@ -2192,3 +2189,16 @@ func (i *Inspector) safeInterfaceToBool(val interface{}, defaultVal bool) bool {
}
return defaultVal
}

// getSequenceMaxValueForType returns the default maximum value for a sequence based on its data type
func getSequenceMaxValueForType(dataType string) int64 {
switch dataType {
case "smallint":
return smallintMaxValue
case "integer":
return integerMaxValue
default:
// bigint is the default for sequences when no type is specified
return bigintMaxValue
}
}
37 changes: 18 additions & 19 deletions ir/ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,22 +146,34 @@ type Function struct {
// This is built dynamically from the Parameters array to ensure it uses normalized types.
// Per PostgreSQL DROP FUNCTION syntax, only input parameters are included (IN, INOUT, VARIADIC).
func (f *Function) GetArguments() string {
if len(f.Parameters) == 0 {
return getInputParameterTypes(f.Parameters)
}

// getInputParameterTypes extracts input parameter types from a parameter list.
// Per PostgreSQL DROP FUNCTION/PROCEDURE syntax, only input parameters are included
// (IN, INOUT, VARIADIC). OUT and TABLE mode parameters are excluded as they're part
// of the return signature.
func getInputParameterTypes(params []*Parameter) string {
if len(params) == 0 {
return ""
}

var argTypes []string
for _, param := range f.Parameters {
// Include only input parameter modes for DROP FUNCTION compatibility
// Exclude OUT and TABLE mode parameters (they're part of return signature)
if param.Mode == "" || param.Mode == "IN" || param.Mode == "INOUT" || param.Mode == "VARIADIC" {
for _, param := range params {
if isInputParameter(param.Mode) {
argTypes = append(argTypes, param.DataType)
}
}

return strings.Join(argTypes, ", ")
}

// isInputParameter returns true if the parameter mode represents an input parameter.
// PostgreSQL DROP FUNCTION/PROCEDURE syntax only includes input parameters.
func isInputParameter(mode string) bool {
return mode == "" || mode == "IN" || mode == "INOUT" || mode == "VARIADIC"
}

// Parameter represents a function parameter
type Parameter struct {
Name string `json:"name"`
Expand Down Expand Up @@ -388,20 +400,7 @@ type Procedure struct {
// This is built dynamically from the Parameters array to ensure it uses normalized types.
// Per PostgreSQL DROP PROCEDURE syntax, only input parameters are included (IN, INOUT, VARIADIC).
func (p *Procedure) GetArguments() string {
if len(p.Parameters) == 0 {
return ""
}

var argTypes []string
for _, param := range p.Parameters {
// Include only input parameter modes for DROP PROCEDURE compatibility
// Exclude OUT and TABLE mode parameters (they're part of return signature)
if param.Mode == "" || param.Mode == "IN" || param.Mode == "INOUT" || param.Mode == "VARIADIC" {
argTypes = append(argTypes, param.DataType)
}
}

return strings.Join(argTypes, ", ")
return getInputParameterTypes(p.Parameters)
}

// DefaultPrivilegeObjectType represents the object type for default privileges
Expand Down
Loading
Loading