Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion cmd/dump/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dump

import (
"fmt"
"log/slog"
"os"

"github.com/pgschema/pgschema/cmd/util"
Expand Down Expand Up @@ -32,6 +33,7 @@ type DumpConfig struct {
Schema string
MultiFile bool
File string
QuoteAll bool
}

var DumpCmd = &cobra.Command{
Expand Down Expand Up @@ -79,7 +81,7 @@ func ExecuteDump(config *DumpConfig) (string, error) {
emptyIR := ir.NewIR()

// Generate diff between empty schema and target schema (this represents a complete dump)
diffs := diff.GenerateMigration(emptyIR, schemaIR, config.Schema)
diffs := diff.GenerateMigration(emptyIR, schemaIR, config.Schema, diff.QuoteAll(config.QuoteAll))

// Create dump formatter
formatter := dump.NewDumpFormatter(schemaIR.Metadata.DatabaseVersion, config.Schema)
Expand Down Expand Up @@ -107,6 +109,17 @@ func runDump(cmd *cobra.Command, args []string) error {
}
}

// Get quote-all flag from root command
var quoteAll bool
if cmd != nil {
q, err := cmd.Root().PersistentFlags().GetBool("quote-all")
if err == nil {
quoteAll = q
} else {
slog.Warn("Failed to get quote-all flag", "error", err)
}
}

// Create config from command-line flags
config := &DumpConfig{
Host: host,
Expand All @@ -117,6 +130,7 @@ func runDump(cmd *cobra.Command, args []string) error {
Schema: schema,
MultiFile: multiFile,
File: file,
QuoteAll: quoteAll,
}

// Execute dump
Expand Down
181 changes: 181 additions & 0 deletions cmd/dump/dump_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"context"
"fmt"
"os"
"regexp"
"strings"
"testing"

Expand Down Expand Up @@ -303,3 +304,183 @@ func compareSchemaOutputs(t *testing.T, actualOutput, expectedOutput string, tes
}
}
}

// TestDumpCommand_QuoteAll validates the --quote-all flag behavior
func TestDumpCommand_QuoteAll(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}

runQuoteAllTest(t, "quote_all_test")
}

// runQuoteAllTest validates that the --quote-all flag correctly quotes all identifiers
func runQuoteAllTest(t *testing.T, testDataDir string) {
// Setup PostgreSQL
embeddedPG := testutil.SetupPostgres(t)
defer embeddedPG.Stop()

// Connect to database
conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG)
defer conn.Close()

// Detect PostgreSQL version and skip tests if needed
majorVersion, err := testutil.GetMajorVersion(conn)
if err != nil {
t.Fatalf("Failed to detect PostgreSQL version: %v", err)
}

// Check if this test should be skipped for this PostgreSQL version
testutil.ShouldSkipTest(t, t.Name(), majorVersion)

// Read and execute the pgdump.sql file
pgdumpPath := fmt.Sprintf("../../testdata/dump/%s/pgdump.sql", testDataDir)
pgdumpContent, err := os.ReadFile(pgdumpPath)
if err != nil {
t.Fatalf("Failed to read %s: %v", pgdumpPath, err)
}

// Execute the SQL to create the schema
_, err = conn.ExecContext(context.Background(), string(pgdumpContent))
if err != nil {
t.Fatalf("Failed to execute pgdump.sql: %v", err)
}

// Test 1: Dump without --quote-all (normal behavior)
configNormal := &DumpConfig{
Host: host,
Port: port,
DB: dbname,
User: user,
Password: password,
Schema: "public",
MultiFile: false,
File: "",
QuoteAll: false,
}

normalOutput, err := ExecuteDump(configNormal)
if err != nil {
t.Fatalf("Dump command failed without quote-all: %v", err)
}

// Test 2: Dump with --quote-all (all identifiers quoted)
configQuoteAll := &DumpConfig{
Host: host,
Port: port,
DB: dbname,
User: user,
Password: password,
Schema: "public",
MultiFile: false,
File: "",
QuoteAll: true,
}

quoteAllOutput, err := ExecuteDump(configQuoteAll)
if err != nil {
t.Fatalf("Dump command failed with quote-all: %v", err)
}

// Validate quote-all behavior
validateQuoteAllBehavior(t, normalOutput, quoteAllOutput, testDataDir)
}

