From 68d0e4b78d680dc9badb8d171f777a18e9a1e76c Mon Sep 17 00:00:00 2001 From: tianzhou Date: Thu, 25 Sep 2025 15:24:18 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20GENERATED=20ALWAYS=20AS=20(=E2=80=A6)?= =?UTF-8?q?=20STORED?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/diff/table.go | 25 +++++--- ir/inspector.go | 16 ++++- ir/ir.go | 22 ++++--- ir/parser.go | 44 +++++++++++++ ir/queries/queries.sql | 56 ++++++++++------ ir/queries/queries.sql.go | 64 +++++++++++++------ .../add_column_generated/diff.sql | 1 + .../create_table/add_column_generated/new.sql | 4 ++ .../create_table/add_column_generated/old.sql | 3 + .../add_column_generated/plan.json | 20 ++++++ .../add_column_generated/plan.sql | 1 + .../add_column_generated/plan.txt | 13 ++++ 12 files changed, 208 insertions(+), 61 deletions(-) create mode 100644 testdata/diff/create_table/add_column_generated/diff.sql create mode 100644 testdata/diff/create_table/add_column_generated/new.sql create mode 100644 testdata/diff/create_table/add_column_generated/old.sql create mode 100644 testdata/diff/create_table/add_column_generated/plan.json create mode 100644 testdata/diff/create_table/add_column_generated/plan.sql create mode 100644 testdata/diff/create_table/add_column_generated/plan.txt diff --git a/internal/diff/table.go b/internal/diff/table.go index c2c7e724..2a817dd1 100644 --- a/internal/diff/table.go +++ b/internal/diff/table.go @@ -546,6 +546,18 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector stmt += generateForeignKeyClause(fkConstraint, targetSchema, true) } + // Don't add DEFAULT for SERIAL columns, identity columns, or generated columns + if column.DefaultValue != nil && column.Identity == nil && !column.IsGenerated && !isSerialColumn(column) { + stmt += fmt.Sprintf(" DEFAULT %s", *column.DefaultValue) + } + + // Don't add NOT NULL for identity columns or SERIAL columns as they are implicitly NOT NULL + // Also skip NOT NULL if we're adding PRIMARY KEY inline (PRIMARY KEY implies NOT NULL) + // For generated columns, include NOT NULL if explicitly specified (but before GENERATED clause) + if !column.IsNullable && column.Identity == nil && !isSerialColumn(column) && pkConstraint == nil { + stmt += " NOT NULL" + } + // Add identity column syntax if column.Identity != nil { switch column.Identity.Generation { @@ -556,15 +568,10 @@ func (td *tableDiff) generateAlterTableStatements(targetSchema string, collector } } - // Don't add DEFAULT for SERIAL columns or if identity is present - if column.DefaultValue != nil && column.Identity == nil && !isSerialColumn(column) { - stmt += fmt.Sprintf(" DEFAULT %s", *column.DefaultValue) - } - - // Don't add NOT NULL for identity columns or SERIAL columns as they are implicitly NOT NULL - // Also skip NOT NULL if we're adding PRIMARY KEY inline (PRIMARY KEY implies NOT NULL) - if !column.IsNullable && column.Identity == nil && !isSerialColumn(column) && pkConstraint == nil { - stmt += " NOT NULL" + // Add generated column syntax + if column.IsGenerated && column.GeneratedExpr != nil { + // TODO: Add support for GENERATED ALWAYS AS (...) VIRTUAL when PostgreSQL 18 is supported + stmt += fmt.Sprintf(" GENERATED ALWAYS AS (%s) STORED", *column.GeneratedExpr) } // Add PRIMARY KEY inline if present diff --git a/ir/inspector.go b/ir/inspector.go index af877585..1c485e1d 100644 --- a/ir/inspector.go +++ b/ir/inspector.go @@ -270,9 +270,21 @@ func (i *Inspector) buildColumns(ctx context.Context, schema *IR, targetSchema s Comment: comment, } + // Handle generated columns first + isGeneratedColumn := i.safeInterfaceToString(col.Attgenerated) == "s" + if isGeneratedColumn { + column.IsGenerated = true + if generatedExpr := i.safeInterfaceToString(col.GeneratedExpr); generatedExpr != "" { + column.GeneratedExpr = &generatedExpr + } + } + // Handle default value - keep original value as stored in database - if defaultVal := i.safeInterfaceToString(col.ColumnDefault); defaultVal != "" && defaultVal != "" { - column.DefaultValue = &defaultVal + // Don't set default values for generated columns + if !isGeneratedColumn { + if defaultVal := i.safeInterfaceToString(col.ColumnDefault); defaultVal != "" && defaultVal != "" { + column.DefaultValue = &defaultVal + } } // Handle max length diff --git a/ir/ir.go b/ir/ir.go index 9a732658..a9811aba 100644 --- a/ir/ir.go +++ b/ir/ir.go @@ -57,16 +57,18 @@ type Table struct { // Column represents a table column type Column struct { - Name string `json:"name"` - Position int `json:"position"` // ordinal_position - DataType string `json:"data_type"` - IsNullable bool `json:"is_nullable"` - DefaultValue *string `json:"default_value,omitempty"` - MaxLength *int `json:"max_length,omitempty"` - Precision *int `json:"precision,omitempty"` - Scale *int `json:"scale,omitempty"` - Comment string `json:"comment,omitempty"` - Identity *Identity `json:"identity,omitempty"` + Name string `json:"name"` + Position int `json:"position"` // ordinal_position + DataType string `json:"data_type"` + IsNullable bool `json:"is_nullable"` + DefaultValue *string `json:"default_value,omitempty"` + MaxLength *int `json:"max_length,omitempty"` + Precision *int `json:"precision,omitempty"` + Scale *int `json:"scale,omitempty"` + Comment string `json:"comment,omitempty"` + Identity *Identity `json:"identity,omitempty"` + GeneratedExpr *string `json:"generated_expr,omitempty"` // Expression for generated columns + IsGenerated bool `json:"is_generated,omitempty"` // True if this is a generated column } // Identity represents PostgreSQL identity column configuration diff --git a/ir/parser.go b/ir/parser.go index 130679f6..7490af30 100644 --- a/ir/parser.go +++ b/ir/parser.go @@ -752,6 +752,17 @@ func (p *Parser) parseColumnDef(colDef *pg_query.ColumnDef, position int, schema if checkConstraint := p.parseInlineCheckConstraint(cons, colDef.Colname, schemaName, tableName); checkConstraint != nil { inlineConstraints = append(inlineConstraints, checkConstraint) } + case pg_query.ConstrType_CONSTR_GENERATED: + // Handle generated column constraints (GENERATED ALWAYS AS ... STORED) + if cons.RawExpr != nil { + generatedExpr := p.extractGeneratedExpression(cons.RawExpr) + if generatedExpr != "" { + column.GeneratedExpr = &generatedExpr + column.IsGenerated = true + // Generated columns are implicitly NOT NULL + column.IsNullable = false + } + } } } } @@ -1030,6 +1041,39 @@ func (p *Parser) extractDefaultValue(expr *pg_query.Node) string { return "" } +// extractGeneratedExpression extracts the expression from a generated column constraint +// Uses pg_query deparse to properly extract complex expressions +func (p *Parser) extractGeneratedExpression(expr *pg_query.Node) string { + if expr == nil { + return "" + } + + // Create a temporary SELECT statement with just this expression to deparse it + tempSelect := &pg_query.SelectStmt{ + TargetList: []*pg_query.Node{{ + Node: &pg_query.Node_ResTarget{ + ResTarget: &pg_query.ResTarget{Val: expr}, + }, + }}, + } + tempResult := &pg_query.ParseResult{ + Stmts: []*pg_query.RawStmt{{ + Stmt: &pg_query.Node{ + Node: &pg_query.Node_SelectStmt{SelectStmt: tempSelect}, + }, + }}, + } + + if deparsed, err := pg_query.Deparse(tempResult); err == nil { + // Extract just the expression part from "SELECT expression" + if expr, found := strings.CutPrefix(deparsed, "SELECT "); found { + return strings.TrimSpace(expr) + } + } + + return "" +} + // parseConstraint parses table constraints func (p *Parser) parseConstraint(constraint *pg_query.Constraint, schemaName, tableName string) *Constraint { var constraintType ConstraintType diff --git a/ir/queries/queries.sql b/ir/queries/queries.sql index 4a033a1d..f66a2ce3 100644 --- a/ir/queries/queries.sql +++ b/ir/queries/queries.sql @@ -56,12 +56,15 @@ ORDER BY t.table_name; -- GetColumns retrieves all columns for all tables -- name: GetColumns :many -SELECT +SELECT c.table_schema, c.table_name, c.column_name, c.ordinal_position, - COALESCE(pg_get_expr(ad.adbin, ad.adrelid), c.column_default) AS column_default, + CASE + WHEN a.attgenerated = 's' THEN NULL -- Generated columns don't have defaults + ELSE COALESCE(pg_get_expr(ad.adbin, ad.adrelid), c.column_default) + END AS column_default, c.is_nullable, c.data_type, c.character_maximum_length, @@ -69,14 +72,14 @@ SELECT c.numeric_scale, c.udt_name, COALESCE(d.description, '') AS column_comment, - CASE - WHEN dt.typtype = 'd' THEN - CASE WHEN dn.nspname = c.table_schema THEN dt.typname - ELSE dn.nspname || '.' || dt.typname + CASE + WHEN dt.typtype = 'd' THEN + CASE WHEN dn.nspname = c.table_schema THEN dt.typname + ELSE dn.nspname || '.' || dt.typname END - WHEN dt.typtype = 'e' OR dt.typtype = 'c' THEN - CASE WHEN dn.nspname = c.table_schema THEN dt.typname - ELSE dn.nspname || '.' || dt.typname + WHEN dt.typtype = 'e' OR dt.typtype = 'c' THEN + CASE WHEN dn.nspname = c.table_schema THEN dt.typname + ELSE dn.nspname || '.' || dt.typname END ELSE c.udt_name END AS resolved_type, @@ -86,7 +89,12 @@ SELECT c.identity_increment, c.identity_maximum, c.identity_minimum, - c.identity_cycle + c.identity_cycle, + a.attgenerated, + CASE + WHEN a.attgenerated = 's' THEN pg_get_expr(ad.adbin, ad.adrelid) + ELSE NULL + END AS generated_expr FROM information_schema.columns c LEFT JOIN pg_class cl ON cl.relname = c.table_name LEFT JOIN pg_namespace n ON cl.relnamespace = n.oid AND n.nspname = c.table_schema @@ -103,12 +111,15 @@ ORDER BY c.table_schema, c.table_name, c.ordinal_position; -- GetColumnsForSchema retrieves all columns for tables in a specific schema -- name: GetColumnsForSchema :many -SELECT +SELECT c.table_schema, c.table_name, c.column_name, c.ordinal_position, - COALESCE(pg_get_expr(ad.adbin, ad.adrelid), c.column_default) AS column_default, + CASE + WHEN a.attgenerated = 's' THEN NULL -- Generated columns don't have defaults + ELSE COALESCE(pg_get_expr(ad.adbin, ad.adrelid), c.column_default) + END AS column_default, c.is_nullable, c.data_type, c.character_maximum_length, @@ -116,14 +127,14 @@ SELECT c.numeric_scale, c.udt_name, COALESCE(d.description, '') AS column_comment, - CASE - WHEN dt.typtype = 'd' THEN - CASE WHEN dn.nspname = c.table_schema THEN dt.typname - ELSE dn.nspname || '.' || dt.typname + CASE + WHEN dt.typtype = 'd' THEN + CASE WHEN dn.nspname = c.table_schema THEN dt.typname + ELSE dn.nspname || '.' || dt.typname END - WHEN dt.typtype = 'e' OR dt.typtype = 'c' THEN - CASE WHEN dn.nspname = c.table_schema THEN dt.typname - ELSE dn.nspname || '.' || dt.typname + WHEN dt.typtype = 'e' OR dt.typtype = 'c' THEN + CASE WHEN dn.nspname = c.table_schema THEN dt.typname + ELSE dn.nspname || '.' || dt.typname END ELSE c.udt_name END AS resolved_type, @@ -133,7 +144,12 @@ SELECT c.identity_increment, c.identity_maximum, c.identity_minimum, - c.identity_cycle + c.identity_cycle, + a.attgenerated, + CASE + WHEN a.attgenerated = 's' THEN pg_get_expr(ad.adbin, ad.adrelid) + ELSE NULL + END AS generated_expr FROM information_schema.columns c LEFT JOIN pg_namespace n ON n.nspname = c.table_schema LEFT JOIN pg_class cl ON cl.relname = c.table_name AND cl.relnamespace = n.oid diff --git a/ir/queries/queries.sql.go b/ir/queries/queries.sql.go index f11e65ae..7ff4643f 100644 --- a/ir/queries/queries.sql.go +++ b/ir/queries/queries.sql.go @@ -191,12 +191,15 @@ func (q *Queries) GetAggregatesForSchema(ctx context.Context, dollar_1 sql.NullS } const getColumns = `-- name: GetColumns :many -SELECT +SELECT c.table_schema, c.table_name, c.column_name, c.ordinal_position, - COALESCE(pg_get_expr(ad.adbin, ad.adrelid), c.column_default) AS column_default, + CASE + WHEN a.attgenerated = 's' THEN NULL -- Generated columns don't have defaults + ELSE COALESCE(pg_get_expr(ad.adbin, ad.adrelid), c.column_default) + END AS column_default, c.is_nullable, c.data_type, c.character_maximum_length, @@ -204,14 +207,14 @@ SELECT c.numeric_scale, c.udt_name, COALESCE(d.description, '') AS column_comment, - CASE - WHEN dt.typtype = 'd' THEN - CASE WHEN dn.nspname = c.table_schema THEN dt.typname - ELSE dn.nspname || '.' || dt.typname + CASE + WHEN dt.typtype = 'd' THEN + CASE WHEN dn.nspname = c.table_schema THEN dt.typname + ELSE dn.nspname || '.' || dt.typname END - WHEN dt.typtype = 'e' OR dt.typtype = 'c' THEN - CASE WHEN dn.nspname = c.table_schema THEN dt.typname - ELSE dn.nspname || '.' || dt.typname + WHEN dt.typtype = 'e' OR dt.typtype = 'c' THEN + CASE WHEN dn.nspname = c.table_schema THEN dt.typname + ELSE dn.nspname || '.' || dt.typname END ELSE c.udt_name END AS resolved_type, @@ -221,7 +224,12 @@ SELECT c.identity_increment, c.identity_maximum, c.identity_minimum, - c.identity_cycle + c.identity_cycle, + a.attgenerated, + CASE + WHEN a.attgenerated = 's' THEN pg_get_expr(ad.adbin, ad.adrelid) + ELSE NULL + END AS generated_expr FROM information_schema.columns c LEFT JOIN pg_class cl ON cl.relname = c.table_name LEFT JOIN pg_namespace n ON cl.relnamespace = n.oid AND n.nspname = c.table_schema @@ -258,6 +266,8 @@ type GetColumnsRow struct { IdentityMaximum interface{} `db:"identity_maximum" json:"identity_maximum"` IdentityMinimum interface{} `db:"identity_minimum" json:"identity_minimum"` IdentityCycle interface{} `db:"identity_cycle" json:"identity_cycle"` + Attgenerated interface{} `db:"attgenerated" json:"attgenerated"` + GeneratedExpr sql.NullString `db:"generated_expr" json:"generated_expr"` } // GetColumns retrieves all columns for all tables @@ -291,6 +301,8 @@ func (q *Queries) GetColumns(ctx context.Context) ([]GetColumnsRow, error) { &i.IdentityMaximum, &i.IdentityMinimum, &i.IdentityCycle, + &i.Attgenerated, + &i.GeneratedExpr, ); err != nil { return nil, err } @@ -306,12 +318,15 @@ func (q *Queries) GetColumns(ctx context.Context) ([]GetColumnsRow, error) { } const getColumnsForSchema = `-- name: GetColumnsForSchema :many -SELECT +SELECT c.table_schema, c.table_name, c.column_name, c.ordinal_position, - COALESCE(pg_get_expr(ad.adbin, ad.adrelid), c.column_default) AS column_default, + CASE + WHEN a.attgenerated = 's' THEN NULL -- Generated columns don't have defaults + ELSE COALESCE(pg_get_expr(ad.adbin, ad.adrelid), c.column_default) + END AS column_default, c.is_nullable, c.data_type, c.character_maximum_length, @@ -319,14 +334,14 @@ SELECT c.numeric_scale, c.udt_name, COALESCE(d.description, '') AS column_comment, - CASE - WHEN dt.typtype = 'd' THEN - CASE WHEN dn.nspname = c.table_schema THEN dt.typname - ELSE dn.nspname || '.' || dt.typname + CASE + WHEN dt.typtype = 'd' THEN + CASE WHEN dn.nspname = c.table_schema THEN dt.typname + ELSE dn.nspname || '.' || dt.typname END - WHEN dt.typtype = 'e' OR dt.typtype = 'c' THEN - CASE WHEN dn.nspname = c.table_schema THEN dt.typname - ELSE dn.nspname || '.' || dt.typname + WHEN dt.typtype = 'e' OR dt.typtype = 'c' THEN + CASE WHEN dn.nspname = c.table_schema THEN dt.typname + ELSE dn.nspname || '.' || dt.typname END ELSE c.udt_name END AS resolved_type, @@ -336,7 +351,12 @@ SELECT c.identity_increment, c.identity_maximum, c.identity_minimum, - c.identity_cycle + c.identity_cycle, + a.attgenerated, + CASE + WHEN a.attgenerated = 's' THEN pg_get_expr(ad.adbin, ad.adrelid) + ELSE NULL + END AS generated_expr FROM information_schema.columns c LEFT JOIN pg_namespace n ON n.nspname = c.table_schema LEFT JOIN pg_class cl ON cl.relname = c.table_name AND cl.relnamespace = n.oid @@ -371,6 +391,8 @@ type GetColumnsForSchemaRow struct { IdentityMaximum interface{} `db:"identity_maximum" json:"identity_maximum"` IdentityMinimum interface{} `db:"identity_minimum" json:"identity_minimum"` IdentityCycle interface{} `db:"identity_cycle" json:"identity_cycle"` + Attgenerated interface{} `db:"attgenerated" json:"attgenerated"` + GeneratedExpr sql.NullString `db:"generated_expr" json:"generated_expr"` } // GetColumnsForSchema retrieves all columns for tables in a specific schema @@ -404,6 +426,8 @@ func (q *Queries) GetColumnsForSchema(ctx context.Context, dollar_1 sql.NullStri &i.IdentityMaximum, &i.IdentityMinimum, &i.IdentityCycle, + &i.Attgenerated, + &i.GeneratedExpr, ); err != nil { return nil, err } diff --git a/testdata/diff/create_table/add_column_generated/diff.sql b/testdata/diff/create_table/add_column_generated/diff.sql new file mode 100644 index 00000000..1bf85d91 --- /dev/null +++ b/testdata/diff/create_table/add_column_generated/diff.sql @@ -0,0 +1 @@ +ALTER TABLE merge_request ADD COLUMN iid integer NOT NULL GENERATED ALWAYS AS (CAST(data ->> 'iid' AS int)) STORED; \ No newline at end of file diff --git a/testdata/diff/create_table/add_column_generated/new.sql b/testdata/diff/create_table/add_column_generated/new.sql new file mode 100644 index 00000000..1e9e885f --- /dev/null +++ b/testdata/diff/create_table/add_column_generated/new.sql @@ -0,0 +1,4 @@ +CREATE TABLE public.merge_request ( + data jsonb NOT NULL, + iid integer NOT NULL GENERATED ALWAYS AS ((data ->> 'iid')::integer) STORED +); \ No newline at end of file diff --git a/testdata/diff/create_table/add_column_generated/old.sql b/testdata/diff/create_table/add_column_generated/old.sql new file mode 100644 index 00000000..0bd3d3d1 --- /dev/null +++ b/testdata/diff/create_table/add_column_generated/old.sql @@ -0,0 +1,3 @@ +CREATE TABLE public.merge_request ( + data jsonb NOT NULL +); \ No newline at end of file diff --git a/testdata/diff/create_table/add_column_generated/plan.json b/testdata/diff/create_table/add_column_generated/plan.json new file mode 100644 index 00000000..17586e0c --- /dev/null +++ b/testdata/diff/create_table/add_column_generated/plan.json @@ -0,0 +1,20 @@ +{ + "version": "1.0.0", + "pgschema_version": "1.1.1", + "created_at": "1970-01-01T00:00:00Z", + "source_fingerprint": { + "hash": "67e51d7bd6b2b020d7a95dbe7f8e5f9bd9ec2e3f4068b9ed09bb63b79e656354" + }, + "groups": [ + { + "steps": [ + { + "sql": "ALTER TABLE merge_request ADD COLUMN iid integer NOT NULL GENERATED ALWAYS AS (CAST(data ->> 'iid' AS int)) STORED;", + "type": "table.column", + "operation": "create", + "path": "public.merge_request.iid" + } + ] + } + ] +} diff --git a/testdata/diff/create_table/add_column_generated/plan.sql b/testdata/diff/create_table/add_column_generated/plan.sql new file mode 100644 index 00000000..53e7aa12 --- /dev/null +++ b/testdata/diff/create_table/add_column_generated/plan.sql @@ -0,0 +1 @@ +ALTER TABLE merge_request ADD COLUMN iid integer NOT NULL GENERATED ALWAYS AS (CAST(data ->> 'iid' AS int)) STORED; diff --git a/testdata/diff/create_table/add_column_generated/plan.txt b/testdata/diff/create_table/add_column_generated/plan.txt new file mode 100644 index 00000000..a0705884 --- /dev/null +++ b/testdata/diff/create_table/add_column_generated/plan.txt @@ -0,0 +1,13 @@ +Plan: 1 to modify. + +Summary by type: + tables: 1 to modify + +Tables: + ~ merge_request + + iid (column) + +DDL to be executed: +-------------------------------------------------- + +ALTER TABLE merge_request ADD COLUMN iid integer NOT NULL GENERATED ALWAYS AS (CAST(data ->> 'iid' AS int)) STORED;