diff --git a/cmd/apply/apply.go b/cmd/apply/apply.go index 8bca60cc..6ac8915c 100644 --- a/cmd/apply/apply.go +++ b/cmd/apply/apply.go @@ -12,6 +12,7 @@ import ( "github.com/pgschema/pgschema/cmd/util" "github.com/pgschema/pgschema/internal/fingerprint" "github.com/pgschema/pgschema/internal/plan" + "github.com/pgschema/pgschema/internal/postgres" "github.com/pgschema/pgschema/internal/version" "github.com/pgschema/pgschema/ir" "github.com/spf13/cobra" @@ -89,7 +90,7 @@ type ApplyConfig struct { // // If config.File is provided, embeddedPG is used to generate the plan. // The caller is responsible for managing the embeddedPG lifecycle (creation and cleanup). -func ApplyMigration(config *ApplyConfig, embeddedPG *util.EmbeddedPostgres) error { +func ApplyMigration(config *ApplyConfig, embeddedPG *postgres.EmbeddedPostgres) error { var migrationPlan *plan.Plan var err error @@ -253,7 +254,7 @@ func RunApply(cmd *cobra.Command, args []string) error { ApplicationName: applyApplicationName, } - var embeddedPG *util.EmbeddedPostgres + var embeddedPG *postgres.EmbeddedPostgres var err error // If using --plan flag, load plan from JSON file diff --git a/cmd/apply/apply_integration_test.go b/cmd/apply/apply_integration_test.go index 780d52d3..aebb2804 100644 --- a/cmd/apply/apply_integration_test.go +++ b/cmd/apply/apply_integration_test.go @@ -2,29 +2,29 @@ package apply import ( "context" + "database/sql" "os" "path/filepath" "strings" "testing" - embeddedpostgres "github.com/fergusstrange/embedded-postgres" planCmd "github.com/pgschema/pgschema/cmd/plan" - "github.com/pgschema/pgschema/cmd/util" "github.com/pgschema/pgschema/internal/plan" + "github.com/pgschema/pgschema/internal/postgres" "github.com/pgschema/pgschema/testutil" ) var ( // sharedEmbeddedPG is a shared embedded PostgreSQL instance used across all integration tests // to significantly improve test performance by avoiding repeated startup/teardown - sharedEmbeddedPG *util.EmbeddedPostgres + sharedEmbeddedPG *postgres.EmbeddedPostgres ) // TestMain sets up shared resources for all tests in this package func TestMain(m *testing.M) { // Create shared embedded postgres instance for all integration tests // This dramatically improves test performance by reusing the same instance - sharedEmbeddedPG = util.SetupSharedEmbeddedPostgres(nil, embeddedpostgres.PostgresVersion("17.5.0")) + sharedEmbeddedPG = testutil.SetupPostgres(nil) defer sharedEmbeddedPG.Stop() // Run tests @@ -62,11 +62,29 @@ func TestApplyCommand_TransactionRollback(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with initial schema - conn := container.Conn initialSQL := ` CREATE TABLE users ( @@ -138,9 +156,9 @@ func TestApplyCommand_TransactionRollback(t *testing.T) { planConfig := &planCmd.PlanConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", File: desiredStateFile, ApplicationName: "pgschema", @@ -197,9 +215,9 @@ func TestApplyCommand_TransactionRollback(t *testing.T) { applyConfig := &ApplyConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", Plan: migrationPlan, // Use pre-generated plan with injected failure AutoApprove: true, @@ -319,11 +337,29 @@ func TestApplyCommand_CreateIndexConcurrently(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with initial schema - conn := container.Conn initialSQL := ` CREATE TABLE users ( @@ -375,9 +411,9 @@ func TestApplyCommand_CreateIndexConcurrently(t *testing.T) { planConfig := &planCmd.PlanConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", File: desiredStateFile, ApplicationName: "pgschema", @@ -414,9 +450,9 @@ func TestApplyCommand_CreateIndexConcurrently(t *testing.T) { applyConfig := &ApplyConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", Plan: migrationPlan, // Use pre-generated plan AutoApprove: true, @@ -520,11 +556,29 @@ func TestApplyCommand_WithPlanFile(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with initial schema - conn := container.Conn initialSQL := ` CREATE TABLE users ( @@ -569,9 +623,9 @@ func TestApplyCommand_WithPlanFile(t *testing.T) { planConfig := &planCmd.PlanConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", File: desiredStateFile, ApplicationName: "pgschema", @@ -586,9 +640,9 @@ func TestApplyCommand_WithPlanFile(t *testing.T) { applyConfig := &ApplyConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", Plan: migrationPlan, // Use pre-generated plan AutoApprove: true, @@ -689,11 +743,29 @@ func TestApplyCommand_FingerprintMismatch(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with initial schema - conn := container.Conn initialSQL := ` CREATE TABLE users ( @@ -744,9 +816,9 @@ func TestApplyCommand_FingerprintMismatch(t *testing.T) { planConfig := &planCmd.PlanConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", File: desiredStateFile, ApplicationName: "pgschema", @@ -798,9 +870,9 @@ func TestApplyCommand_FingerprintMismatch(t *testing.T) { applyConfig := &ApplyConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", Plan: migrationPlan, // Use pre-generated plan with old fingerprint AutoApprove: true, @@ -880,11 +952,29 @@ func TestApplyCommand_WaitDirective(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with initial schema and data - conn := container.Conn initialSQL := ` CREATE TABLE users ( @@ -931,9 +1021,9 @@ func TestApplyCommand_WaitDirective(t *testing.T) { planConfig := &planCmd.PlanConfig{ Host: container.Host, Port: container.Port, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", File: desiredStateFile, ApplicationName: "pgschema", @@ -948,9 +1038,9 @@ func TestApplyCommand_WaitDirective(t *testing.T) { applyConfig := &ApplyConfig{ Host: container.Host, Port: container.Port, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", Plan: migrationPlan, // Use pre-generated plan AutoApprove: true, diff --git a/cmd/dump/dump.go b/cmd/dump/dump.go index fb237525..dbce9052 100644 --- a/cmd/dump/dump.go +++ b/cmd/dump/dump.go @@ -1,7 +1,6 @@ package dump import ( - "context" "fmt" "os" @@ -23,6 +22,17 @@ var ( file string ) +// DumpConfig holds configuration for dump execution +type DumpConfig struct { + Host string + Port int + DB string + User string + Password string + Schema string + MultiFile bool + File string +} var DumpCmd = &cobra.Command{ Use: "dump", @@ -30,7 +40,7 @@ var DumpCmd = &cobra.Command{ Long: "Dump and output database schema information for a specific schema. Uses the --schema flag to target a particular schema (defaults to 'public').", RunE: runDump, SilenceUsage: true, - PreRunE: util.PreRunEWithEnvVarsAndConnection(&db, &user, &host, &port), + PreRunE: util.PreRunEWithEnvVarsAndConnection(&db, &user, &host, &port), } func init() { @@ -44,72 +54,79 @@ func init() { DumpCmd.Flags().StringVar(&file, "file", "", "Output file path (required when --multi-file is used)") } -func runDump(cmd *cobra.Command, args []string) error { +// ExecuteDump executes the dump operation with the given configuration +func ExecuteDump(config *DumpConfig) (string, error) { // Validate flags - if multiFile && file == "" { + if config.MultiFile && config.File == "" { // When --multi-file is used but no --file specified, emit warning and use single-file mode fmt.Fprintf(os.Stderr, "Warning: --multi-file flag requires --file to be specified. Fallback to single-file mode.\n") - multiFile = false - } - - // Derive final password: use flag if provided, otherwise check environment variable - finalPassword := password - if finalPassword == "" { - if envPassword := os.Getenv("PGPASSWORD"); envPassword != "" { - finalPassword = envPassword - } - } - - // Build database connection - config := &util.ConnectionConfig{ - Host: host, - Port: port, - Database: db, - User: user, - Password: finalPassword, - SSLMode: "prefer", - ApplicationName: "pgschema", + config.MultiFile = false } - dbConn, err := util.Connect(config) - if err != nil { - return err - } - defer dbConn.Close() - - ctx := context.Background() - // Load ignore configuration ignoreConfig, err := util.LoadIgnoreFileWithStructure() if err != nil { - return fmt.Errorf("failed to load .pgschemaignore: %w", err) + return "", fmt.Errorf("failed to load .pgschemaignore: %w", err) } - // Build IR using the IR system - inspector := ir.NewInspector(dbConn, ignoreConfig) - schemaIR, err := inspector.BuildIR(ctx, schema) + // Get IR from database using the shared utility + schemaIR, err := util.GetIRFromDatabase(config.Host, config.Port, config.DB, config.User, config.Password, config.Schema, "pgschema", ignoreConfig) if err != nil { - return fmt.Errorf("failed to build IR: %w", err) + return "", fmt.Errorf("failed to get database schema: %w", err) } // Create an empty schema for comparison to generate a dump diff emptyIR := ir.NewIR() // Generate diff between empty schema and target schema (this represents a complete dump) - diffs := diff.GenerateMigration(emptyIR, schemaIR, schema) + diffs := diff.GenerateMigration(emptyIR, schemaIR, config.Schema) // Create dump formatter - formatter := dump.NewDumpFormatter(schemaIR.Metadata.DatabaseVersion, schema) + formatter := dump.NewDumpFormatter(schemaIR.Metadata.DatabaseVersion, config.Schema) - if multiFile { + if config.MultiFile { // Multi-file mode - output to files - err := formatter.FormatMultiFile(diffs, file) + err := formatter.FormatMultiFile(diffs, config.File) if err != nil { - return fmt.Errorf("failed to create multi-file output: %w", err) + return "", fmt.Errorf("failed to create multi-file output: %w", err) } + return "", nil } else { - // Single file mode - output to stdout + // Single file mode - return output as string output := formatter.FormatSingleFile(diffs) + return output, nil + } +} + +func runDump(cmd *cobra.Command, args []string) error { + // Derive final password: use flag if provided, otherwise check environment variable + finalPassword := password + if finalPassword == "" { + if envPassword := os.Getenv("PGPASSWORD"); envPassword != "" { + finalPassword = envPassword + } + } + + // Create config from command-line flags + config := &DumpConfig{ + Host: host, + Port: port, + DB: db, + User: user, + Password: finalPassword, + Schema: schema, + MultiFile: multiFile, + File: file, + } + + // Execute dump + output, err := ExecuteDump(config) + if err != nil { + return err + } + + // Print output to stdout (only in single-file mode) + if output != "" { fmt.Print(output) } diff --git a/cmd/dump/dump_integration_test.go b/cmd/dump/dump_integration_test.go index 7b2b72f5..3207b612 100644 --- a/cmd/dump/dump_integration_test.go +++ b/cmd/dump/dump_integration_test.go @@ -10,7 +10,6 @@ package dump import ( "context" "fmt" - "io" "os" "strings" "testing" @@ -80,9 +79,13 @@ func runExactMatchTest(t *testing.T, testDataDir string) { } func runExactMatchTestWithContext(t *testing.T, ctx context.Context, testDataDir string) { - // Setup PostgreSQL container - containerInfo := testutil.SetupPostgresContainer(ctx, t) - defer containerInfo.Terminate(ctx, t) + // Setup PostgreSQL + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + + // Connect to database + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() // Read and execute the pgdump.sql file pgdumpPath := fmt.Sprintf("../../testdata/dump/%s/pgdump.sql", testDataDir) @@ -92,36 +95,28 @@ func runExactMatchTestWithContext(t *testing.T, ctx context.Context, testDataDir } // Execute the SQL to create the schema - _, err = containerInfo.Conn.ExecContext(ctx, string(pgdumpContent)) + _, err = conn.ExecContext(ctx, string(pgdumpContent)) if err != nil { t.Fatalf("Failed to execute pgdump.sql: %v", err) } - // Store original connection parameters and restore them later - originalConfig := testutil.TestConnectionConfig{ - Host: host, - Port: port, - DB: db, - User: user, - Schema: schema, + // Create dump configuration + config := &DumpConfig{ + Host: host, + Port: port, + DB: dbname, + User: user, + Password: password, + Schema: "public", + MultiFile: false, + File: "", + } + + // Execute pgschema dump + actualOutput, err := ExecuteDump(config) + if err != nil { + t.Fatalf("Dump command failed: %v", err) } - defer func() { - host = originalConfig.Host - port = originalConfig.Port - db = originalConfig.DB - user = originalConfig.User - schema = originalConfig.Schema - }() - - // Configure connection parameters - host = containerInfo.Host - port = containerInfo.Port - db = "testdb" - user = "testuser" - testutil.SetEnvPassword("testpass") - - // Execute pgschema dump and capture output - actualOutput := executePgSchemaDump(t, "") // Read expected output expectedPath := fmt.Sprintf("../../testdata/dump/%s/pgschema.sql", testDataDir) @@ -136,11 +131,13 @@ func runExactMatchTestWithContext(t *testing.T, ctx context.Context, testDataDir } func runTenantSchemaTest(t *testing.T, testDataDir string) { - ctx := context.Background() + // Setup PostgreSQL + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() - // Setup PostgreSQL container - containerInfo := testutil.SetupPostgresContainer(ctx, t) - defer containerInfo.Terminate(ctx, t) + // Connect to database + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() // Load public schema types first publicSQL, err := os.ReadFile(fmt.Sprintf("../../testdata/dump/%s/public.sql", testDataDir)) @@ -148,7 +145,7 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { t.Fatalf("Failed to read public.sql: %v", err) } - _, err = containerInfo.Conn.Exec(string(publicSQL)) + _, err = conn.Exec(string(publicSQL)) if err != nil { t.Fatalf("Failed to load public types: %v", err) } @@ -156,7 +153,7 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { // Load utility functions (if util.sql exists) utilPath := fmt.Sprintf("../../testdata/dump/%s/util.sql", testDataDir) if utilSQL, err := os.ReadFile(utilPath); err == nil { - _, err = containerInfo.Conn.Exec(string(utilSQL)) + _, err = conn.Exec(string(utilSQL)) if err != nil { t.Fatalf("Failed to load utility functions from util.sql: %v", err) } @@ -167,7 +164,7 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { // Create two tenant schemas tenants := []string{"tenant1", "tenant2"} for _, tenant := range tenants { - _, err = containerInfo.Conn.Exec(fmt.Sprintf("CREATE SCHEMA %s", tenant)) + _, err = conn.Exec(fmt.Sprintf("CREATE SCHEMA %s", tenant)) if err != nil { t.Fatalf("Failed to create schema %s: %v", tenant, err) } @@ -183,47 +180,38 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { for _, tenant := range tenants { // Set search path to include public for the types, but target schema first quotedTenant := ir.QuoteIdentifier(tenant) - _, err = containerInfo.Conn.Exec(fmt.Sprintf("SET search_path TO %s, public", quotedTenant)) + _, err = conn.Exec(fmt.Sprintf("SET search_path TO %s, public", quotedTenant)) if err != nil { t.Fatalf("Failed to set search path to %s: %v", tenant, err) } // Execute the SQL - _, err = containerInfo.Conn.Exec(string(tenantSQL)) + _, err = conn.Exec(string(tenantSQL)) if err != nil { t.Fatalf("Failed to load SQL into schema %s: %v", tenant, err) } } - // Save original command variables - originalConfig := testutil.TestConnectionConfig{ - Host: host, - Port: port, - DB: db, - User: user, - Schema: schema, - } - defer func() { - host = originalConfig.Host - port = originalConfig.Port - db = originalConfig.DB - user = originalConfig.User - schema = originalConfig.Schema - }() - // Dump both tenant schemas using pgschema dump command var dumps []string for _, tenantName := range tenants { - // Set connection parameters for this specific tenant dump - host = containerInfo.Host - port = containerInfo.Port - db = "testdb" - user = "testuser" - testutil.SetEnvPassword("testpass") - schema = tenantName - - // Execute pgschema dump and capture output - actualOutput := executePgSchemaDump(t, fmt.Sprintf("tenant %s", tenantName)) + // Create dump configuration for this tenant + config := &DumpConfig{ + Host: host, + Port: port, + DB: dbname, + User: user, + Password: password, + Schema: tenantName, + MultiFile: false, + File: "", + } + + // Execute pgschema dump + actualOutput, err := ExecuteDump(config) + if err != nil { + t.Fatalf("Dump command failed for tenant %s: %v", tenantName, err) + } dumps = append(dumps, actualOutput) } @@ -248,56 +236,6 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { } } -func executePgSchemaDump(t *testing.T, contextInfo string) string { - // Capture output by redirecting stdout - originalStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - // Read from pipe in a goroutine to avoid deadlock - var actualOutput string - var readErr error - done := make(chan bool) - - go func() { - defer close(done) - output, err := io.ReadAll(r) - if err != nil { - readErr = err - return - } - actualOutput = string(output) - }() - - // Run the dump command - // Logger setup handled by root command - err := runDump(nil, nil) - - // Close write end and restore stdout - w.Close() - os.Stdout = originalStdout - - if err != nil { - if contextInfo != "" { - t.Fatalf("Dump command failed for %s: %v", contextInfo, err) - } else { - t.Fatalf("Dump command failed: %v", err) - } - } - - // Wait for reading to complete - <-done - if readErr != nil { - if contextInfo != "" { - t.Fatalf("Failed to read captured output for %s: %v", contextInfo, readErr) - } else { - t.Fatalf("Failed to read captured output: %v", readErr) - } - } - - return actualOutput -} - // normalizeSchemaOutput removes version-specific lines for comparison func normalizeSchemaOutput(output string) string { lines := strings.Split(output, "\n") diff --git a/cmd/dump/dump_permission_integration_test.go b/cmd/dump/dump_permission_integration_test.go index bf099e2a..73acddc8 100644 --- a/cmd/dump/dump_permission_integration_test.go +++ b/cmd/dump/dump_permission_integration_test.go @@ -12,7 +12,6 @@ import ( "context" "database/sql" "fmt" - "io" "os" "strings" "testing" @@ -31,8 +30,27 @@ func TestDumpCommand_PermissionSuite(t *testing.T) { ctx := context.Background() // Start single PostgreSQL container for all permission tests - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "postgres", "testpwd") - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Run each permission test with its own isolated database t.Run("ProcedureAndFunctionSourceAccess", func(t *testing.T) { @@ -45,7 +63,14 @@ func TestDumpCommand_PermissionSuite(t *testing.T) { } // setupTestDatabase creates a new database with permission test roles -func setupTestDatabase(ctx context.Context, t *testing.T, container *testutil.ContainerInfo, dbName string) *sql.DB { +func setupTestDatabase(ctx context.Context, t *testing.T, container *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, dbName string) *sql.DB { // Create the database _, err := container.Conn.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", dbName)) if err != nil { @@ -74,8 +99,8 @@ func setupTestDatabase(ctx context.Context, t *testing.T, container *testutil.Co Host: container.Host, Port: container.Port, Database: dbName, - User: "postgres", - Password: "testpwd", + User: container.User, + Password: container.Password, SSLMode: "prefer", ApplicationName: "pgschema", } @@ -102,7 +127,14 @@ func getRoleNames(dbName string) (restrictedRole string, regularUser string) { } // testIgnoredObjects tests procedures ignored via .pgschemaignore -func testIgnoredObjects(t *testing.T, ctx context.Context, container *testutil.ContainerInfo, dbName string) { +func testIgnoredObjects(t *testing.T, ctx context.Context, container *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, dbName string) { // This test verifies that when procedures/functions are explicitly ignored // via .pgschemaignore, permission issues should not cause the dump to fail @@ -253,7 +285,14 @@ patterns = ["*_restricted"] // testProcedureAndFunctionSourceAccess tests that procedure and function source code is readable // via p.prosrc even when information_schema.routines.routine_definition is NULL -func testProcedureAndFunctionSourceAccess(t *testing.T, ctx context.Context, container *testutil.ContainerInfo, dbName string) { +func testProcedureAndFunctionSourceAccess(t *testing.T, ctx context.Context, container *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, dbName string) { // Setup isolated database dbConn := setupTestDatabase(ctx, t, container, dbName) defer dbConn.Close() @@ -434,63 +473,18 @@ func testProcedureAndFunctionSourceAccess(t *testing.T, ctx context.Context, con // This helper is used specifically for permission testing where we need to run // the dump command with restricted database user credentials. func executeDumpCommandAsUser(hostArg string, portArg int, database, userArg, password, schemaArg string) (string, error) { - // Store original connection parameters and restore them later - originalConfig := testutil.TestConnectionConfig{ - Host: host, - Port: port, - DB: db, - User: user, - Schema: schema, - } - defer func() { - host = originalConfig.Host - port = originalConfig.Port - db = originalConfig.DB - user = originalConfig.User - schema = originalConfig.Schema - }() - - // Set connection parameters for this specific dump - host = hostArg - port = portArg - db = database - user = userArg - schema = schemaArg - testutil.SetEnvPassword(password) - - // Capture output by redirecting stdout - originalStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - // Read from pipe in a goroutine to avoid deadlock - var actualOutput string - var readErr error - done := make(chan bool) - - go func() { - defer close(done) - output, err := io.ReadAll(r) - if err != nil { - readErr = err - return - } - actualOutput = string(output) - }() - - // Run the dump command - // Logger setup handled by root command - err := runDump(nil, nil) - - // Close write end and restore stdout - w.Close() - os.Stdout = originalStdout - - // Wait for reading to complete - <-done - if readErr != nil { - return "", fmt.Errorf("failed to read captured output: %w", readErr) + // Create dump configuration + config := &DumpConfig{ + Host: hostArg, + Port: portArg, + DB: database, + User: userArg, + Password: password, + Schema: schemaArg, + MultiFile: false, + File: "", } - return actualOutput, err + // Execute dump + return ExecuteDump(config) } diff --git a/cmd/ignore_integration_test.go b/cmd/ignore_integration_test.go index 5ef1b2ad..f3ea305b 100644 --- a/cmd/ignore_integration_test.go +++ b/cmd/ignore_integration_test.go @@ -6,7 +6,6 @@ package cmd // various database object types and ignore patterns including wildcards and negation. import ( - "context" "database/sql" "fmt" "os" @@ -28,11 +27,28 @@ func TestIgnoreIntegration(t *testing.T) { t.Skip("Skipping integration test in short mode") } - ctx := context.Background() - // Setup PostgreSQL container - containerInfo := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer containerInfo.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create containerInfo struct to match old API for minimal changes + containerInfo := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Create the test schema with various object types createTestSchema(t, containerInfo.Conn) @@ -68,9 +84,27 @@ func TestIgnoreIntegration(t *testing.T) { t.Run("apply", func(t *testing.T) { // Create a fresh container for apply test to avoid fingerprint conflicts - ctx := context.Background() - applyContainerInfo := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer applyContainerInfo.Terminate(ctx, t) + applyEmbeddedPG := testutil.SetupPostgres(t) + defer applyEmbeddedPG.Stop() + applyConn, applyHost, applyPort, applyDbname, applyUser, applyPassword := testutil.ConnectToPostgres(t, applyEmbeddedPG) + defer applyConn.Close() + + // Create applyContainerInfo struct to match old API + applyContainerInfo := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: applyConn, + Host: applyHost, + Port: applyPort, + DBName: applyDbname, + User: applyUser, + Password: applyPassword, + } // Create the test schema in the fresh container createTestSchema(t, applyContainerInfo.Conn) @@ -266,7 +300,14 @@ patterns = ["seq_temp_*"] } // testIgnoreDump tests the dump command with ignore functionality -func testIgnoreDump(t *testing.T, containerInfo *testutil.ContainerInfo) { +func testIgnoreDump(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}) { // Create .pgschemaignore file cleanup := createIgnoreFile(t) defer cleanup() @@ -281,7 +322,14 @@ func testIgnoreDump(t *testing.T, containerInfo *testutil.ContainerInfo) { // testIgnorePlanWithTriggerOnIgnoredTable tests that triggers can be defined on ignored tables // This tests the scenario where users manage triggers on externally-managed tables // without managing the table schema itself -func testIgnorePlanWithTriggerOnIgnoredTable(t *testing.T, containerInfo *testutil.ContainerInfo) { +func testIgnorePlanWithTriggerOnIgnoredTable(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}) { // Create .pgschemaignore file - temp_* pattern will ignore temp_external_users cleanup := createIgnoreFile(t) defer cleanup() @@ -348,7 +396,14 @@ CREATE TRIGGER on_external_user_created } // testIgnorePlan tests the plan command with ignore functionality -func testIgnorePlan(t *testing.T, containerInfo *testutil.ContainerInfo) { +func testIgnorePlan(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}) { // Create .pgschemaignore file cleanup := createIgnoreFile(t) defer cleanup() @@ -400,7 +455,14 @@ CREATE TABLE test_core_config ( // testIgnoreApply tests the apply command with ignore functionality // This test verifies that ignored objects are excluded from fingerprint calculation -func testIgnoreApply(t *testing.T, containerInfo *testutil.ContainerInfo) { +func testIgnoreApply(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}) { // Create .pgschemaignore file cleanup := createIgnoreFile(t) defer cleanup() @@ -500,7 +562,14 @@ $$; } // executeIgnoreDumpCommand runs the dump command and returns the output -func executeIgnoreDumpCommand(t *testing.T, containerInfo *testutil.ContainerInfo) string { +func executeIgnoreDumpCommand(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}) string { // Create a new root command with dump as subcommand rootCmd := &cobra.Command{ Use: "pgschema", @@ -526,9 +595,9 @@ func executeIgnoreDumpCommand(t *testing.T, containerInfo *testutil.ContainerInf "dump", "--host", containerInfo.Host, "--port", fmt.Sprintf("%d", containerInfo.Port), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", containerInfo.DBName, + "--user", containerInfo.User, + "--password", containerInfo.Password, "--schema", "public", } rootCmd.SetArgs(args) @@ -547,14 +616,21 @@ func executeIgnoreDumpCommand(t *testing.T, containerInfo *testutil.ContainerInf } // executeIgnorePlanCommand runs the plan command and returns the output -func executeIgnorePlanCommand(t *testing.T, containerInfo *testutil.ContainerInfo, schemaFile string) string { +func executeIgnorePlanCommand(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, schemaFile string) string { // Create plan configuration with shared embedded postgres for performance config := &planCmd.PlanConfig{ Host: containerInfo.Host, Port: containerInfo.Port, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: containerInfo.DBName, + User: containerInfo.User, + Password: containerInfo.Password, Schema: "public", File: schemaFile, ApplicationName: "pgschema", @@ -571,7 +647,14 @@ func executeIgnorePlanCommand(t *testing.T, containerInfo *testutil.ContainerInf } // executeIgnoreApplyCommandWithError runs the apply command and returns any error -func executeIgnoreApplyCommandWithError(containerInfo *testutil.ContainerInfo, schemaFile string) error { +func executeIgnoreApplyCommandWithError(containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, schemaFile string) error { rootCmd := &cobra.Command{ Use: "pgschema", } @@ -581,9 +664,9 @@ func executeIgnoreApplyCommandWithError(containerInfo *testutil.ContainerInfo, s "apply", "--host", containerInfo.Host, "--port", fmt.Sprintf("%d", containerInfo.Port), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", containerInfo.DBName, + "--user", containerInfo.User, + "--password", containerInfo.Password, "--schema", "public", "--file", schemaFile, "--auto-approve", diff --git a/cmd/include_integration_test.go b/cmd/include_integration_test.go index d146e36c..81122f79 100644 --- a/cmd/include_integration_test.go +++ b/cmd/include_integration_test.go @@ -8,7 +8,7 @@ package cmd // the same organized file structure. import ( - "context" + "database/sql" "fmt" "os" "path/filepath" @@ -26,11 +26,28 @@ func TestIncludeIntegration(t *testing.T) { t.Skip("Skipping integration test in short mode") } - ctx := context.Background() - // Setup PostgreSQL container with specific database - containerInfo := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer containerInfo.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create containerInfo struct to match old API for minimal changes + containerInfo := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Apply the include-based schema using the apply command applyIncludeSchema(t, containerInfo) @@ -47,7 +64,14 @@ func TestIncludeIntegration(t *testing.T) { } // applyIncludeSchema applies the testdata/include/main.sql schema using the apply command -func applyIncludeSchema(t *testing.T, containerInfo *testutil.ContainerInfo) { +func applyIncludeSchema(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}) { mainSQLPath := "../testdata/include/main.sql" // Create a new root command with apply as subcommand @@ -63,9 +87,9 @@ func applyIncludeSchema(t *testing.T, containerInfo *testutil.ContainerInfo) { "apply", "--host", containerInfo.Host, "--port", fmt.Sprintf("%d", containerInfo.Port), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", containerInfo.DBName, + "--user", containerInfo.User, + "--password", containerInfo.Password, "--file", mainSQLPath, "--auto-approve", // Skip interactive confirmation } @@ -81,7 +105,14 @@ func applyIncludeSchema(t *testing.T, containerInfo *testutil.ContainerInfo) { } // executeMultiFileDump runs pgschema dump --multi-file using the CLI command -func executeMultiFileDump(t *testing.T, containerInfo *testutil.ContainerInfo, outputPath string) { +func executeMultiFileDump(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, outputPath string) { // Create a new root command with dump as subcommand rootCmd := &cobra.Command{ Use: "pgschema", @@ -95,9 +126,9 @@ func executeMultiFileDump(t *testing.T, containerInfo *testutil.ContainerInfo, o "dump", "--host", containerInfo.Host, "--port", fmt.Sprintf("%d", containerInfo.Port), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", containerInfo.DBName, + "--user", containerInfo.User, + "--password", containerInfo.Password, "--schema", "public", "--multi-file", "--file", outputPath, diff --git a/cmd/migrate_integration_test.go b/cmd/migrate_integration_test.go index de145f1c..ce8deaa8 100644 --- a/cmd/migrate_integration_test.go +++ b/cmd/migrate_integration_test.go @@ -11,13 +11,12 @@ import ( "strings" "testing" - embeddedpostgres "github.com/fergusstrange/embedded-postgres" "github.com/google/go-cmp/cmp" _ "github.com/jackc/pgx/v5/stdlib" "github.com/pgschema/pgschema/cmd/apply" planCmd "github.com/pgschema/pgschema/cmd/plan" - "github.com/pgschema/pgschema/cmd/util" "github.com/pgschema/pgschema/internal/plan" + "github.com/pgschema/pgschema/internal/postgres" "github.com/pgschema/pgschema/testutil" ) @@ -25,7 +24,7 @@ var ( generate = flag.Bool("generate", false, "generate expected test output files instead of comparing") // sharedEmbeddedPG is a shared embedded PostgreSQL instance used across all integration tests // to significantly improve test performance by avoiding repeated startup/teardown - sharedEmbeddedPG *util.EmbeddedPostgres + sharedEmbeddedPG *postgres.EmbeddedPostgres ) // TestMain sets up shared resources for all tests in this package @@ -35,7 +34,7 @@ func TestMain(m *testing.M) { // Create shared embedded postgres instance for all integration tests // This dramatically improves test performance (from ~60s to ~10s per test) - sharedEmbeddedPG = util.SetupSharedEmbeddedPostgres(nil, embeddedpostgres.PostgresVersion("17.5.0")) + sharedEmbeddedPG = testutil.SetupPostgres(nil) defer sharedEmbeddedPG.Stop() // Run tests @@ -79,11 +78,27 @@ func TestPlanAndApply(t *testing.T) { testDataRoot := "../testdata/diff" // Start a single PostgreSQL container for all test cases - container := testutil.SetupPostgresContainerWithDB(ctx, t, "postgres", "testuser", "testpass") - defer container.Terminate(ctx, t) - - containerHost := container.Host - portMapped := container.Port + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Get test filter from environment variable testFilter := os.Getenv("PGSCHEMA_TEST_FILTER") @@ -167,7 +182,7 @@ func TestPlanAndApply(t *testing.T) { // Run all test cases using the shared container for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - runPlanAndApplyTest(t, ctx, containerHost, portMapped, tc) + runPlanAndApplyTest(t, ctx, container, tc) }) } } @@ -182,7 +197,16 @@ type testCase struct { } // runPlanAndApplyTest executes a single plan and apply test case with test-specific database -func runPlanAndApplyTest(t *testing.T, ctx context.Context, containerHost string, portMapped int, tc testCase) { +func runPlanAndApplyTest(t *testing.T, ctx context.Context, container *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, tc testCase) { + containerHost := container.Host + portMapped := container.Port // Create a unique database name for this test case (replace invalid chars) dbName := "test_" + strings.ReplaceAll(strings.ReplaceAll(tc.name, "/", "_"), "-", "_") // PostgreSQL identifiers are limited to 63 characters @@ -214,12 +238,12 @@ func runPlanAndApplyTest(t *testing.T, ctx context.Context, containerHost string // STEP 2: Test plan command with new.sql as target t.Logf("--- Testing plan command outputs ---") - testPlanOutputs(t, containerHost, portMapped, dbName, tc.newFile, tc.planSQLFile, tc.planJSONFile, tc.planTXTFile) + testPlanOutputs(t, container, dbName, tc.newFile, tc.planSQLFile, tc.planJSONFile, tc.planTXTFile) if !*generate { // STEP 3: Apply the migration using apply command t.Logf("--- Applying migration using apply command ---") - err = applySchemaChanges(containerHost, portMapped, dbName, "testuser", "testpass", "public", tc.newFile) + err = applySchemaChanges(containerHost, portMapped, dbName, container.User, container.Password, "public", tc.newFile) if err != nil { t.Fatalf("Failed to apply schema changes using pgschema apply: %v", err) } @@ -227,7 +251,7 @@ func runPlanAndApplyTest(t *testing.T, ctx context.Context, containerHost string // STEP 4: Test idempotency - plan should produce no changes t.Logf("--- Testing idempotency ---") - secondPlanOutput, err := generatePlanSQLFormatted(containerHost, portMapped, dbName, "testuser", "testpass", "public", tc.newFile) + secondPlanOutput, err := generatePlanSQLFormatted(containerHost, portMapped, dbName, container.User, container.Password, "public", tc.newFile) if err != nil { t.Fatalf("Failed to generate plan SQL for idempotency check: %v", err) } @@ -243,14 +267,23 @@ func runPlanAndApplyTest(t *testing.T, ctx context.Context, containerHost string } // testPlanOutputs tests all plan output formats against expected files -func testPlanOutputs(t *testing.T, containerHost string, portMapped int, dbName, schemaFile, planSQLFile, planJSONFile, planTXTFile string) { +func testPlanOutputs(t *testing.T, container *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, dbName, schemaFile, planSQLFile, planJSONFile, planTXTFile string) { + containerHost := container.Host + portMapped := container.Port // Set fixed timestamp for generate mode to ensure deterministic output if *generate { os.Setenv("PGSCHEMA_TEST_TIME", "1970-01-01T00:00:00Z") defer os.Unsetenv("PGSCHEMA_TEST_TIME") } // Test SQL format - sqlFormattedOutput, err := generatePlanSQLFormatted(containerHost, portMapped, dbName, "testuser", "testpass", "public", schemaFile) + sqlFormattedOutput, err := generatePlanSQLFormatted(containerHost, portMapped, dbName, container.User, container.Password, "public", schemaFile) if err != nil { t.Fatalf("Failed to generate plan SQL formatted output: %v", err) } @@ -283,7 +316,7 @@ func testPlanOutputs(t *testing.T, containerHost string, portMapped int, dbName, } // Test human-readable format - humanOutput, err := generatePlanHuman(containerHost, portMapped, dbName, "testuser", "testpass", "public", schemaFile) + humanOutput, err := generatePlanHuman(containerHost, portMapped, dbName, container.User, container.Password, "public", schemaFile) if err != nil { t.Fatalf("Failed to generate plan human output: %v", err) } @@ -316,7 +349,7 @@ func testPlanOutputs(t *testing.T, containerHost string, portMapped int, dbName, } // Test JSON format - jsonOutput, err := generatePlanJSON(containerHost, portMapped, dbName, "testuser", "testpass", "public", schemaFile) + jsonOutput, err := generatePlanJSON(containerHost, portMapped, dbName, container.User, container.Password, "public", schemaFile) if err != nil { t.Fatalf("Failed to generate plan JSON output: %v", err) } diff --git a/cmd/plan/plan.go b/cmd/plan/plan.go index 40e88ba2..ec5daeaf 100644 --- a/cmd/plan/plan.go +++ b/cmd/plan/plan.go @@ -11,6 +11,7 @@ import ( "github.com/pgschema/pgschema/internal/fingerprint" "github.com/pgschema/pgschema/internal/include" "github.com/pgschema/pgschema/internal/plan" + "github.com/pgschema/pgschema/internal/postgres" "github.com/spf13/cobra" ) @@ -122,30 +123,27 @@ type PlanConfig struct { // CreateEmbeddedPostgresForPlan creates a temporary embedded PostgreSQL instance // for validating the desired state schema. The instance should be stopped by the caller. -func CreateEmbeddedPostgresForPlan(config *PlanConfig) (*util.EmbeddedPostgres, error) { +func CreateEmbeddedPostgresForPlan(config *PlanConfig) (*postgres.EmbeddedPostgres, error) { // Detect target database PostgreSQL version - targetDBConfig := &util.ConnectionConfig{ - Host: config.Host, - Port: config.Port, - Database: config.DB, - User: config.User, - Password: config.Password, - SSLMode: "prefer", - ApplicationName: config.ApplicationName, - } - pgVersion, err := util.DetectPostgresVersionFromConfig(targetDBConfig) + pgVersion, err := postgres.DetectPostgresVersionFromDB( + config.Host, + config.Port, + config.DB, + config.User, + config.Password, + ) if err != nil { return nil, fmt.Errorf("failed to detect PostgreSQL version: %w", err) } // Start embedded PostgreSQL with matching version - embeddedConfig := &util.EmbeddedPostgresConfig{ + embeddedConfig := &postgres.EmbeddedPostgresConfig{ Version: pgVersion, Database: "pgschema_temp", Username: "pgschema", Password: "pgschema", } - embeddedPG, err := util.StartEmbeddedPostgres(embeddedConfig) + embeddedPG, err := postgres.StartEmbeddedPostgres(embeddedConfig) if err != nil { return nil, fmt.Errorf("failed to start embedded PostgreSQL: %w", err) } @@ -156,7 +154,7 @@ func CreateEmbeddedPostgresForPlan(config *PlanConfig) (*util.EmbeddedPostgres, // GeneratePlan generates a migration plan from configuration. // The caller must provide a non-nil embeddedPG instance for validating the desired state schema. // The caller is responsible for managing the embeddedPG lifecycle (creation and cleanup). -func GeneratePlan(config *PlanConfig, embeddedPG *util.EmbeddedPostgres) (*plan.Plan, error) { +func GeneratePlan(config *PlanConfig, embeddedPG *postgres.EmbeddedPostgres) (*plan.Plan, error) { // Load ignore configuration ignoreConfig, err := util.LoadIgnoreFileWithStructure() if err != nil { @@ -184,19 +182,13 @@ func GeneratePlan(config *PlanConfig, embeddedPG *util.EmbeddedPostgres) (*plan. ctx := context.Background() - // Reset the schema to ensure clean state - if err := embeddedPG.ResetSchema(ctx, config.Schema); err != nil { - return nil, fmt.Errorf("failed to reset schema in embedded PostgreSQL: %w", err) - } - - // Apply desired state SQL to embedded PostgreSQL - if err := embeddedPG.ApplySchemaSQL(ctx, config.Schema, desiredState); err != nil { + // Apply desired state SQL to embedded PostgreSQL (resets schema first) + if err := embeddedPG.ApplySchema(ctx, config.Schema, desiredState); err != nil { return nil, fmt.Errorf("failed to apply desired state to embedded PostgreSQL: %w", err) } // Inspect embedded PostgreSQL to get desired state IR - embeddedHost, embeddedPort, embeddedDB := embeddedPG.GetConnectionInfo() - embeddedUsername, embeddedPassword := embeddedPG.GetCredentials() + embeddedHost, embeddedPort, embeddedDB, embeddedUsername, embeddedPassword := embeddedPG.GetConnectionDetails() desiredStateIR, err := util.GetIRFromDatabase(embeddedHost, embeddedPort, embeddedDB, embeddedUsername, embeddedPassword, config.Schema, config.ApplicationName, ignoreConfig) if err != nil { return nil, fmt.Errorf("failed to get desired state from embedded PostgreSQL: %w", err) diff --git a/cmd/plan/plan_integration_test.go b/cmd/plan/plan_integration_test.go index 16063414..5cc3a14e 100644 --- a/cmd/plan/plan_integration_test.go +++ b/cmd/plan/plan_integration_test.go @@ -2,6 +2,7 @@ package plan import ( "context" + "database/sql" "fmt" "os" "path/filepath" @@ -20,11 +21,29 @@ func TestPlanCommand_DatabaseIntegration(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with initial schema - conn := container.Conn initialSQL := ` CREATE TABLE users ( @@ -90,9 +109,9 @@ func TestPlanCommand_DatabaseIntegration(t *testing.T) { args := []string{ "--host", containerHost, "--port", fmt.Sprintf("%d", portMapped), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", container.DBName, + "--user", container.User, + "--password", container.Password, "--file", desiredStateFile, "--output-human", "stdout", } @@ -117,11 +136,29 @@ func TestPlanCommand_OutputFormats(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup simple database schema - conn := container.Conn simpleSQL := ` CREATE TABLE users ( @@ -184,9 +221,9 @@ func TestPlanCommand_OutputFormats(t *testing.T) { args := []string{ "--host", containerHost, "--port", fmt.Sprintf("%d", portMapped), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", container.DBName, + "--user", container.User, + "--password", container.Password, "--file", desiredStateFile, tc.outputFlag, "stdout", } @@ -212,11 +249,29 @@ func TestPlanCommand_SchemaFiltering(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with multiple schemas - conn := container.Conn multiSchemaSQL := ` CREATE SCHEMA app; @@ -279,9 +334,9 @@ func TestPlanCommand_SchemaFiltering(t *testing.T) { args := []string{ "--host", containerHost, "--port", fmt.Sprintf("%d", portMapped), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", container.DBName, + "--user", container.User, + "--password", container.Password, "--schema", "public", // Filter to only public schema "--file", publicSchemaFile, "--output-human", "stdout", @@ -302,12 +357,30 @@ func TestPlanCommand_EmptyDatabase(t *testing.T) { t.Skip("Skipping integration test in short mode") } - ctx := context.Background() var err error // Start PostgreSQL container with empty database - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Create desired state schema file tmpDir := t.TempDir() @@ -348,9 +421,9 @@ func TestPlanCommand_EmptyDatabase(t *testing.T) { args := []string{ "--host", containerHost, "--port", fmt.Sprintf("%d", portMapped), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", container.DBName, + "--user", container.User, + "--password", container.Password, "--file", desiredStateFile, "--output-human", "stdout", } diff --git a/cmd/schema_integration_test.go b/cmd/schema_integration_test.go index a8a67c45..58f8889d 100644 --- a/cmd/schema_integration_test.go +++ b/cmd/schema_integration_test.go @@ -24,11 +24,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { ctx := context.Background() - // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") - defer container.Terminate(ctx, t) - - conn := container.Conn + // Start PostgreSQL + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() // Test Case 1: Plan and Apply to tenant schema using CLI t.Run("cli_plan_and_apply_tenant_schema", func(t *testing.T) { @@ -61,11 +61,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { // Step 1: Generate plan using CLI planOutput, err := executePlanCommand( - container.Host, - container.Port, - "testdb", - "testuser", - "testpass", + host, + port, + dbname, + user, + password, "tenant", // Non-public schema desiredStateFile, ) @@ -82,11 +82,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { // Step 2: Apply changes using CLI err = executeApplyCommand( - container.Host, - container.Port, - "testdb", - "testuser", - "testpass", + host, + port, + dbname, + user, + password, "tenant", // Non-public schema desiredStateFile, ) @@ -169,11 +169,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { // Apply changes ONLY to app_a schema err = executeApplyCommand( - container.Host, - container.Port, - "testdb", - "testuser", - "testpass", + host, + port, + dbname, + user, + password, "app_a", // Target only app_a desiredStateFile, ) @@ -250,11 +250,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { // Step 1: Generate plan using CLI for mixed-case schema planOutput, err := executePlanCommand( - container.Host, - container.Port, - "testdb", - "testuser", - "testpass", + host, + port, + dbname, + user, + password, "MyApp", // Mixed-case schema desiredStateFile, ) @@ -271,11 +271,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { // Step 2: Apply changes using CLI for mixed-case schema err = executeApplyCommand( - container.Host, - container.Port, - "testdb", - "testuser", - "testpass", + host, + port, + dbname, + user, + password, "MyApp", // Mixed-case schema desiredStateFile, ) diff --git a/cmd/util/embedded_postgres_test.go b/cmd/util/embedded_postgres_test.go deleted file mode 100644 index c1290de0..00000000 --- a/cmd/util/embedded_postgres_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package util - -import ( - "context" - "testing" - - embeddedpostgres "github.com/fergusstrange/embedded-postgres" -) - -func TestStartEmbeddedPostgres(t *testing.T) { - if testing.Short() { - t.Skip("Skipping embedded postgres test in short mode") - } - - config := &EmbeddedPostgresConfig{ - Version: embeddedpostgres.PostgresVersion("17.5.0"), - Database: "testdb", - Username: "testuser", - Password: "testpass", - } - - ep, err := StartEmbeddedPostgres(config) - if err != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) - } - defer ep.Stop() - - // Test connection - ctx := context.Background() - if err := ep.GetDB().PingContext(ctx); err != nil { - t.Fatalf("Failed to ping embedded postgres: %v", err) - } - - // Verify connection info - host, port, database := ep.GetConnectionInfo() - if host != "localhost" { - t.Errorf("Expected host 'localhost', got '%s'", host) - } - if port == 0 { - t.Error("Port should not be 0") - } - if database != "testdb" { - t.Errorf("Expected database 'testdb', got '%s'", database) - } -} - -func TestApplySchemaSQL(t *testing.T) { - if testing.Short() { - t.Skip("Skipping embedded postgres test in short mode") - } - - config := &EmbeddedPostgresConfig{ - Version: embeddedpostgres.PostgresVersion("17.5.0"), - Database: "testdb", - Username: "testuser", - Password: "testpass", - } - - ep, err := StartEmbeddedPostgres(config) - if err != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) - } - defer ep.Stop() - - ctx := context.Background() - - // Test applying schema SQL - schemaSQL := ` - CREATE TABLE users ( - id SERIAL PRIMARY KEY, - name TEXT NOT NULL, - email TEXT UNIQUE - ); - - CREATE INDEX idx_users_email ON users(email); - ` - - err = ep.ApplySchemaSQL(ctx, "public", schemaSQL) - if err != nil { - t.Fatalf("Failed to apply schema SQL: %v", err) - } - - // Verify table was created - var tableName string - query := "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users'" - err = ep.GetDB().QueryRowContext(ctx, query).Scan(&tableName) - if err != nil { - t.Fatalf("Failed to query table: %v", err) - } - if tableName != "users" { - t.Errorf("Expected table 'users', got '%s'", tableName) - } - - // Verify index was created - var indexName string - query = "SELECT indexname FROM pg_indexes WHERE tablename = 'users' AND indexname = 'idx_users_email'" - err = ep.GetDB().QueryRowContext(ctx, query).Scan(&indexName) - if err != nil { - t.Fatalf("Failed to query index: %v", err) - } - if indexName != "idx_users_email" { - t.Errorf("Expected index 'idx_users_email', got '%s'", indexName) - } -} - -func TestApplySchemaSQL_CustomSchema(t *testing.T) { - if testing.Short() { - t.Skip("Skipping embedded postgres test in short mode") - } - - config := &EmbeddedPostgresConfig{ - Version: embeddedpostgres.PostgresVersion("17.5.0"), - Database: "testdb", - Username: "testuser", - Password: "testpass", - } - - ep, err := StartEmbeddedPostgres(config) - if err != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) - } - defer ep.Stop() - - ctx := context.Background() - - // Test applying schema SQL to custom schema - schemaSQL := ` - CREATE TABLE products ( - id SERIAL PRIMARY KEY, - name TEXT NOT NULL, - price NUMERIC(10, 2) - ); - ` - - err = ep.ApplySchemaSQL(ctx, "myschema", schemaSQL) - if err != nil { - t.Fatalf("Failed to apply schema SQL: %v", err) - } - - // Verify schema was created - var schemaName string - query := "SELECT schema_name FROM information_schema.schemata WHERE schema_name = 'myschema'" - err = ep.GetDB().QueryRowContext(ctx, query).Scan(&schemaName) - if err != nil { - t.Fatalf("Failed to query schema: %v", err) - } - if schemaName != "myschema" { - t.Errorf("Expected schema 'myschema', got '%s'", schemaName) - } - - // Verify table was created in custom schema - var tableName string - query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'myschema' AND table_name = 'products'" - err = ep.GetDB().QueryRowContext(ctx, query).Scan(&tableName) - if err != nil { - t.Fatalf("Failed to query table: %v", err) - } - if tableName != "products" { - t.Errorf("Expected table 'products', got '%s'", tableName) - } -} - -func TestApplySchemaSQL_InvalidSQL(t *testing.T) { - if testing.Short() { - t.Skip("Skipping embedded postgres test in short mode") - } - - config := &EmbeddedPostgresConfig{ - Version: embeddedpostgres.PostgresVersion("17.5.0"), - Database: "testdb", - Username: "testuser", - Password: "testpass", - } - - ep, err := StartEmbeddedPostgres(config) - if err != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) - } - defer ep.Stop() - - ctx := context.Background() - - // Test with invalid SQL - invalidSQL := "CREATE TABLE invalid syntax here" - err = ep.ApplySchemaSQL(ctx, "public", invalidSQL) - if err == nil { - t.Error("Expected error for invalid SQL, got none") - } -} - -func TestStop(t *testing.T) { - if testing.Short() { - t.Skip("Skipping embedded postgres test in short mode") - } - - config := &EmbeddedPostgresConfig{ - Version: embeddedpostgres.PostgresVersion("17.5.0"), - Database: "testdb", - Username: "testuser", - Password: "testpass", - } - - ep, err := StartEmbeddedPostgres(config) - if err != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) - } - - // Test stopping - err = ep.Stop() - if err != nil { - t.Fatalf("Failed to stop embedded postgres: %v", err) - } - - // Verify connection is closed - ctx := context.Background() - err = ep.GetDB().PingContext(ctx) - if err == nil { - t.Error("Expected error pinging stopped postgres, got none") - } -} diff --git a/cmd/util/embedded_postgres_test_helper.go b/cmd/util/embedded_postgres_test_helper.go deleted file mode 100644 index b81b3add..00000000 --- a/cmd/util/embedded_postgres_test_helper.go +++ /dev/null @@ -1,50 +0,0 @@ -package util - -import ( - "testing" - - embeddedpostgres "github.com/fergusstrange/embedded-postgres" -) - -// SetupSharedEmbeddedPostgres creates a shared embedded PostgreSQL instance for test suites. -// This instance can be reused across multiple test cases to significantly improve test performance. -// -// Usage example: -// -// func TestMain(m *testing.M) { -// // Create shared embedded postgres for all tests -// embeddedPG := util.SetupSharedEmbeddedPostgres(nil, embeddedpostgres.PostgresVersion("17.5.0")) -// defer embeddedPG.Stop() -// -// // Run tests -// code := m.Run() -// os.Exit(code) -// } -// -// func TestMyFeature(t *testing.T) { -// config := &plan.PlanConfig{ -// // ... other config ... -// EmbeddedPG: embeddedPG, // Reuse shared instance -// } -// plan, err := plan.GeneratePlan(config) -// // ... -// } -func SetupSharedEmbeddedPostgres(t testing.TB, version embeddedpostgres.PostgresVersion) *EmbeddedPostgres { - config := &EmbeddedPostgresConfig{ - Version: version, - Database: "testdb", - Username: "testuser", - Password: "testpass", - } - - embeddedPG, err := StartEmbeddedPostgres(config) - if err != nil { - if t != nil { - t.Fatalf("Failed to start shared embedded PostgreSQL: %v", err) - } else { - panic("Failed to start shared embedded PostgreSQL: " + err.Error()) - } - } - - return embeddedPG -} diff --git a/cmd/util/postgres_version.go b/cmd/util/postgres_version.go deleted file mode 100644 index cb951ee1..00000000 --- a/cmd/util/postgres_version.go +++ /dev/null @@ -1,84 +0,0 @@ -package util - -import ( - "context" - "database/sql" - "fmt" - "strconv" - "strings" - - embeddedpostgres "github.com/fergusstrange/embedded-postgres" -) - -// DetectPostgresVersion queries the target database to determine its PostgreSQL version -// and returns the corresponding embedded-postgres version string -func DetectPostgresVersion(db *sql.DB) (embeddedpostgres.PostgresVersion, error) { - ctx := context.Background() - - // Query PostgreSQL version number (e.g., 170005 for 17.5) - var versionNum int - err := db.QueryRowContext(ctx, "SHOW server_version_num").Scan(&versionNum) - if err != nil { - return "", fmt.Errorf("failed to query PostgreSQL version: %w", err) - } - - // Extract major version: version_num / 10000 - // e.g., 170005 / 10000 = 17 - majorVersion := versionNum / 10000 - - // Map to embedded-postgres version - return mapToEmbeddedPostgresVersion(majorVersion) -} - -// DetectPostgresVersionFromConfig queries the target database using connection config -// and returns the corresponding embedded-postgres version string -func DetectPostgresVersionFromConfig(config *ConnectionConfig) (embeddedpostgres.PostgresVersion, error) { - // Connect to target database - db, err := Connect(config) - if err != nil { - return "", fmt.Errorf("failed to connect to detect version: %w", err) - } - defer db.Close() - - return DetectPostgresVersion(db) -} - -// mapToEmbeddedPostgresVersion maps a PostgreSQL major version to embedded-postgres version -// Supported versions: 14, 15, 16, 17 -func mapToEmbeddedPostgresVersion(majorVersion int) (embeddedpostgres.PostgresVersion, error) { - switch majorVersion { - case 14: - return embeddedpostgres.PostgresVersion("14.18.0"), nil - case 15: - return embeddedpostgres.PostgresVersion("15.13.0"), nil - case 16: - return embeddedpostgres.PostgresVersion("16.9.0"), nil - case 17: - return embeddedpostgres.PostgresVersion("17.5.0"), nil - default: - return "", fmt.Errorf("unsupported PostgreSQL version %d (supported: 14, 15, 16, 17)", majorVersion) - } -} - -// ParseVersionString parses a PostgreSQL version string (e.g., "17.5") and returns major version -func ParseVersionString(versionStr string) (int, error) { - // Handle various formats: "17.5", "17.5.0", "PostgreSQL 17.5", etc. - // Extract the version number part - versionStr = strings.TrimSpace(versionStr) - - // Remove "PostgreSQL " prefix if present - versionStr = strings.TrimPrefix(versionStr, "PostgreSQL ") - - // Split by "." and take the first part (major version) - parts := strings.Split(versionStr, ".") - if len(parts) == 0 { - return 0, fmt.Errorf("invalid version string: %s", versionStr) - } - - majorVersion, err := strconv.Atoi(parts[0]) - if err != nil { - return 0, fmt.Errorf("failed to parse major version from %s: %w", versionStr, err) - } - - return majorVersion, nil -} diff --git a/cmd/util/postgres_version_test.go b/cmd/util/postgres_version_test.go deleted file mode 100644 index a9661786..00000000 --- a/cmd/util/postgres_version_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package util - -import ( - "testing" - - embeddedpostgres "github.com/fergusstrange/embedded-postgres" -) - -func TestMapToEmbeddedPostgresVersion(t *testing.T) { - tests := []struct { - name string - majorVersion int - expected embeddedpostgres.PostgresVersion - expectError bool - }{ - { - name: "PostgreSQL 14", - majorVersion: 14, - expected: embeddedpostgres.PostgresVersion("14.18.0"), - expectError: false, - }, - { - name: "PostgreSQL 15", - majorVersion: 15, - expected: embeddedpostgres.PostgresVersion("15.13.0"), - expectError: false, - }, - { - name: "PostgreSQL 16", - majorVersion: 16, - expected: embeddedpostgres.PostgresVersion("16.9.0"), - expectError: false, - }, - { - name: "PostgreSQL 17", - majorVersion: 17, - expected: embeddedpostgres.PostgresVersion("17.5.0"), - expectError: false, - }, - { - name: "Unsupported version 13", - majorVersion: 13, - expected: "", - expectError: true, - }, - { - name: "Unsupported version 18", - majorVersion: 18, - expected: "", - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := mapToEmbeddedPostgresVersion(tt.majorVersion) - - if tt.expectError { - if err == nil { - t.Errorf("expected error but got none") - } - } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if result != tt.expected { - t.Errorf("expected version %s, got %s", tt.expected, result) - } - } - }) - } -} - -func TestParseVersionString(t *testing.T) { - tests := []struct { - name string - versionStr string - expected int - expectError bool - }{ - { - name: "Simple version 17.5", - versionStr: "17.5", - expected: 17, - expectError: false, - }, - { - name: "Version with patch 17.5.0", - versionStr: "17.5.0", - expected: 17, - expectError: false, - }, - { - name: "Version with prefix", - versionStr: "PostgreSQL 17.5", - expected: 17, - expectError: false, - }, - { - name: "Version 14.18.0", - versionStr: "14.18.0", - expected: 14, - expectError: false, - }, - { - name: "Version with whitespace", - versionStr: " 16.9 ", - expected: 16, - expectError: false, - }, - { - name: "Invalid version", - versionStr: "invalid", - expected: 0, - expectError: true, - }, - { - name: "Empty string", - versionStr: "", - expected: 0, - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := ParseVersionString(tt.versionStr) - - if tt.expectError { - if err == nil { - t.Errorf("expected error but got none") - } - } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if result != tt.expected { - t.Errorf("expected version %d, got %d", tt.expected, result) - } - } - }) - } -} diff --git a/internal/diff/diff_test.go b/internal/diff/diff_test.go index 4f6aedf5..53ddeba6 100644 --- a/internal/diff/diff_test.go +++ b/internal/diff/diff_test.go @@ -1,21 +1,24 @@ package diff import ( - "context" "os" "path/filepath" "strings" "testing" + "github.com/pgschema/pgschema/internal/postgres" "github.com/pgschema/pgschema/ir" + "github.com/pgschema/pgschema/testutil" ) +// sharedTestPostgres is the shared embedded postgres instance for all tests in this package +var sharedTestPostgres *postgres.EmbeddedPostgres + // TestMain sets up shared resources for all tests in this package func TestMain(m *testing.M) { // Create shared embedded postgres for all tests to dramatically improve performance - ctx := context.Background() - container := ir.SetupSharedTestContainer(ctx, nil) - defer container.Terminate(ctx, nil) + sharedTestPostgres = testutil.SetupPostgres(nil) + defer sharedTestPostgres.Stop() // Run tests code := m.Run() @@ -53,7 +56,8 @@ func buildSQLFromSteps(diffs []Diff) string { // parseSQL is a helper function to convert SQL string to IR for tests // Uses embedded PostgreSQL to ensure tests use the same code path as production func parseSQL(t *testing.T, sql string) *ir.IR { - return ir.ParseSQLForTest(t, sql, "public") + t.Helper() + return testutil.ParseSQLToIR(t, sharedTestPostgres, sql, "public") } // TestDiffFromFiles runs file-based diff tests from testdata directory. diff --git a/internal/plan/plan_test.go b/internal/plan/plan_test.go index ec7cacb7..7b2d1239 100644 --- a/internal/plan/plan_test.go +++ b/internal/plan/plan_test.go @@ -1,7 +1,6 @@ package plan import ( - "context" "encoding/json" "fmt" "os" @@ -13,15 +12,19 @@ import ( "github.com/google/go-cmp/cmp" "github.com/pgschema/pgschema/internal/diff" + "github.com/pgschema/pgschema/internal/postgres" "github.com/pgschema/pgschema/ir" + "github.com/pgschema/pgschema/testutil" ) +// sharedTestPostgres is the shared embedded postgres instance for all tests in this package +var sharedTestPostgres *postgres.EmbeddedPostgres + // TestMain sets up shared resources for all tests in this package func TestMain(m *testing.M) { // Create shared embedded postgres for all tests to dramatically improve performance - ctx := context.Background() - container := ir.SetupSharedTestContainer(ctx, nil) - defer container.Terminate(ctx, nil) + sharedTestPostgres = testutil.SetupPostgres(nil) + defer sharedTestPostgres.Stop() // Run tests code := m.Run() @@ -54,7 +57,8 @@ func discoverTestDataVersions(testdataDir string) ([]string, error) { // parseSQL is a helper function to convert SQL string to IR for tests // Uses embedded PostgreSQL to ensure tests use the same code path as production func parseSQL(t *testing.T, sql string) *ir.IR { - return ir.ParseSQLForTest(t, sql, "public") + t.Helper() + return testutil.ParseSQLToIR(t, sharedTestPostgres, sql, "public") } func TestPlanSummary(t *testing.T) { @@ -178,62 +182,6 @@ func TestPlanJSONRoundTrip(t *testing.T) { } } -func TestPlanToJSON(t *testing.T) { - oldSQL := `CREATE TABLE users ( - id integer NOT NULL - );` - - newSQL := `CREATE TABLE users ( - id integer NOT NULL, - name text NOT NULL - );` - - oldIR := parseSQL(t, oldSQL) - newIR := parseSQL(t, newSQL) - diffs := diff.GenerateMigration(oldIR, newIR, "public") - - plan := NewPlan(diffs) - - // Test non-debug version (default behavior) - should NOT contain source field - jsonOutput, err := plan.ToJSON() - if err != nil { - t.Fatalf("Failed to generate JSON: %v", err) - } - - if !strings.Contains(jsonOutput, `"groups"`) { - t.Error("JSON output should contain groups") - } - - // Non-debug version should NOT contain source field - if strings.Contains(jsonOutput, `"source"`) { - t.Error("JSON output should NOT contain source field when debug is disabled") - } - - // Test debug version - should contain source field - jsonDebugOutput, err := plan.ToJSONWithDebug(true) - if err != nil { - t.Fatalf("Failed to generate debug JSON: %v", err) - } - - if !strings.Contains(jsonDebugOutput, `"groups"`) { - t.Error("Debug JSON output should contain groups") - } - - // Debug version should still work (but Steps don't have source field anymore) - // This is expected behavior after refactoring to Step structure - if jsonDebugOutput == "" { - t.Error("Debug JSON output should not be empty") - } - - if !strings.Contains(jsonOutput, `"version"`) { - t.Error("JSON output should contain version") - } - - if !strings.Contains(jsonOutput, `"created_at"`) { - t.Error("JSON output should contain created_at timestamp") - } -} - func TestPlanNoChanges(t *testing.T) { sql := `CREATE TABLE users ( id integer NOT NULL diff --git a/cmd/util/embedded_postgres.go b/internal/postgres/embedded.go similarity index 57% rename from cmd/util/embedded_postgres.go rename to internal/postgres/embedded.go index 5b765e95..1fce0b9d 100644 --- a/cmd/util/embedded_postgres.go +++ b/internal/postgres/embedded.go @@ -1,4 +1,7 @@ -package util +// Package postgres provides embedded PostgreSQL functionality for production use. +// This package is used by the plan command to create temporary PostgreSQL instances +// for validating desired state schemas. +package postgres import ( "context" @@ -12,14 +15,17 @@ import ( embeddedpostgres "github.com/fergusstrange/embedded-postgres" _ "github.com/jackc/pgx/v5/stdlib" - "github.com/pgschema/pgschema/internal/logger" ) -// EmbeddedPostgres manages a temporary embedded PostgreSQL instance +// PostgresVersion is an alias for the embedded-postgres version type. +type PostgresVersion = embeddedpostgres.PostgresVersion + +// EmbeddedPostgres manages a temporary embedded PostgreSQL instance. +// This is used by the plan command to validate desired state schemas. type EmbeddedPostgres struct { instance *embeddedpostgres.EmbeddedPostgres db *sql.DB - version embeddedpostgres.PostgresVersion + version PostgresVersion host string port int database string @@ -30,26 +36,38 @@ type EmbeddedPostgres struct { // EmbeddedPostgresConfig holds configuration for starting embedded PostgreSQL type EmbeddedPostgresConfig struct { - Version embeddedpostgres.PostgresVersion + Version PostgresVersion Database string Username string Password string } -// findAvailablePort finds an available TCP port for PostgreSQL to use -func findAvailablePort() (int, error) { - listener, err := net.Listen("tcp", ":0") +// DetectPostgresVersionFromDB connects to a database and detects its version +// This is a convenience function that opens a connection, detects the version, and closes it +func DetectPostgresVersionFromDB(host string, port int, database, user, password string) (PostgresVersion, error) { + // Build connection string + dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=prefer", + user, password, host, port, database) + + // Connect to database + db, err := sql.Open("pgx", dsn) if err != nil { - return 0, err + return "", fmt.Errorf("failed to connect to database: %w", err) } - defer listener.Close() - return listener.Addr().(*net.TCPAddr).Port, nil + defer db.Close() + + // Test the connection + ctx := context.Background() + if err := db.PingContext(ctx); err != nil { + return "", fmt.Errorf("failed to ping database: %w", err) + } + + // Detect version + return detectPostgresVersion(db) } // StartEmbeddedPostgres starts a temporary embedded PostgreSQL instance func StartEmbeddedPostgres(config *EmbeddedPostgresConfig) (*EmbeddedPostgres, error) { - log := logger.Get() - // Create unique runtime path with timestamp (using nanoseconds for uniqueness) timestamp := time.Now().Format("20060102_150405.000000000") runtimePath := filepath.Join(os.TempDir(), fmt.Sprintf("pgschema-plan-%s", timestamp)) @@ -60,13 +78,6 @@ func StartEmbeddedPostgres(config *EmbeddedPostgresConfig) (*EmbeddedPostgres, e return nil, fmt.Errorf("failed to find available port: %w", err) } - log.Debug("Starting embedded PostgreSQL", - "version", config.Version, - "port", port, - "database", config.Database, - "runtime_path", runtimePath, - ) - // Configure embedded postgres pgConfig := embeddedpostgres.DefaultConfig(). Version(config.Version). @@ -78,7 +89,7 @@ func StartEmbeddedPostgres(config *EmbeddedPostgresConfig) (*EmbeddedPostgres, e DataPath(filepath.Join(runtimePath, "data")). Logger(io.Discard). // Suppress embedded-postgres startup logs StartParameters(map[string]string{ - "logging_collector": "off", // Disable log collector + "logging_collector": "off", // Disable log collector "log_destination": "stderr", // Send logs to stderr (which we discard) "log_min_messages": "PANIC", // Only log PANIC level messages "log_statement": "none", // Don't log SQL statements @@ -113,11 +124,6 @@ func StartEmbeddedPostgres(config *EmbeddedPostgresConfig) (*EmbeddedPostgres, e return nil, fmt.Errorf("failed to ping embedded PostgreSQL: %w", err) } - log.Debug("Embedded PostgreSQL started successfully", - "host", host, - "port", port, - ) - return &EmbeddedPostgres{ instance: instance, db: db, @@ -131,61 +137,56 @@ func StartEmbeddedPostgres(config *EmbeddedPostgresConfig) (*EmbeddedPostgres, e }, nil } -// GetDB returns the database connection -func (ep *EmbeddedPostgres) GetDB() *sql.DB { - return ep.db -} - -// GetConnectionInfo returns connection details -func (ep *EmbeddedPostgres) GetConnectionInfo() (host string, port int, database string) { - return ep.host, ep.port, ep.database -} - -// GetCredentials returns the username and password for the embedded PostgreSQL instance -func (ep *EmbeddedPostgres) GetCredentials() (username string, password string) { - return ep.username, ep.password -} +// Stop stops and cleans up the embedded PostgreSQL instance +func (ep *EmbeddedPostgres) Stop() error { + // Close database connection + if ep.db != nil { + ep.db.Close() + } -// ResetSchema drops and recreates a schema, clearing all objects -// This is useful for tests that want to reuse the same embedded postgres instance -func (ep *EmbeddedPostgres) ResetSchema(ctx context.Context, schema string) error { - log := logger.Get() - log.Debug("Resetting schema in embedded PostgreSQL", - "schema", schema, - ) + // Stop PostgreSQL instance + var stopErr error + if ep.instance != nil { + stopErr = ep.instance.Stop() + } - // Drop the schema if it exists (CASCADE to drop all objects) - dropSchemaSQL := fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", QuoteIdentifier(schema)) - if _, err := ep.db.ExecContext(ctx, dropSchemaSQL); err != nil { - return fmt.Errorf("failed to drop schema %s: %w", schema, err) + // Clean up runtime directory + if ep.runtimePath != "" { + if err := os.RemoveAll(ep.runtimePath); err != nil { + // Don't return error here - just ignore cleanup failures + // This can happen on Windows when files are still in use + } } - // Recreate the schema - createSchemaSQL := fmt.Sprintf("CREATE SCHEMA %s", QuoteIdentifier(schema)) - if _, err := ep.db.ExecContext(ctx, createSchemaSQL); err != nil { - return fmt.Errorf("failed to create schema %s: %w", schema, err) + if stopErr != nil { + return fmt.Errorf("failed to stop embedded PostgreSQL: %w", stopErr) } - log.Debug("Schema reset successfully") return nil } -// ApplySchemaSQL applies SQL schema to the embedded PostgreSQL database -func (ep *EmbeddedPostgres) ApplySchemaSQL(ctx context.Context, schema string, sql string) error { - log := logger.Get() - log.Debug("Applying schema SQL to embedded PostgreSQL", - "schema", schema, - "sql_length", len(sql), - ) +// GetConnectionDetails returns all connection details needed to connect to the embedded PostgreSQL instance +func (ep *EmbeddedPostgres) GetConnectionDetails() (host string, port int, database, username, password string) { + return ep.host, ep.port, ep.database, ep.username, ep.password +} + +// ApplySchema resets a schema (drops and recreates it) and applies SQL to it. +// This ensures a clean state before applying the desired schema definition. +func (ep *EmbeddedPostgres) ApplySchema(ctx context.Context, schema string, sql string) error { + // Drop the schema if it exists (CASCADE to drop all objects) + dropSchemaSQL := fmt.Sprintf("DROP SCHEMA IF EXISTS \"%s\" CASCADE", schema) + if _, err := ep.db.ExecContext(ctx, dropSchemaSQL); err != nil { + return fmt.Errorf("failed to drop schema %s: %w", schema, err) + } - // Create the schema if it doesn't exist - createSchemaSQL := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", QuoteIdentifier(schema)) + // Create the schema + createSchemaSQL := fmt.Sprintf("CREATE SCHEMA \"%s\"", schema) if _, err := ep.db.ExecContext(ctx, createSchemaSQL); err != nil { return fmt.Errorf("failed to create schema %s: %w", schema, err) } // Set search_path to the target schema - setSearchPathSQL := fmt.Sprintf("SET search_path TO %s", QuoteIdentifier(schema)) + setSearchPathSQL := fmt.Sprintf("SET search_path TO \"%s\"", schema) if _, err := ep.db.ExecContext(ctx, setSearchPathSQL); err != nil { return fmt.Errorf("failed to set search_path: %w", err) } @@ -197,51 +198,52 @@ func (ep *EmbeddedPostgres) ApplySchemaSQL(ctx context.Context, schema string, s return fmt.Errorf("failed to apply schema SQL: %w", err) } - log.Debug("Schema SQL applied successfully") return nil } -// Stop stops and cleans up the embedded PostgreSQL instance -func (ep *EmbeddedPostgres) Stop() error { - log := logger.Get() - log.Debug("Stopping embedded PostgreSQL", - "runtime_path", ep.runtimePath, - ) - - // Close database connection - if ep.db != nil { - ep.db.Close() +// findAvailablePort finds an available TCP port for PostgreSQL to use +func findAvailablePort() (int, error) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + return 0, err } + defer listener.Close() + return listener.Addr().(*net.TCPAddr).Port, nil +} - // Stop PostgreSQL instance - var stopErr error - if ep.instance != nil { - stopErr = ep.instance.Stop() +// mapToEmbeddedPostgresVersion maps a PostgreSQL major version to embedded-postgres version +// Supported versions: 14, 15, 16, 17 +func mapToEmbeddedPostgresVersion(majorVersion int) (PostgresVersion, error) { + switch majorVersion { + case 14: + return PostgresVersion("14.18.0"), nil + case 15: + return PostgresVersion("15.13.0"), nil + case 16: + return PostgresVersion("16.9.0"), nil + case 17: + return PostgresVersion("17.5.0"), nil + default: + return "", fmt.Errorf("unsupported PostgreSQL version %d (supported: 14, 15, 16, 17)", majorVersion) } +} - // Clean up runtime directory - if ep.runtimePath != "" { - if err := os.RemoveAll(ep.runtimePath); err != nil { - log.Debug("Failed to clean up runtime directory", - "path", ep.runtimePath, - "error", err, - ) - // Don't return error here - just log it - } - } +// detectPostgresVersion queries the target database to determine its PostgreSQL version +// and returns the corresponding embedded-postgres version string +func detectPostgresVersion(db *sql.DB) (PostgresVersion, error) { + ctx := context.Background() - if stopErr != nil { - return fmt.Errorf("failed to stop embedded PostgreSQL: %w", stopErr) + // Query PostgreSQL version number (e.g., 170005 for 17.5) + var versionNum int + err := db.QueryRowContext(ctx, "SHOW server_version_num").Scan(&versionNum) + if err != nil { + return "", fmt.Errorf("failed to query PostgreSQL version: %w", err) } - log.Debug("Embedded PostgreSQL stopped and cleaned up") - return nil -} + // Extract major version: version_num / 10000 + // e.g., 170005 / 10000 = 17 + majorVersion := versionNum / 10000 -// QuoteIdentifier quotes a PostgreSQL identifier (schema, table, column name) -// This is a simple implementation - for production use, consider using pq.QuoteIdentifier -func QuoteIdentifier(identifier string) string { - // For now, just use the IR package's quote function - // In a production system, you might want to use a proper quoting library - return fmt.Sprintf("\"%s\"", identifier) + // Map to embedded-postgres version + return mapToEmbeddedPostgresVersion(majorVersion) } diff --git a/ir/testutil.go b/ir/testutil.go deleted file mode 100644 index ee9c0fe0..00000000 --- a/ir/testutil.go +++ /dev/null @@ -1,281 +0,0 @@ -// Package ir provides an intermediate representation for PostgreSQL schemas -package ir - -import ( - "context" - "database/sql" - "fmt" - "io" - "net" - "os" - "path/filepath" - "strings" - "testing" - "time" - - embeddedpostgres "github.com/fergusstrange/embedded-postgres" - _ "github.com/jackc/pgx/v5/stdlib" -) - -// getPostgresVersion returns the PostgreSQL version to use for testing. -// It reads from the PGSCHEMA_POSTGRES_VERSION environment variable, -// defaulting to "17" if not set. -// Returns an error if an unsupported version is specified. -func getPostgresVersion() (embeddedpostgres.PostgresVersion, error) { - versionStr := os.Getenv("PGSCHEMA_POSTGRES_VERSION") - if versionStr == "" { - return embeddedpostgres.PostgresVersion("17.5.0"), nil - } - - switch versionStr { - case "14": - return embeddedpostgres.PostgresVersion("14.18.0"), nil - case "15": - return embeddedpostgres.PostgresVersion("15.13.0"), nil - case "16": - return embeddedpostgres.PostgresVersion("16.9.0"), nil - case "17": - return embeddedpostgres.PostgresVersion("17.5.0"), nil - default: - return "", fmt.Errorf("unsupported PGSCHEMA_POSTGRES_VERSION: %s (supported versions: 14, 15, 16, 17)", versionStr) - } -} - -// findAvailablePort finds an available TCP port for PostgreSQL to use -func findAvailablePort() (int, error) { - listener, err := net.Listen("tcp", ":0") - if err != nil { - return 0, err - } - defer listener.Close() - return listener.Addr().(*net.TCPAddr).Port, nil -} - -// ContainerInfo holds PostgreSQL instance connection details for testing -type ContainerInfo struct { - Database *embeddedpostgres.EmbeddedPostgres - Host string - Port int - DSN string - Conn *sql.DB - RuntimePath string -} - -// setupPostgresContainer creates a new PostgreSQL test container -func setupPostgresContainer(ctx context.Context, t *testing.T) *ContainerInfo { - return setupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") -} - -// fatalError handles test failures - uses t.Fatalf if available, otherwise panics -func fatalError(t *testing.T, format string, args ...interface{}) { - if t != nil { - t.Fatalf(format, args...) - } else { - panic(fmt.Sprintf(format, args...)) - } -} - -// setupPostgresContainerWithDB creates a new PostgreSQL instance with custom database settings -func setupPostgresContainerWithDB(ctx context.Context, t *testing.T, database, username, password string) *ContainerInfo { - // Extract test name and create unique runtime path - testName := "shared" - if t != nil { - testName = strings.ReplaceAll(t.Name(), "/", "_") // Replace slashes for subtest names - } - timestamp := time.Now().Format("20060102_150405.000000000") - runtimePath := filepath.Join(os.TempDir(), fmt.Sprintf("pgschema-test-%s-%s", testName, timestamp)) - - // Find an available port - port, err := findAvailablePort() - if err != nil { - fatalError(t, "Failed to find available port: %v", err) - } - - // Get PostgreSQL version - pgVersion, err := getPostgresVersion() - if err != nil { - fatalError(t, "Failed to get PostgreSQL version: %v", err) - } - - // Configure embedded postgres with unique runtime path and dynamic port - config := embeddedpostgres.DefaultConfig(). - Version(pgVersion). - Database(database). - Username(username). - Password(password). - Port(uint32(port)). - RuntimePath(runtimePath). - DataPath(filepath.Join(runtimePath, "data")). - Logger(io.Discard). // Suppress embedded-postgres startup logs - StartParameters(map[string]string{ - "logging_collector": "off", // Disable log collector - "log_destination": "stderr", // Send logs to stderr (which we discard above) - "log_min_messages": "PANIC", // Only log PANIC level messages - "log_statement": "none", // Don't log SQL statements - "log_min_duration_statement": "-1", // Don't log slow queries - }) - - // Create and start PostgreSQL instance - postgres := embeddedpostgres.NewDatabase(config) - err = postgres.Start() - if err != nil { - fatalError(t, "Failed to start embedded postgres: %v", err) - } - - // Build connection string - host := "localhost" - testDSN := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", - username, password, host, port, database) - - // Connect to database - conn, err := sql.Open("pgx", testDSN) - if err != nil { - postgres.Stop() - fatalError(t, "Failed to connect to database: %v", err) - } - - // Test the connection - if err := conn.PingContext(ctx); err != nil { - conn.Close() - postgres.Stop() - fatalError(t, "Failed to ping database: %v", err) - } - - return &ContainerInfo{ - Database: postgres, - Host: host, - Port: port, - DSN: testDSN, - Conn: conn, - RuntimePath: runtimePath, - } -} - -// terminate cleans up the database instance and connection -func (ci *ContainerInfo) terminate(ctx context.Context, t *testing.T) { - ci.Conn.Close() - if err := ci.Database.Stop(); err != nil { - t.Logf("Failed to stop embedded postgres: %v", err) - } - // Clean up the runtime directory - if ci.RuntimePath != "" { - if err := os.RemoveAll(ci.RuntimePath); err != nil { - t.Logf("Failed to clean up runtime directory: %v", err) - } - } -} - -// sharedTestContainer holds an optional shared embedded postgres instance for tests -var sharedTestContainer *ContainerInfo - -// SetSharedTestContainer sets a shared embedded postgres instance for ParseSQLForTest to reuse. -// This significantly improves test performance by avoiding repeated postgres startup/shutdown. -// -// Usage in test packages: -// -// func TestMain(m *testing.M) { -// ctx := context.Background() -// container := ir.SetupSharedTestContainer(ctx, nil) -// defer container.Terminate(ctx, nil) -// -// code := m.Run() -// os.Exit(code) -// } -func SetupSharedTestContainer(ctx context.Context, t testing.TB) *ContainerInfo { - // Convert testing.TB to *testing.T if needed - var tt *testing.T - if t != nil { - if tPtr, ok := t.(*testing.T); ok { - tt = tPtr - } else { - // For testing.TB that's not *testing.T (like *testing.M), create a dummy *testing.T - // This is safe because setupPostgresContainer only uses t for Fatalf on errors - panic("SetupSharedTestContainer requires *testing.T or nil") - } - } - container := setupPostgresContainer(ctx, tt) - sharedTestContainer = container - return container -} - -// Terminate cleans up the container (exported for use by test packages) -func (ci *ContainerInfo) Terminate(ctx context.Context, t testing.TB) { - // Convert testing.TB to *testing.T if needed - var tt *testing.T - if t != nil { - if tPtr, ok := t.(*testing.T); ok { - tt = tPtr - } - // For nil or other types, tt remains nil which is fine for terminate - } - ci.terminate(ctx, tt) -} - -// ParseSQLForTest is a test helper that converts SQL to IR using embedded PostgreSQL. -// This replaces the old parser-based approach for tests. -// -// If a shared test container has been set via SetupSharedTestContainer, it will be reused -// (with the schema reset between calls). Otherwise, a new temporary instance is created. -// -// This ensures tests use the same code path as production (database inspection) rather than parsing. -func ParseSQLForTest(t *testing.T, sqlContent string, schema string) *IR { - t.Helper() - - ctx := context.Background() - - var conn *sql.DB - var needsCleanup bool - - if sharedTestContainer != nil { - // Reuse shared container - reset the schema for clean state - conn = sharedTestContainer.Conn - needsCleanup = false - - // Drop and recreate schema - dropSchema := fmt.Sprintf("DROP SCHEMA IF EXISTS \"%s\" CASCADE", schema) - if _, err := conn.ExecContext(ctx, dropSchema); err != nil { - t.Fatalf("Failed to drop schema: %v", err) - } - createSchema := fmt.Sprintf("CREATE SCHEMA \"%s\"", schema) - if _, err := conn.ExecContext(ctx, createSchema); err != nil { - t.Fatalf("Failed to create schema: %v", err) - } - } else { - // Create new container for this test - container := setupPostgresContainer(ctx, t) - defer container.terminate(ctx, t) - conn = container.Conn - needsCleanup = true - - // Create schema if not public - if schema != "public" { - createSchemaSQL := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS \"%s\"", schema) - if _, err := conn.ExecContext(ctx, createSchemaSQL); err != nil { - t.Fatalf("Failed to create schema: %v", err) - } - } - } - - // Set search_path to target schema - setSearchPathSQL := fmt.Sprintf("SET search_path TO \"%s\"", schema) - if _, err := conn.ExecContext(ctx, setSearchPathSQL); err != nil { - t.Fatalf("Failed to set search_path: %v", err) - } - - // Execute the SQL - if _, err := conn.ExecContext(ctx, sqlContent); err != nil { - t.Fatalf("Failed to apply SQL to embedded PostgreSQL: %v", err) - } - - // Inspect the database to get IR - inspector := NewInspector(conn, nil) - ir, err := inspector.BuildIR(ctx, schema) - if err != nil { - t.Fatalf("Failed to inspect embedded PostgreSQL: %v", err) - } - - // If we created a container just for this test, cleanup happens via defer above - _ = needsCleanup - - return ir -} \ No newline at end of file diff --git a/testutil/postgres.go b/testutil/postgres.go index 7dd0724f..6ab6592a 100644 --- a/testutil/postgres.go +++ b/testutil/postgres.go @@ -5,153 +5,147 @@ import ( "context" "database/sql" "fmt" - "io" - "net" "os" - "path/filepath" - "strings" "testing" - "time" - embeddedpostgres "github.com/fergusstrange/embedded-postgres" _ "github.com/jackc/pgx/v5/stdlib" + "github.com/pgschema/pgschema/internal/postgres" + "github.com/pgschema/pgschema/ir" ) -// getPostgresVersion returns the PostgreSQL version to use for testing. -// It reads from the PGSCHEMA_POSTGRES_VERSION environment variable, -// defaulting to "17" if not set. -func getPostgresVersion() embeddedpostgres.PostgresVersion { - versionStr := os.Getenv("PGSCHEMA_POSTGRES_VERSION") - switch versionStr { - case "14": - return embeddedpostgres.PostgresVersion("14.18.0") - case "15": - return embeddedpostgres.PostgresVersion("15.13.0") - case "16": - return embeddedpostgres.PostgresVersion("16.9.0") - case "17", "": - return embeddedpostgres.PostgresVersion("17.5.0") - default: - return embeddedpostgres.PostgresVersion("17.5.0") +// SetupPostgres creates a PostgreSQL instance for testing. +// It uses the production postgres.EmbeddedPostgres implementation. +// PostgreSQL version is determined from PGSCHEMA_POSTGRES_VERSION environment variable. +func SetupPostgres(t testing.TB) *postgres.EmbeddedPostgres { + + // Determine PostgreSQL version from environment + version := getPostgresVersion() + + // Create configuration for production postgres package + config := &postgres.EmbeddedPostgresConfig{ + Version: version, + Database: "testdb", + Username: "testuser", + Password: "testpass", } -} -// findAvailablePort finds an available TCP port for PostgreSQL to use -func findAvailablePort() (int, error) { - listener, err := net.Listen("tcp", ":0") + // Start embedded PostgreSQL using production code + embeddedPG, err := postgres.StartEmbeddedPostgres(config) if err != nil { - return 0, err + if t != nil { + t.Fatalf("Failed to start embedded PostgreSQL: %v", err) + } else { + panic("Failed to start embedded PostgreSQL: " + err.Error()) + } } - defer listener.Close() - return listener.Addr().(*net.TCPAddr).Port, nil -} -// ContainerInfo holds PostgreSQL instance connection details -type ContainerInfo struct { - Database *embeddedpostgres.EmbeddedPostgres - Host string - Port int - DSN string - Conn *sql.DB - RuntimePath string + return embeddedPG } -// SetupPostgresContainer creates a new PostgreSQL test container -func SetupPostgresContainer(ctx context.Context, t *testing.T) *ContainerInfo { - return SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") -} +// ParseSQLToIR is a test helper that parses SQL and returns its IR representation. +// It applies the SQL to an embedded PostgreSQL instance, inspects it, and returns the IR. +// The schema will be reset (dropped and recreated) to ensure clean state between test calls. +// This ensures tests use the same code path as production (database inspection) rather than parsing. +func ParseSQLToIR(t *testing.T, embeddedPG *postgres.EmbeddedPostgres, sqlContent string, schema string) *ir.IR { + t.Helper() -// SetupPostgresContainerWithDB creates a new PostgreSQL instance with custom database settings -func SetupPostgresContainerWithDB(ctx context.Context, t *testing.T, database, username, password string) *ContainerInfo { - // Extract test name and create unique runtime path - testName := strings.ReplaceAll(t.Name(), "/", "_") // Replace slashes for subtest names - timestamp := time.Now().Format("20060102_150405.000000000") - runtimePath := filepath.Join(os.TempDir(), fmt.Sprintf("pgschema-test-%s-%s", testName, timestamp)) + ctx := context.Background() - // Find an available port - port, err := findAvailablePort() - if err != nil { - t.Fatalf("Failed to find available port: %v", err) - } - - // Configure embedded postgres with unique runtime path and dynamic port - config := embeddedpostgres.DefaultConfig(). - Version(getPostgresVersion()). - Database(database). - Username(username). - Password(password). - Port(uint32(port)). - RuntimePath(runtimePath). - DataPath(filepath.Join(runtimePath, "data")). - Logger(io.Discard). // Suppress embedded-postgres startup logs - StartParameters(map[string]string{ - "logging_collector": "off", // Disable log collector - "log_destination": "stderr", // Send logs to stderr (which we discard above) - "log_min_messages": "PANIC", // Only log PANIC level messages - "log_statement": "none", // Don't log SQL statements - "log_min_duration_statement": "-1", // Don't log slow queries - }) - - // Create and start PostgreSQL instance - postgres := embeddedpostgres.NewDatabase(config) - err = postgres.Start() - if err != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) - } + // Get connection details from embedded postgres + host, port, database, username, password := embeddedPG.GetConnectionDetails() // Build connection string - host := "localhost" - testDSN := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", + dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", username, password, host, port, database) // Connect to database - conn, err := sql.Open("pgx", testDSN) + conn, err := sql.Open("pgx", dsn) if err != nil { - postgres.Stop() t.Fatalf("Failed to connect to database: %v", err) } + defer conn.Close() // Test the connection if err := conn.PingContext(ctx); err != nil { - conn.Close() - postgres.Stop() t.Fatalf("Failed to ping database: %v", err) } - return &ContainerInfo{ - Database: postgres, - Host: host, - Port: port, - DSN: testDSN, - Conn: conn, - RuntimePath: runtimePath, + // Drop and recreate schema for clean state + dropSchema := fmt.Sprintf("DROP SCHEMA IF EXISTS \"%s\" CASCADE", schema) + if _, err := conn.ExecContext(ctx, dropSchema); err != nil { + t.Fatalf("Failed to drop schema: %v", err) + } + createSchema := fmt.Sprintf("CREATE SCHEMA \"%s\"", schema) + if _, err := conn.ExecContext(ctx, createSchema); err != nil { + t.Fatalf("Failed to create schema: %v", err) } -} -// Terminate cleans up the database instance and connection -func (ci *ContainerInfo) Terminate(ctx context.Context, t *testing.T) { - ci.Conn.Close() - if err := ci.Database.Stop(); err != nil { - t.Logf("Failed to stop embedded postgres: %v", err) + // Set search_path to target schema + setSearchPathSQL := fmt.Sprintf("SET search_path TO \"%s\"", schema) + if _, err := conn.ExecContext(ctx, setSearchPathSQL); err != nil { + t.Fatalf("Failed to set search_path: %v", err) } - // Clean up the runtime directory - if ci.RuntimePath != "" { - if err := os.RemoveAll(ci.RuntimePath); err != nil { - t.Logf("Failed to clean up runtime directory: %v", err) - } + + // Execute the SQL + if _, err := conn.ExecContext(ctx, sqlContent); err != nil { + t.Fatalf("Failed to apply SQL to embedded PostgreSQL: %v", err) + } + + // Inspect the database to get IR + inspector := ir.NewInspector(conn, nil) + irResult, err := inspector.BuildIR(ctx, schema) + if err != nil { + t.Fatalf("Failed to inspect embedded PostgreSQL: %v", err) } + + return irResult } -// SetEnvPassword sets the PGPASSWORD environment variable -func SetEnvPassword(password string) { - os.Setenv("PGPASSWORD", password) +// ConnectToPostgres connects to an embedded PostgreSQL instance and returns connection details. +// This is a helper for tests that need database connection information. +// The caller is responsible for closing the returned *sql.DB connection. +func ConnectToPostgres(t testing.TB, embeddedPG *postgres.EmbeddedPostgres) (conn *sql.DB, host string, port int, dbname, user, password string) { + t.Helper() + + ctx := context.Background() + + // Get connection details from embedded postgres + host, port, dbname, user, password = embeddedPG.GetConnectionDetails() + + // Build connection string + dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", + user, password, host, port, dbname) + + // Connect to database + conn, err := sql.Open("pgx", dsn) + if err != nil { + t.Fatalf("Failed to connect to database: %v", err) + } + + // Test the connection + if err := conn.PingContext(ctx); err != nil { + conn.Close() + t.Fatalf("Failed to ping database: %v", err) + } + + return conn, host, port, dbname, user, password } -// TestConnectionConfig stores connection settings for save/restore operations -type TestConnectionConfig struct { - Host string - Port int - DB string - User string - Schema string +// getPostgresVersion returns the PostgreSQL version to use for testing. +// It reads from the PGSCHEMA_POSTGRES_VERSION environment variable, +// defaulting to "17" if not set. +func getPostgresVersion() postgres.PostgresVersion { + versionStr := os.Getenv("PGSCHEMA_POSTGRES_VERSION") + switch versionStr { + case "14": + return postgres.PostgresVersion("14.18.0") + case "15": + return postgres.PostgresVersion("15.13.0") + case "16": + return postgres.PostgresVersion("16.9.0") + case "17", "": + return postgres.PostgresVersion("17.5.0") + default: + return postgres.PostgresVersion("17.5.0") + } }