// validateQuoteAllBehavior verifies that --quote-all produces correctly quoted output
func validateQuoteAllBehavior(t *testing.T, normalOutput, quoteAllOutput, testName string) {
// Split outputs into lines for analysis
normalLines := strings.Split(normalOutput, "\n")
quoteAllLines := strings.Split(quoteAllOutput, "\n")

// Both outputs should have the same number of lines
if len(normalLines) != len(quoteAllLines) {
t.Fatalf("Different number of lines - Normal: %d, QuoteAll: %d", len(normalLines), len(quoteAllLines))
}

// Track identifiers that should be quoted in normal mode vs quote-all mode
var normalQuotedIdentifiers []string
var quoteAllQuotedIdentifiers []string

// Regular expression to find quoted identifiers
quotedIdentifierRegex := `"([^"]+)"`

for i, normalLine := range normalLines {
quoteAllLine := quoteAllLines[i]

// Skip comment lines and empty lines
if strings.HasPrefix(strings.TrimSpace(normalLine), "--") || strings.TrimSpace(normalLine) == "" {
continue
}

// Extract quoted identifiers from both outputs
normalMatches := regexp.MustCompile(quotedIdentifierRegex).FindAllStringSubmatch(normalLine, -1)
quoteAllMatches := regexp.MustCompile(quotedIdentifierRegex).FindAllStringSubmatch(quoteAllLine, -1)

for _, match := range normalMatches {
normalQuotedIdentifiers = append(normalQuotedIdentifiers, match[1])
}

for _, match := range quoteAllMatches {
quoteAllQuotedIdentifiers = append(quoteAllQuotedIdentifiers, match[1])
}
}

// Validate expectations:
// 1. Quote-all mode should have more quoted identifiers than normal mode
if len(quoteAllQuotedIdentifiers) <= len(normalQuotedIdentifiers) {
t.Errorf("Quote-all mode should have more quoted identifiers. Normal: %d, QuoteAll: %d",
len(normalQuotedIdentifiers), len(quoteAllQuotedIdentifiers))
}

// 2. All identifiers that were quoted in normal mode should also be quoted in quote-all mode
normalQuotedSet := make(map[string]bool)
for _, id := range normalQuotedIdentifiers {
normalQuotedSet[id] = true
}

quoteAllQuotedSet := make(map[string]bool)
for _, id := range quoteAllQuotedIdentifiers {
quoteAllQuotedSet[id] = true
}

for identifier := range normalQuotedSet {
if !quoteAllQuotedSet[identifier] {
t.Errorf("Identifier '%s' was quoted in normal mode but not in quote-all mode", identifier)
}
}

// 3. Verify specific expected behaviors
// Note: Currently only table and column names support quote-all. Other objects (indexes, sequences, views, functions) are not yet implemented
expectedNormalQuoted := []string{"order", "MixedCase", "ID", "FirstName", "LastName", "SpecialColumn", "Index_Order_Status", "MixedCase_pkey"}
expectedQuoteAllOnly := []string{"users", "id", "first_name", "last_name", "email", "created_at", "user_id", "total_amount", "status"}

// Check that expected identifiers are quoted in normal mode
for _, identifier := range expectedNormalQuoted {
if !normalQuotedSet[identifier] {
t.Errorf("Expected identifier '%s' to be quoted in normal mode, but it wasn't", identifier)
}
}

// Check that additional identifiers are quoted only in quote-all mode
for _, identifier := range expectedQuoteAllOnly {
if normalQuotedSet[identifier] {
t.Errorf("Identifier '%s' should not be quoted in normal mode", identifier)
}
if !quoteAllQuotedSet[identifier] {
t.Errorf("Identifier '%s' should be quoted in quote-all mode", identifier)
}
}

// Write outputs to files for debugging if test fails
if t.Failed() {
normalFilename := fmt.Sprintf("%s_normal.sql", testName)
os.WriteFile(normalFilename, []byte(normalOutput), 0644)

quoteAllFilename := fmt.Sprintf("%s_quote_all.sql", testName)
os.WriteFile(quoteAllFilename, []byte(quoteAllOutput), 0644)

t.Logf("Outputs written to %s and %s for debugging", normalFilename, quoteAllFilename)
t.Logf("Normal quoted identifiers: %v", normalQuotedIdentifiers)
t.Logf("Quote-all quoted identifiers: %v", quoteAllQuotedIdentifiers)
}
}
16 changes: 15 additions & 1 deletion cmd/plan/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package plan
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -100,6 +101,17 @@ func runPlan(cmd *cobra.Command, args []string) error {
}
}

// Get quote-all flag from root command
var quoteAll bool
if cmd != nil {
q, err := cmd.Root().PersistentFlags().GetBool("quote-all")
if err == nil {
quoteAll = q
} else {
slog.Warn("Failed to get quote-all flag", "error", err)
}
}

// Create plan configuration
config := &PlanConfig{
Host: planHost,
Expand All @@ -110,6 +122,7 @@ func runPlan(cmd *cobra.Command, args []string) error {
Schema: planSchema,
File: planFile,
ApplicationName: "pgschema",
QuoteAll: quoteAll,
// Plan database configuration
PlanDBHost: planDBHost,
PlanDBPort: planDBPort,
Expand Down Expand Up @@ -157,6 +170,7 @@ type PlanConfig struct {
Schema string
File string
ApplicationName string
QuoteAll bool
// Plan database configuration (optional - for external database)
PlanDBHost string
PlanDBPort int
Expand Down Expand Up @@ -285,7 +299,7 @@ func GeneratePlan(config *PlanConfig, provider postgres.DesiredStateProvider) (*
}

// Generate diff (current -> desired) using IR directly
diffs := diff.GenerateMigration(currentStateIR, desiredStateIR, config.Schema)
diffs := diff.GenerateMigration(currentStateIR, desiredStateIR, config.Schema, diff.QuoteAll(config.QuoteAll))

// Create plan from diffs with fingerprint
migrationPlan := plan.NewPlanWithFingerprint(diffs, sourceFingerprint)
Expand Down
7 changes: 7 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
)

var Debug bool
var QuoteAll bool
var logger *slog.Logger

// Build-time variables set via ldflags
Expand Down Expand Up @@ -45,6 +46,7 @@ Use "pgschema [command] --help" for more information about a command.`,

func init() {
RootCmd.PersistentFlags().BoolVar(&Debug, "debug", false, "Enable debug logging")
RootCmd.PersistentFlags().BoolVar(&QuoteAll, "quote-all", false, "Quote all identifiers regardless of whether they are reserved words")
RootCmd.CompletionOptions.DisableDefaultCmd = true
RootCmd.AddCommand(dump.DumpCmd)
RootCmd.AddCommand(plan.PlanCmd)
Expand Down Expand Up @@ -78,6 +80,11 @@ func IsDebug() bool {
return Debug
}

// IsQuoteAll returns whether quote-all mode is enabled
func IsQuoteAll() bool {
return QuoteAll
}

// platform returns the OS/architecture combination
func platform() string {
return runtime.GOOS + "/" + runtime.GOARCH
Expand Down
Loading