Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions internal/diff/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,18 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector
stmt += generateForeignKeyClause(fkConstraint, targetSchema, true)
}

// Don't add DEFAULT for SERIAL columns, identity columns, or generated columns
if column.DefaultValue != nil && column.Identity == nil && !column.IsGenerated && !isSerialColumn(column) {
stmt += fmt.Sprintf(" DEFAULT %s", *column.DefaultValue)
}

// Don't add NOT NULL for identity columns or SERIAL columns as they are implicitly NOT NULL
// Also skip NOT NULL if we're adding PRIMARY KEY inline (PRIMARY KEY implies NOT NULL)
// For generated columns, include NOT NULL if explicitly specified (but before GENERATED clause)
if !column.IsNullable && column.Identity == nil && !isSerialColumn(column) && pkConstraint == nil {
stmt += " NOT NULL"
}
Comment on lines +549 to +559
Copy link

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The logic for determining when to add DEFAULT and NOT NULL clauses has become complex and duplicated. Consider extracting this into separate helper functions like shouldAddDefault(column) and shouldAddNotNull(column, pkConstraint) to improve readability and maintainability.

Copilot uses AI. Check for mistakes.

// Add identity column syntax
if column.Identity != nil {
switch column.Identity.Generation {
Expand All @@ -556,15 +568,10 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector
}
}

// Don't add DEFAULT for SERIAL columns or if identity is present
if column.DefaultValue != nil && column.Identity == nil && !isSerialColumn(column) {
stmt += fmt.Sprintf(" DEFAULT %s", *column.DefaultValue)
}

