From 48b98eb2ed24d7a9f0d1c8ed5059e6ba7c9c2e11 Mon Sep 17 00:00:00 2001 From: Adam Mustafa Date: Thu, 27 Nov 2025 10:52:41 -0500 Subject: [PATCH 1/2] Adds a `--quote-all` option to the command line --- cmd/dump/dump.go | 16 ++- cmd/dump/dump_integration_test.go | 181 ++++++++++++++++++++++++ cmd/plan/plan.go | 16 ++- cmd/root.go | 7 + internal/diff/diff.go | 81 ++++++++++- internal/diff/table.go | 19 +-- ir/quote.go | 23 +++ ir/quote_test.go | 56 +++++++- testdata/dump/quote_all_test/pgdump.sql | 130 +++++++++++++++++ 9 files changed, 512 insertions(+), 17 deletions(-) create mode 100644 testdata/dump/quote_all_test/pgdump.sql diff --git a/cmd/dump/dump.go b/cmd/dump/dump.go index dbce9052..1971e86c 100644 --- a/cmd/dump/dump.go +++ b/cmd/dump/dump.go @@ -2,6 +2,7 @@ package dump import ( "fmt" + "log" "os" "github.com/pgschema/pgschema/cmd/util" @@ -32,6 +33,7 @@ type DumpConfig struct { Schema string MultiFile bool File string + QuoteAll bool } var DumpCmd = &cobra.Command{ @@ -79,7 +81,7 @@ func ExecuteDump(config *DumpConfig) (string, error) { emptyIR := ir.NewIR() // Generate diff between empty schema and target schema (this represents a complete dump) - diffs := diff.GenerateMigration(emptyIR, schemaIR, config.Schema) + diffs := diff.GenerateMigration(emptyIR, schemaIR, config.Schema, diff.QuoteAll(config.QuoteAll)) // Create dump formatter formatter := dump.NewDumpFormatter(schemaIR.Metadata.DatabaseVersion, config.Schema) @@ -107,6 +109,17 @@ func runDump(cmd *cobra.Command, args []string) error { } } + // Get quote-all flag from root command + var quoteAll bool + if cmd != nil { + q, err := cmd.Root().PersistentFlags().GetBool("quote-all") + if err == nil { + quoteAll = q + } else { + log.Printf("Failed to get quote-all flag: %v\n", err) + } + } + // Create config from command-line flags config := &DumpConfig{ Host: host, @@ -117,6 +130,7 @@ func runDump(cmd *cobra.Command, args []string) error { Schema: schema, MultiFile: multiFile, File: file, + QuoteAll: quoteAll, } // Execute dump diff --git a/cmd/dump/dump_integration_test.go b/cmd/dump/dump_integration_test.go index b36af55c..40e4ce39 100644 --- a/cmd/dump/dump_integration_test.go +++ b/cmd/dump/dump_integration_test.go @@ -11,6 +11,7 @@ import ( "context" "fmt" "os" + "regexp" "strings" "testing" @@ -303,3 +304,183 @@ func compareSchemaOutputs(t *testing.T, actualOutput, expectedOutput string, tes } } } + +// TestDumpCommand_QuoteAll validates the --quote-all flag behavior +func TestDumpCommand_QuoteAll(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + runQuoteAllTest(t, "quote_all_test") +} + +// runQuoteAllTest validates that the --quote-all flag correctly quotes all identifiers +func runQuoteAllTest(t *testing.T, testDataDir string) { + // Setup PostgreSQL + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + + // Connect to database + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Detect PostgreSQL version and skip tests if needed + majorVersion, err := testutil.GetMajorVersion(conn) + if err != nil { + t.Fatalf("Failed to detect PostgreSQL version: %v", err) + } + + // Check if this test should be skipped for this PostgreSQL version + testutil.ShouldSkipTest(t, t.Name(), majorVersion) + + // Read and execute the pgdump.sql file + pgdumpPath := fmt.Sprintf("../../testdata/dump/%s/pgdump.sql", testDataDir) + pgdumpContent, err := os.ReadFile(pgdumpPath) + if err != nil { + t.Fatalf("Failed to read %s: %v", pgdumpPath, err) + } + + // Execute the SQL to create the schema + _, err = conn.ExecContext(context.Background(), string(pgdumpContent)) + if err != nil { + t.Fatalf("Failed to execute pgdump.sql: %v", err) + } + + // Test 1: Dump without --quote-all (normal behavior) + configNormal := &DumpConfig{ + Host: host, + Port: port, + DB: dbname, + User: user, + Password: password, + Schema: "public", + MultiFile: false, + File: "", + QuoteAll: false, + } + + normalOutput, err := ExecuteDump(configNormal) + if err != nil { + t.Fatalf("Dump command failed without quote-all: %v", err) + } + + // Test 2: Dump with --quote-all (all identifiers quoted) + configQuoteAll := &DumpConfig{ + Host: host, + Port: port, + DB: dbname, + User: user, + Password: password, + Schema: "public", + MultiFile: false, + File: "", + QuoteAll: true, + } + + quoteAllOutput, err := ExecuteDump(configQuoteAll) + if err != nil { + t.Fatalf("Dump command failed with quote-all: %v", err) + } + + // Validate quote-all behavior + validateQuoteAllBehavior(t, normalOutput, quoteAllOutput, testDataDir) +} + +// validateQuoteAllBehavior verifies that --quote-all produces correctly quoted output +func validateQuoteAllBehavior(t *testing.T, normalOutput, quoteAllOutput, testName string) { + // Split outputs into lines for analysis + normalLines := strings.Split(normalOutput, "\n") + quoteAllLines := strings.Split(quoteAllOutput, "\n") + + // Both outputs should have the same number of lines + if len(normalLines) != len(quoteAllLines) { + t.Fatalf("Different number of lines - Normal: %d, QuoteAll: %d", len(normalLines), len(quoteAllLines)) + } + + // Track identifiers that should be quoted in normal mode vs quote-all mode + var normalQuotedIdentifiers []string + var quoteAllQuotedIdentifiers []string + + // Regular expression to find quoted identifiers + quotedIdentifierRegex := `"([^"]+)"` + + for i, normalLine := range normalLines { + quoteAllLine := quoteAllLines[i] + + // Skip comment lines and empty lines + if strings.HasPrefix(strings.TrimSpace(normalLine), "--") || strings.TrimSpace(normalLine) == "" { + continue + } + + // Extract quoted identifiers from both outputs + normalMatches := regexp.MustCompile(quotedIdentifierRegex).FindAllStringSubmatch(normalLine, -1) + quoteAllMatches := regexp.MustCompile(quotedIdentifierRegex).FindAllStringSubmatch(quoteAllLine, -1) + + for _, match := range normalMatches { + normalQuotedIdentifiers = append(normalQuotedIdentifiers, match[1]) + } + + for _, match := range quoteAllMatches { + quoteAllQuotedIdentifiers = append(quoteAllQuotedIdentifiers, match[1]) + } + } + + // Validate expectations: + // 1. Quote-all mode should have more quoted identifiers than normal mode + if len(quoteAllQuotedIdentifiers) <= len(normalQuotedIdentifiers) { + t.Errorf("Quote-all mode should have more quoted identifiers. Normal: %d, QuoteAll: %d", + len(normalQuotedIdentifiers), len(quoteAllQuotedIdentifiers)) + } + + // 2. All identifiers that were quoted in normal mode should also be quoted in quote-all mode + normalQuotedSet := make(map[string]bool) + for _, id := range normalQuotedIdentifiers { + normalQuotedSet[id] = true + } + + quoteAllQuotedSet := make(map[string]bool) + for _, id := range quoteAllQuotedIdentifiers { + quoteAllQuotedSet[id] = true + } + + for identifier := range normalQuotedSet { + if !quoteAllQuotedSet[identifier] { + t.Errorf("Identifier '%s' was quoted in normal mode but not in quote-all mode", identifier) + } + } + + // 3. Verify specific expected behaviors + // Note: Currently only table and column names support quote-all. Other objects (indexes, sequences, views, functions) are not yet implemented + expectedNormalQuoted := []string{"order", "MixedCase", "ID", "FirstName", "LastName", "SpecialColumn", "Index_Order_Status", "MixedCase_pkey"} + expectedQuoteAllOnly := []string{"users", "id", "first_name", "last_name", "email", "created_at", "user_id", "total_amount", "status"} + + // Check that expected identifiers are quoted in normal mode + for _, identifier := range expectedNormalQuoted { + if !normalQuotedSet[identifier] { + t.Errorf("Expected identifier '%s' to be quoted in normal mode, but it wasn't", identifier) + } + } + + // Check that additional identifiers are quoted only in quote-all mode + for _, identifier := range expectedQuoteAllOnly { + if normalQuotedSet[identifier] { + t.Errorf("Identifier '%s' should not be quoted in normal mode", identifier) + } + if !quoteAllQuotedSet[identifier] { + t.Errorf("Identifier '%s' should be quoted in quote-all mode", identifier) + } + } + + // Write outputs to files for debugging if test fails + if t.Failed() { + normalFilename := fmt.Sprintf("%s_normal.sql", testName) + os.WriteFile(normalFilename, []byte(normalOutput), 0644) + + quoteAllFilename := fmt.Sprintf("%s_quote_all.sql", testName) + os.WriteFile(quoteAllFilename, []byte(quoteAllOutput), 0644) + + t.Logf("Outputs written to %s and %s for debugging", normalFilename, quoteAllFilename) + t.Logf("Normal quoted identifiers: %v", normalQuotedIdentifiers) + t.Logf("Quote-all quoted identifiers: %v", quoteAllQuotedIdentifiers) + } +} diff --git a/cmd/plan/plan.go b/cmd/plan/plan.go index 2aea9d08..131e8e06 100644 --- a/cmd/plan/plan.go +++ b/cmd/plan/plan.go @@ -3,6 +3,7 @@ package plan import ( "context" "fmt" + "log" "os" "path/filepath" "strings" @@ -100,6 +101,17 @@ func runPlan(cmd *cobra.Command, args []string) error { } } + // Get quote-all flag from root command + var quoteAll bool + if cmd != nil { + q, err := cmd.Root().PersistentFlags().GetBool("quote-all") + if err == nil { + quoteAll = q + } else { + log.Printf("Failed to get quote-all flag: %v\n", err) + } + } + // Create plan configuration config := &PlanConfig{ Host: planHost, @@ -110,6 +122,7 @@ func runPlan(cmd *cobra.Command, args []string) error { Schema: planSchema, File: planFile, ApplicationName: "pgschema", + QuoteAll: quoteAll, // Plan database configuration PlanDBHost: planDBHost, PlanDBPort: planDBPort, @@ -157,6 +170,7 @@ type PlanConfig struct { Schema string File string ApplicationName string + QuoteAll bool // Plan database configuration (optional - for external database) PlanDBHost string PlanDBPort int @@ -285,7 +299,7 @@ func GeneratePlan(config *PlanConfig, provider postgres.DesiredStateProvider) (* } // Generate diff (current -> desired) using IR directly - diffs := diff.GenerateMigration(currentStateIR, desiredStateIR, config.Schema) + diffs := diff.GenerateMigration(currentStateIR, desiredStateIR, config.Schema, diff.QuoteAll(config.QuoteAll)) // Create plan from diffs with fingerprint migrationPlan := plan.NewPlanWithFingerprint(diffs, sourceFingerprint) diff --git a/cmd/root.go b/cmd/root.go index 5ea599d0..a6435e2d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -15,6 +15,7 @@ import ( ) var Debug bool +var QuoteAll bool var logger *slog.Logger // Build-time variables set via ldflags @@ -45,6 +46,7 @@ Use "pgschema [command] --help" for more information about a command.`, func init() { RootCmd.PersistentFlags().BoolVar(&Debug, "debug", false, "Enable debug logging") + RootCmd.PersistentFlags().BoolVar(&QuoteAll, "quote-all", false, "Quote all identifiers regardless of whether they are reserved words") RootCmd.CompletionOptions.DisableDefaultCmd = true RootCmd.AddCommand(dump.DumpCmd) RootCmd.AddCommand(plan.PlanCmd) @@ -78,6 +80,11 @@ func IsDebug() bool { return Debug } +// IsQuoteAll returns whether quote-all mode is enabled +func IsQuoteAll() bool { + return QuoteAll +} + // platform returns the OS/architecture combination func platform() string { return runtime.GOOS + "/" + runtime.GOARCH diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 98d0d0c9..eacaa1fc 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -243,6 +243,7 @@ type ddlDiff struct { addedSequences []*ir.Sequence droppedSequences []*ir.Sequence modifiedSequences []*sequenceDiff + quoteAll bool } // schemaDiff represents changes to a schema @@ -348,8 +349,63 @@ type rlsChange struct { Enabled bool // true to enable, false to disable } -// GenerateMigration compares two IR schemas and returns the SQL differences -func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string) []Diff { +// Option represents a configuration option for migration generation +type Option func(*options) + +// options holds configuration for migration generation +type options struct { + quoteAll bool +} + +// QuoteAll configures whether all identifiers should be quoted, regardless of whether +// they are PostgreSQL reserved words. When enabled, all table names, column names, +// and other identifiers will be quoted with double quotes. +// +// Example: +// - QuoteAll(false): CREATE TABLE users (id int, name text) +// - QuoteAll(true): CREATE TABLE "users" ("id" int, "name" text) +// +// This is useful for: +// - Ensuring consistent quoting across all DDL statements +// - Avoiding potential conflicts with future PostgreSQL reserved words +// - Maintaining compatibility with case-sensitive identifier requirements +func QuoteAll(enabled bool) Option { + return func(opts *options) { + opts.quoteAll = enabled + } +} + +// GenerateMigration compares two IR schemas and returns the SQL differences. +// It accepts optional configuration through the Option pattern. +// +// Parameters: +// - oldIR: The current/source schema state +// - newIR: The desired/target schema state +// - targetSchema: The schema name to use in generated DDL +// - opts: Optional configuration (e.g., QuoteAll(true)) +// +// Returns a slice of Diff objects representing the migration steps needed +// to transform oldIR into newIR. +// +// Example usage: +// + +// tandard migration +// +// +//diffs := GenerateMigration(oldIR, newIR, "public") +// +// +// Migration with all identifiers quoted +// +// diffs := GenerateMigration(oldIR, newIR, "public", QuoteAll(true)) +func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string, opts ...Option) []Diff { + // Parse options + config := &options{} + for _, opt := range opts { + opt(config) + } + diff := &ddlDiff{ addedSchemas: []*ir.Schema{}, droppedSchemas: []*ir.Schema{}, @@ -372,6 +428,7 @@ func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string) []Diff { addedSequences: []*ir.Sequence{}, droppedSequences: []*ir.Sequence{}, modifiedSequences: []*sequenceDiff{}, + quoteAll: config.quoteAll, } // Compare schemas first in deterministic order @@ -905,6 +962,11 @@ func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string) []Diff { return collector.diffs } +// quoteIdentifier quotes an identifier according to the quoteAll setting +func (d *ddlDiff) quoteIdentifier(identifier string) string { + return ir.QuoteIdentifierWithForce(identifier, d.quoteAll) +} + // collectMigrationSQL populates the collector with SQL statements for the diff // The collector must not be nil func (d *ddlDiff) collectMigrationSQL(targetSchema string, collector *diffCollector) { @@ -947,7 +1009,6 @@ func (d *ddlDiff) generateCreateSQL(targetSchema string, collector *diffCollecto // This ensures we create functions before tables that use them in defaults/checks tablesWithoutFunctionDeps := []*ir.Table{} tablesWithFunctionDeps := []*ir.Table{} - for _, table := range d.addedTables { if tableReferencesNewFunction(table, newFunctionLookup) { tablesWithFunctionDeps = append(tablesWithFunctionDeps, table) @@ -957,7 +1018,7 @@ func (d *ddlDiff) generateCreateSQL(targetSchema string, collector *diffCollecto } // Create tables WITHOUT function dependencies first (functions may reference these) - deferredPolicies1, deferredConstraints1 := generateCreateTablesSQL(tablesWithoutFunctionDeps, targetSchema, collector, existingTables, shouldDeferPolicy) + deferredPolicies1, deferredConstraints1 := generateCreateTablesSQL(tablesWithoutFunctionDeps, targetSchema, collector, existingTables, shouldDeferPolicy, d.quoteAll) // Add deferred foreign key constraints from first batch generateDeferredConstraintsSQL(deferredConstraints1, targetSchema, collector) @@ -969,7 +1030,7 @@ func (d *ddlDiff) generateCreateSQL(targetSchema string, collector *diffCollecto generateCreateProceduresSQL(d.addedProcedures, targetSchema, collector) // Create tables WITH function dependencies (now that functions exist) - deferredPolicies2, deferredConstraints2 := generateCreateTablesSQL(tablesWithFunctionDeps, targetSchema, collector, existingTables, shouldDeferPolicy) + deferredPolicies2, deferredConstraints2 := generateCreateTablesSQL(tablesWithFunctionDeps, targetSchema, collector, existingTables, shouldDeferPolicy, d.quoteAll) // Add deferred foreign key constraints from second batch generateDeferredConstraintsSQL(deferredConstraints2, targetSchema, collector) @@ -1064,6 +1125,16 @@ func qualifyEntityName(entitySchema, entityName, targetSchema string) string { return fmt.Sprintf("%s.%s", quotedSchema, quotedName) } +// qualifyEntityNameWithForce quotes and qualifies an entity name with optional force quoting +func qualifyEntityNameWithForce(entitySchema, entityName, targetSchema string, quoteAll bool) string { + quotedName := ir.QuoteIdentifierWithForce(entityName, quoteAll) + if entitySchema == targetSchema { + return quotedName + } + quotedSchema := ir.QuoteIdentifierWithForce(entitySchema, quoteAll) + return fmt.Sprintf("%s.%s", quotedSchema, quotedName) +} + // quoteString properly quotes a string for SQL, handling single quotes func quoteString(s string) string { // Escape single quotes by doubling them diff --git a/internal/diff/table.go b/internal/diff/table.go index c68f0a20..7e825aea 100644 --- a/internal/diff/table.go +++ b/internal/diff/table.go @@ -351,6 +351,7 @@ func generateCreateTablesSQL( collector *diffCollector, existingTables map[string]bool, shouldDeferPolicy func(*ir.RLSPolicy) bool, + quoteAll bool, ) ([]*ir.RLSPolicy, []*deferredConstraint) { var deferredPolicies []*ir.RLSPolicy var deferredConstraints []*deferredConstraint @@ -359,7 +360,7 @@ func generateCreateTablesSQL( // Process tables in the provided order (already topologically sorted) for _, table := range tables { // Create the table, deferring FK constraints that reference not-yet-created tables - sql, tableDeferred := generateTableSQL(table, targetSchema, createdTables, existingTables) + sql, tableDeferred := generateTableSQL(table, targetSchema, createdTables, existingTables, quoteAll) deferredConstraints = append(deferredConstraints, tableDeferred...) // Create context for this statement @@ -375,7 +376,7 @@ func generateCreateTablesSQL( // Add table comment if table.Comment != "" { - tableName := qualifyEntityName(table.Schema, table.Name, targetSchema) + tableName := qualifyEntityNameWithForce(table.Schema, table.Name, targetSchema, quoteAll) sql := fmt.Sprintf("COMMENT ON TABLE %s IS %s;", tableName, quoteString(table.Comment)) // Create context for this statement @@ -393,8 +394,8 @@ func generateCreateTablesSQL( // Add column comments for _, column := range table.Columns { if column.Comment != "" { - tableName := qualifyEntityName(table.Schema, table.Name, targetSchema) - sql := fmt.Sprintf("COMMENT ON COLUMN %s.%s IS %s;", tableName, column.Name, quoteString(column.Comment)) + tableName := qualifyEntityNameWithForce(table.Schema, table.Name, targetSchema, quoteAll) + sql := fmt.Sprintf("COMMENT ON COLUMN %s.%s IS %s;", tableName, ir.QuoteIdentifierWithForce(column.Name, quoteAll), quoteString(column.Comment)) // Create context for this statement context := &diffContext{ @@ -509,9 +510,9 @@ func generateDropTablesSQL(tables []*ir.Table, targetSchema string, collector *d } // generateTableSQL generates CREATE TABLE statement and returns any deferred FK constraints -func generateTableSQL(table *ir.Table, targetSchema string, createdTables map[string]bool, existingTables map[string]bool) (string, []*deferredConstraint) { +func generateTableSQL(table *ir.Table, targetSchema string, createdTables map[string]bool, existingTables map[string]bool, quoteAll bool) (string, []*deferredConstraint) { // Only include table name without schema if it's in the target schema - tableName := ir.QualifyEntityNameWithQuotes(table.Schema, table.Name, targetSchema) + tableName := ir.QualifyEntityNameWithQuotesAndForce(table.Schema, table.Name, targetSchema, quoteAll) var parts []string parts = append(parts, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (", tableName)) @@ -521,7 +522,7 @@ func generateTableSQL(table *ir.Table, targetSchema string, createdTables map[st for _, column := range table.Columns { // Build column definition with SERIAL detection var builder strings.Builder - writeColumnDefinitionToBuilder(&builder, table, column, targetSchema) + writeColumnDefinitionToBuilder(&builder, table, column, targetSchema, quoteAll) columnParts = append(columnParts, fmt.Sprintf(" %s", builder.String())) } @@ -1133,8 +1134,8 @@ func ensureCheckClauseParens(s string) string { // writeColumnDefinitionToBuilder builds column definitions with SERIAL detection and proper formatting // This is moved from ir/table.go to consolidate SQL generation in the diff module -func writeColumnDefinitionToBuilder(builder *strings.Builder, table *ir.Table, column *ir.Column, targetSchema string) { - builder.WriteString(ir.QuoteIdentifier(column.Name)) +func writeColumnDefinitionToBuilder(builder *strings.Builder, table *ir.Table, column *ir.Column, targetSchema string, quoteAll bool) { + builder.WriteString(ir.QuoteIdentifierWithForce(column.Name, quoteAll)) builder.WriteString(" ") // Data type - handle array types and precision/scale for appropriate types diff --git a/ir/quote.go b/ir/quote.go index 0c5ac60d..1ded98e7 100644 --- a/ir/quote.go +++ b/ir/quote.go @@ -131,6 +131,17 @@ func QuoteIdentifier(identifier string) string { return identifier } +// QuoteIdentifierWithForce adds quotes to an identifier based on forceQuote flag +func QuoteIdentifierWithForce(identifier string, forceQuote bool) string { + if identifier == "" { + return "" + } + if forceQuote || NeedsQuoting(identifier) { + return `"` + identifier + `"` + } + return identifier +} + // QualifyEntityNameWithQuotes returns the properly qualified and quoted entity name func QualifyEntityNameWithQuotes(entitySchema, entityName, targetSchema string) string { quotedName := QuoteIdentifier(entityName) @@ -141,4 +152,16 @@ func QualifyEntityNameWithQuotes(entitySchema, entityName, targetSchema string) quotedSchema := QuoteIdentifier(entitySchema) return quotedSchema + "." + quotedName +} + +// QualifyEntityNameWithQuotesAndForce returns the properly qualified and quoted entity name with forceQuote option +func QualifyEntityNameWithQuotesAndForce(entitySchema, entityName, targetSchema string, forceQuote bool) string { + quotedName := QuoteIdentifierWithForce(entityName, forceQuote) + + if entitySchema == targetSchema { + return quotedName + } + + quotedSchema := QuoteIdentifierWithForce(entitySchema, forceQuote) + return quotedSchema + "." + quotedName } \ No newline at end of file diff --git a/ir/quote_test.go b/ir/quote_test.go index a64dfd23..65ab1d05 100644 --- a/ir/quote_test.go +++ b/ir/quote_test.go @@ -84,4 +84,58 @@ func TestQualifyEntityNameWithQuotes(t *testing.T) { } }) } -} \ No newline at end of file +} + +func TestQuoteIdentifierWithForce(t *testing.T) { + tests := []struct { + name string + identifier string + forceQuote bool + expected string + }{ + {"simple without force", "users", false, "users"}, + {"simple with force", "users", true, `"users"`}, + {"reserved word without force", "user", false, `"user"`}, + {"reserved word with force", "user", true, `"user"`}, + {"camelCase without force", "firstName", false, `"firstName"`}, + {"camelCase with force", "firstName", true, `"firstName"`}, + {"empty string without force", "", false, ""}, + {"empty string with force", "", true, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := QuoteIdentifierWithForce(tt.identifier, tt.forceQuote) + if result != tt.expected { + t.Errorf("QuoteIdentifierWithForce(%q, %v) = %q; want %q", tt.identifier, tt.forceQuote, result, tt.expected) + } + }) + } +} + +func TestQualifyEntityNameWithQuotesAndForce(t *testing.T) { + tests := []struct { + name string + entitySchema string + entityName string + targetSchema string + forceQuote bool + expected string + }{ + {"same schema without force", "public", "users", "public", false, "users"}, + {"same schema with force", "public", "users", "public", true, `"users"`}, + {"different schema without force", "tenant", "users", "public", false, "tenant.users"}, + {"different schema with force", "tenant", "users", "public", true, `"tenant"."users"`}, + {"reserved word schema with force", "user", "table", "public", true, `"user"."table"`}, + {"mixed case schema with force", "MyApp", "Orders", "public", true, `"MyApp"."Orders"`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := QualifyEntityNameWithQuotesAndForce(tt.entitySchema, tt.entityName, tt.targetSchema, tt.forceQuote) + if result != tt.expected { + t.Errorf("QualifyEntityNameWithQuotesAndForce(%q, %q, %q, %v) = %q; want %q", tt.entitySchema, tt.entityName, tt.targetSchema, tt.forceQuote, result, tt.expected) + } + }) + } +} diff --git a/testdata/dump/quote_all_test/pgdump.sql b/testdata/dump/quote_all_test/pgdump.sql new file mode 100644 index 00000000..b4cbfad6 --- /dev/null +++ b/testdata/dump/quote_all_test/pgdump.sql @@ -0,0 +1,130 @@ +-- +-- PostgreSQL database dump for quote-all testing +-- + +SET statement_timeout = 0; +SET lock_timeout = 0; +SET idle_in_transaction_session_timeout = 0; +SET client_encoding = 'UTF8'; +SET standard_conforming_strings = on; +SELECT pg_catalog.set_config('search_path', '', false); +SET check_function_bodies = false; +SET xmloption = content; +SET client_min_messages = warning; +SET row_security = off; + +-- +-- Test table with normal identifiers (should not require quoting without --quote-all) +-- +CREATE TABLE public.users ( + id integer NOT NULL, + first_name text NOT NULL, + last_name text NOT NULL, + email text, + created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP +); + +-- +-- Test table with reserved word identifier (should always require quoting) +-- +CREATE TABLE public."order" ( + id integer NOT NULL, + user_id integer NOT NULL, + total_amount numeric(10,2), + status text DEFAULT 'pending' +); + +-- +-- Test table with mixed case identifiers (should require quoting) +-- +CREATE TABLE public."MixedCase" ( + "ID" integer NOT NULL, + "FirstName" text, + "LastName" text, + "SpecialColumn" text +); + +-- +-- Test sequence with normal name +-- +CREATE SEQUENCE public.users_id_seq + AS integer + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; + +-- +-- Test sequence with mixed case name +-- +CREATE SEQUENCE public."OrderSeq" + AS integer + START WITH 1000 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; + +-- +-- Test index with normal name +-- +CREATE INDEX idx_users_email ON public.users USING btree (email); + +-- +-- Test index with reserved word and mixed case +-- +CREATE INDEX "Index_Order_Status" ON public."order" USING btree (status); + +-- +-- Test view with normal name +-- +CREATE VIEW public.user_orders AS + SELECT u.id, + u.first_name, + u.last_name, + o.total_amount, + o.status + FROM (public.users u + LEFT JOIN public."order" o ON ((u.id = o.user_id))); + +-- +-- Test function with normal name +-- +CREATE FUNCTION public.get_user_count() RETURNS integer + LANGUAGE sql + AS $$ + SELECT COUNT(*) FROM users; +$$; + +-- +-- Add constraints +-- +ALTER TABLE ONLY public.users + ADD CONSTRAINT users_pkey PRIMARY KEY (id); + +ALTER TABLE ONLY public."order" + ADD CONSTRAINT order_pkey PRIMARY KEY (id); + +ALTER TABLE ONLY public."MixedCase" + ADD CONSTRAINT "MixedCase_pkey" PRIMARY KEY ("ID"); + +-- +-- Add foreign key constraint +-- +ALTER TABLE ONLY public."order" + ADD CONSTRAINT order_user_id_fkey FOREIGN KEY (user_id) REFERENCES public.users(id); + +-- +-- Set sequence ownership +-- +ALTER SEQUENCE public.users_id_seq OWNED BY public.users.id; +ALTER SEQUENCE public."OrderSeq" OWNED BY public."order".id; + +-- +-- Comments on objects +-- +COMMENT ON TABLE public.users IS 'Table storing user information'; +COMMENT ON COLUMN public.users.first_name IS 'User first name'; +COMMENT ON TABLE public."order" IS 'Table storing order information'; +COMMENT ON COLUMN public."MixedCase"."FirstName" IS 'Mixed case column comment'; \ No newline at end of file From c93b1120d585b86c71262e03dd472067aa369dc3 Mon Sep 17 00:00:00 2001 From: Adam Mustafa Date: Thu, 27 Nov 2025 14:37:56 -0500 Subject: [PATCH 2/2] update comments --- cmd/dump/dump.go | 4 ++-- cmd/plan/plan.go | 4 ++-- internal/diff/diff.go | 2 +- internal/diff/table.go | 28 +++++++++++++++++++++++++--- ir/quote.go | 21 +++++++++++++++++++-- 5 files changed, 49 insertions(+), 10 deletions(-) diff --git a/cmd/dump/dump.go b/cmd/dump/dump.go index 1971e86c..8ba273a2 100644 --- a/cmd/dump/dump.go +++ b/cmd/dump/dump.go @@ -2,7 +2,7 @@ package dump import ( "fmt" - "log" + "log/slog" "os" "github.com/pgschema/pgschema/cmd/util" @@ -116,7 +116,7 @@ func runDump(cmd *cobra.Command, args []string) error { if err == nil { quoteAll = q } else { - log.Printf("Failed to get quote-all flag: %v\n", err) + slog.Warn("Failed to get quote-all flag", "error", err) } } diff --git a/cmd/plan/plan.go b/cmd/plan/plan.go index 131e8e06..9fea6730 100644 --- a/cmd/plan/plan.go +++ b/cmd/plan/plan.go @@ -3,7 +3,7 @@ package plan import ( "context" "fmt" - "log" + "log/slog" "os" "path/filepath" "strings" @@ -108,7 +108,7 @@ func runPlan(cmd *cobra.Command, args []string) error { if err == nil { quoteAll = q } else { - log.Printf("Failed to get quote-all flag: %v\n", err) + slog.Warn("Failed to get quote-all flag", "error", err) } } diff --git a/internal/diff/diff.go b/internal/diff/diff.go index eacaa1fc..0143ea2a 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -390,7 +390,7 @@ func QuoteAll(enabled bool) Option { // Example usage: // -// tandard migration +// Standard migration // // //diffs := GenerateMigration(oldIR, newIR, "public") diff --git a/internal/diff/table.go b/internal/diff/table.go index 7e825aea..57d40c7d 100644 --- a/internal/diff/table.go +++ b/internal/diff/table.go @@ -345,6 +345,14 @@ type deferredConstraint struct { // dependent functions/procedures have been created, while all other policies are emitted inline. // It returns deferred policies and foreign key constraints that should be applied after dependent objects exist. // Tables are assumed to be pre-sorted in topological order for dependency-aware creation. +// +// Parameters: +// - tables: Tables to create, assumed to be in topological order +// - targetSchema: The schema name to use in generated DDL +// - collector: Diff collector to accumulate generated SQL statements +// - existingTables: Map of table names that already exist +// - shouldDeferPolicy: Function to determine if a policy should be deferred +// - quoteAll: Whether to quote all identifiers regardless of whether they are reserved words func generateCreateTablesSQL( tables []*ir.Table, targetSchema string, @@ -509,7 +517,14 @@ func generateDropTablesSQL(tables []*ir.Table, targetSchema string, collector *d } } -// generateTableSQL generates CREATE TABLE statement and returns any deferred FK constraints +// generateTableSQL generates CREATE TABLE statement and returns any deferred FK constraints. +// +// Parameters: +// - table: The table definition to create +// - targetSchema: The schema name to use in generated DDL +// - createdTables: Map of table names that have been created in this migration +// - existingTables: Map of table names that already exist in the database +// - quoteAll: Whether to quote all identifiers regardless of whether they are reserved words func generateTableSQL(table *ir.Table, targetSchema string, createdTables map[string]bool, existingTables map[string]bool, quoteAll bool) (string, []*deferredConstraint) { // Only include table name without schema if it's in the target schema tableName := ir.QualifyEntityNameWithQuotesAndForce(table.Schema, table.Name, targetSchema, quoteAll) @@ -1132,8 +1147,15 @@ func ensureCheckClauseParens(s string) string { return "CHECK (" + expr + ")" } -// writeColumnDefinitionToBuilder builds column definitions with SERIAL detection and proper formatting -// This is moved from ir/table.go to consolidate SQL generation in the diff module +// writeColumnDefinitionToBuilder builds column definitions with SERIAL detection and proper formatting. +// This is moved from ir/table.go to consolidate SQL generation in the diff module. +// +// Parameters: +// - builder: String builder to write the column definition to +// - table: The table containing this column (used for SERIAL detection) +// - column: The column definition to format +// - targetSchema: The schema name to use for type qualification +// - quoteAll: Whether to quote all identifiers regardless of whether they are reserved words func writeColumnDefinitionToBuilder(builder *strings.Builder, table *ir.Table, column *ir.Column, targetSchema string, quoteAll bool) { builder.WriteString(ir.QuoteIdentifierWithForce(column.Name, quoteAll)) builder.WriteString(" ") diff --git a/ir/quote.go b/ir/quote.go index 1ded98e7..b3cd313d 100644 --- a/ir/quote.go +++ b/ir/quote.go @@ -131,7 +131,15 @@ func QuoteIdentifier(identifier string) string { return identifier } -// QuoteIdentifierWithForce adds quotes to an identifier based on forceQuote flag +// QuoteIdentifierWithForce adds quotes to an identifier based on forceQuote flag. +// When forceQuote is true, all identifiers are quoted regardless of whether they need it. +// When forceQuote is false, only identifiers that require quoting (reserved words, mixed case, etc.) are quoted. +// +// Parameters: +// - identifier: The identifier to potentially quote +// - forceQuote: Whether to force quoting of all identifiers, regardless of whether they are reserved words +// +// Returns the identifier with quotes added if necessary or forced. func QuoteIdentifierWithForce(identifier string, forceQuote bool) string { if identifier == "" { return "" @@ -154,7 +162,16 @@ func QualifyEntityNameWithQuotes(entitySchema, entityName, targetSchema string) return quotedSchema + "." + quotedName } -// QualifyEntityNameWithQuotesAndForce returns the properly qualified and quoted entity name with forceQuote option +// QualifyEntityNameWithQuotesAndForce returns the properly qualified and quoted entity name with forceQuote option. +// This function combines schema qualification logic with optional forced quoting of all identifiers. +// +// Parameters: +// - entitySchema: The schema of the entity +// - entityName: The name of the entity +// - targetSchema: The target schema for qualification (if same as entitySchema, schema prefix is omitted) +// - forceQuote: Whether to force quoting of all identifiers, regardless of whether they are reserved words +// +// Returns a properly qualified and quoted entity name (e.g., "schema"."table" or just "table" if in target schema). func QualifyEntityNameWithQuotesAndForce(entitySchema, entityName, targetSchema string, forceQuote bool) string { quotedName := QuoteIdentifierWithForce(entityName, forceQuote)