// Don't add NOT NULL for identity columns or SERIAL columns as they are implicitly NOT NULL
// Also skip NOT NULL if we're adding PRIMARY KEY inline (PRIMARY KEY implies NOT NULL)
if !column.IsNullable && column.Identity == nil && !isSerialColumn(column) && pkConstraint == nil {
stmt += " NOT NULL"
// Add generated column syntax
if column.IsGenerated && column.GeneratedExpr != nil {
// TODO: Add support for GENERATED ALWAYS AS (...) VIRTUAL when PostgreSQL 18 is supported
stmt += fmt.Sprintf(" GENERATED ALWAYS AS (%s) STORED", *column.GeneratedExpr)
}

// Add PRIMARY KEY inline if present
Expand Down
16 changes: 14 additions & 2 deletions ir/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,21 @@ func (i *Inspector) buildColumns(ctx context.Context, schema *IR, targetSchema s
Comment: comment,
}

// Handle generated columns first
isGeneratedColumn := i.safeInterfaceToString(col.Attgenerated) == "s"
if isGeneratedColumn {
column.IsGenerated = true
if generatedExpr := i.safeInterfaceToString(col.GeneratedExpr); generatedExpr != "" {
column.GeneratedExpr = &generatedExpr
}
}

// Handle default value - keep original value as stored in database
if defaultVal := i.safeInterfaceToString(col.ColumnDefault); defaultVal != "" && defaultVal != "<nil>" {
column.DefaultValue = &defaultVal
// Don't set default values for generated columns
if !isGeneratedColumn {
if defaultVal := i.safeInterfaceToString(col.ColumnDefault); defaultVal != "" && defaultVal != "<nil>" {
column.DefaultValue = &defaultVal
}
}

// Handle max length
Expand Down
22 changes: 12 additions & 10 deletions ir/ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,18 @@ type Table struct {

// Column represents a table column
type Column struct {
Name string `json:"name"`
Position int `json:"position"` // ordinal_position
DataType string `json:"data_type"`
IsNullable bool `json:"is_nullable"`
DefaultValue *string `json:"default_value,omitempty"`
MaxLength *int `json:"max_length,omitempty"`
Precision *int `json:"precision,omitempty"`
Scale *int `json:"scale,omitempty"`
Comment string `json:"comment,omitempty"`
Identity *Identity `json:"identity,omitempty"`
Name string `json:"name"`
Position int `json:"position"` // ordinal_position
DataType string `json:"data_type"`
IsNullable bool `json:"is_nullable"`
DefaultValue *string `json:"default_value,omitempty"`
MaxLength *int `json:"max_length,omitempty"`
Precision *int `json:"precision,omitempty"`
Scale *int `json:"scale,omitempty"`
Comment string `json:"comment,omitempty"`
Identity *Identity `json:"identity,omitempty"`
GeneratedExpr *string `json:"generated_expr,omitempty"` // Expression for generated columns
IsGenerated bool `json:"is_generated,omitempty"` // True if this is a generated column
}

// Identity represents PostgreSQL identity column configuration
Expand Down
44 changes: 44 additions & 0 deletions ir/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,17 @@ func (p *Parser) parseColumnDef(colDef *pg_query.ColumnDef, position int, schema
if checkConstraint := p.parseInlineCheckConstraint(cons, colDef.Colname, schemaName, tableName); checkConstraint != nil {
inlineConstraints = append(inlineConstraints, checkConstraint)
}
case pg_query.ConstrType_CONSTR_GENERATED:
// Handle generated column constraints (GENERATED ALWAYS AS ... STORED)
if cons.RawExpr != nil {
generatedExpr := p.extractGeneratedExpression(cons.RawExpr)
if generatedExpr != "" {
column.GeneratedExpr = &generatedExpr
column.IsGenerated = true
// Generated columns are implicitly NOT NULL
column.IsNullable = false
}
}
}
}
}
Expand Down Expand Up @@ -1030,6 +1041,39 @@ func (p *Parser) extractDefaultValue(expr *pg_query.Node) string {
return ""
}

// extractGeneratedExpression extracts the expression from a generated column constraint
// Uses pg_query deparse to properly extract complex expressions
func (p *Parser) extractGeneratedExpression(expr *pg_query.Node) string {
if expr == nil {
return ""
}

// Create a temporary SELECT statement with just this expression to deparse it
tempSelect := &pg_query.SelectStmt{
TargetList: []*pg_query.Node{{
Node: &pg_query.Node_ResTarget{
ResTarget: &pg_query.ResTarget{Val: expr},
},
}},
}
tempResult := &pg_query.ParseResult{
Stmts: []*pg_query.RawStmt{{
Stmt: &pg_query.Node{
Node: &pg_query.Node_SelectStmt{SelectStmt: tempSelect},
},
}},
}

if deparsed, err := pg_query.Deparse(tempResult); err == nil {
// Extract just the expression part from "SELECT expression"
if expr, found := strings.CutPrefix(deparsed, "SELECT "); found {
return strings.TrimSpace(expr)
}
}

return ""
}

// parseConstraint parses table constraints
func (p *Parser) parseConstraint(constraint *pg_query.Constraint, schemaName, tableName string) *Constraint {
var constraintType ConstraintType
Expand Down
56 changes: 36 additions & 20 deletions ir/queries/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,30 @@ ORDER BY t.table_name;

-- GetColumns retrieves all columns for all tables
-- name: GetColumns :many
SELECT
SELECT
c.table_schema,
c.table_name,
c.column_name,
c.ordinal_position,
COALESCE(pg_get_expr(ad.adbin, ad.adrelid), c.column_default) AS column_default,
CASE
WHEN a.attgenerated = 's' THEN NULL -- Generated columns don't have defaults
ELSE COALESCE(pg_get_expr(ad.adbin, ad.adrelid), c.column_default)
END AS column_default,
c.is_nullable,
c.data_type,
c.character_maximum_length,
c.numeric_precision,
c.numeric_scale,
c.udt_name,
COALESCE(d.description, '') AS column_comment,
CASE
WHEN dt.typtype = 'd' THEN
CASE WHEN dn.nspname = c.table_schema THEN dt.typname
ELSE dn.nspname || '.' || dt.typname
CASE
WHEN dt.typtype = 'd' THEN
CASE WHEN dn.nspname = c.table_schema THEN dt.typname
ELSE dn.nspname || '.' || dt.typname
END
WHEN dt.typtype = 'e' OR dt.typtype = 'c' THEN
CASE WHEN dn.nspname = c.table_schema THEN dt.typname
ELSE dn.nspname || '.' || dt.typname
WHEN dt.typtype = 'e' OR dt.typtype = 'c' THEN
CASE WHEN dn.nspname = c.table_schema THEN dt.typname
ELSE dn.nspname || '.' || dt.typname
END
ELSE c.udt_name
END AS resolved_type,
Expand All @@ -86,7 +89,12 @@ SELECT
c.identity_increment,
c.identity_maximum,
c.identity_minimum,
c.identity_cycle
c.identity_cycle,
a.attgenerated,
CASE
WHEN a.attgenerated = 's' THEN pg_get_expr(ad.adbin, ad.adrelid)
ELSE NULL
END AS generated_expr
FROM information_schema.columns c
LEFT JOIN pg_class cl ON cl.relname = c.table_name
LEFT JOIN pg_namespace n ON cl.relnamespace = n.oid AND n.nspname = c.table_schema
Expand All @@ -103,27 +111,30 @@ ORDER BY c.table_schema, c.table_name, c.ordinal_position;

-- GetColumnsForSchema retrieves all columns for tables in a specific schema
-- name: GetColumnsForSchema :many
SELECT
SELECT
c.table_schema,
c.table_name,
c.column_name,
c.ordinal_position,
COALESCE(pg_get_expr(ad.adbin, ad.adrelid), c.column_default) AS column_default,
CASE
WHEN a.attgenerated = 's' THEN NULL -- Generated columns don't have defaults
ELSE COALESCE(pg_get_expr(ad.adbin, ad.adrelid), c.column_default)
END AS column_default,
c.is_nullable,
c.data_type,
c.character_maximum_length,
c.numeric_precision,
c.numeric_scale,
c.udt_name,
COALESCE(d.description, '') AS column_comment,
CASE
WHEN dt.typtype = 'd' THEN
CASE WHEN dn.nspname = c.table_schema THEN dt.typname
ELSE dn.nspname || '.' || dt.typname
CASE
WHEN dt.typtype = 'd' THEN
CASE WHEN dn.nspname = c.table_schema THEN dt.typname
ELSE dn.nspname || '.' || dt.typname
END
WHEN dt.typtype = 'e' OR dt.typtype = 'c' THEN
CASE WHEN dn.nspname = c.table_schema THEN dt.typname
ELSE dn.nspname || '.' || dt.typname
WHEN dt.typtype = 'e' OR dt.typtype = 'c' THEN
CASE WHEN dn.nspname = c.table_schema THEN dt.typname
ELSE dn.nspname || '.' || dt.typname
END
ELSE c.udt_name
END AS resolved_type,
Expand All @@ -133,7 +144,12 @@ SELECT
c.identity_increment,
c.identity_maximum,
c.identity_minimum,
c.identity_cycle
c.identity_cycle,
a.attgenerated,
CASE
WHEN a.attgenerated = 's' THEN pg_get_expr(ad.adbin, ad.adrelid)
ELSE NULL
END AS generated_expr
FROM information_schema.columns c
LEFT JOIN pg_namespace n ON n.nspname = c.table_schema
LEFT JOIN pg_class cl ON cl.relname = c.table_name AND cl.relnamespace = n.oid
Expand Down
Loading