From 9e10e79e8c9990dd8e8c58687ffc68e7e69748e4 Mon Sep 17 00:00:00 2001 From: tianzhou Date: Thu, 23 Oct 2025 19:01:20 +0800 Subject: [PATCH 1/6] refactor: simplify embeded pg --- cmd/apply/apply.go | 5 +- cmd/apply/apply_integration_test.go | 16 +- cmd/dump/dump_integration_test.go | 4 +- cmd/dump/dump_permission_integration_test.go | 12 +- cmd/ignore_integration_test.go | 18 +- cmd/include_integration_test.go | 6 +- cmd/migrate_integration_test.go | 8 +- cmd/plan/plan.go | 26 +- cmd/plan/plan_integration_test.go | 8 +- cmd/schema_integration_test.go | 2 +- cmd/util/embedded_postgres.go | 247 ---------- cmd/util/embedded_postgres_test.go | 220 --------- cmd/util/embedded_postgres_test_helper.go | 50 -- cmd/util/postgres_version.go | 84 ---- cmd/util/postgres_version_test.go | 143 ------ internal/diff/diff_test.go | 23 +- internal/plan/plan_test.go | 79 +--- ir/testutil.go | 281 ------------ testutil/postgres.go | 451 ++++++++++++++++++- 19 files changed, 520 insertions(+), 1163 deletions(-) delete mode 100644 cmd/util/embedded_postgres.go delete mode 100644 cmd/util/embedded_postgres_test.go delete mode 100644 cmd/util/embedded_postgres_test_helper.go delete mode 100644 cmd/util/postgres_version.go delete mode 100644 cmd/util/postgres_version_test.go delete mode 100644 ir/testutil.go diff --git a/cmd/apply/apply.go b/cmd/apply/apply.go index 8bca60cc..f0a2b438 100644 --- a/cmd/apply/apply.go +++ b/cmd/apply/apply.go @@ -14,6 +14,7 @@ import ( "github.com/pgschema/pgschema/internal/plan" "github.com/pgschema/pgschema/internal/version" "github.com/pgschema/pgschema/ir" + "github.com/pgschema/pgschema/testutil" "github.com/spf13/cobra" ) @@ -89,7 +90,7 @@ type ApplyConfig struct { // // If config.File is provided, embeddedPG is used to generate the plan. // The caller is responsible for managing the embeddedPG lifecycle (creation and cleanup). -func ApplyMigration(config *ApplyConfig, embeddedPG *util.EmbeddedPostgres) error { +func ApplyMigration(config *ApplyConfig, embeddedPG *testutil.EmbeddedPostgres) error { var migrationPlan *plan.Plan var err error @@ -253,7 +254,7 @@ func RunApply(cmd *cobra.Command, args []string) error { ApplicationName: applyApplicationName, } - var embeddedPG *util.EmbeddedPostgres + var embeddedPG *testutil.EmbeddedPostgres var err error // If using --plan flag, load plan from JSON file diff --git a/cmd/apply/apply_integration_test.go b/cmd/apply/apply_integration_test.go index 780d52d3..072070c5 100644 --- a/cmd/apply/apply_integration_test.go +++ b/cmd/apply/apply_integration_test.go @@ -7,9 +7,7 @@ import ( "strings" "testing" - embeddedpostgres "github.com/fergusstrange/embedded-postgres" planCmd "github.com/pgschema/pgschema/cmd/plan" - "github.com/pgschema/pgschema/cmd/util" "github.com/pgschema/pgschema/internal/plan" "github.com/pgschema/pgschema/testutil" ) @@ -17,14 +15,14 @@ import ( var ( // sharedEmbeddedPG is a shared embedded PostgreSQL instance used across all integration tests // to significantly improve test performance by avoiding repeated startup/teardown - sharedEmbeddedPG *util.EmbeddedPostgres + sharedEmbeddedPG *testutil.EmbeddedPostgres ) // TestMain sets up shared resources for all tests in this package func TestMain(m *testing.M) { // Create shared embedded postgres instance for all integration tests // This dramatically improves test performance by reusing the same instance - sharedEmbeddedPG = util.SetupSharedEmbeddedPostgres(nil, embeddedpostgres.PostgresVersion("17.5.0")) + sharedEmbeddedPG = testutil.SetupSharedEmbeddedPostgres(nil, testutil.PostgresVersion("17.5.0")) defer sharedEmbeddedPG.Stop() // Run tests @@ -62,7 +60,7 @@ func TestApplyCommand_TransactionRollback(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) // Setup database with initial schema @@ -319,7 +317,7 @@ func TestApplyCommand_CreateIndexConcurrently(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) // Setup database with initial schema @@ -520,7 +518,7 @@ func TestApplyCommand_WithPlanFile(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) // Setup database with initial schema @@ -689,7 +687,7 @@ func TestApplyCommand_FingerprintMismatch(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) // Setup database with initial schema @@ -880,7 +878,7 @@ func TestApplyCommand_WaitDirective(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) // Setup database with initial schema and data diff --git a/cmd/dump/dump_integration_test.go b/cmd/dump/dump_integration_test.go index 7b2b72f5..4de5406d 100644 --- a/cmd/dump/dump_integration_test.go +++ b/cmd/dump/dump_integration_test.go @@ -81,7 +81,7 @@ func runExactMatchTest(t *testing.T, testDataDir string) { func runExactMatchTestWithContext(t *testing.T, ctx context.Context, testDataDir string) { // Setup PostgreSQL container - containerInfo := testutil.SetupPostgresContainer(ctx, t) + containerInfo := testutil.SetupTestPostgres(ctx, t) defer containerInfo.Terminate(ctx, t) // Read and execute the pgdump.sql file @@ -139,7 +139,7 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { ctx := context.Background() // Setup PostgreSQL container - containerInfo := testutil.SetupPostgresContainer(ctx, t) + containerInfo := testutil.SetupTestPostgres(ctx, t) defer containerInfo.Terminate(ctx, t) // Load public schema types first diff --git a/cmd/dump/dump_permission_integration_test.go b/cmd/dump/dump_permission_integration_test.go index bf099e2a..51cfaa0e 100644 --- a/cmd/dump/dump_permission_integration_test.go +++ b/cmd/dump/dump_permission_integration_test.go @@ -31,7 +31,7 @@ func TestDumpCommand_PermissionSuite(t *testing.T) { ctx := context.Background() // Start single PostgreSQL container for all permission tests - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "postgres", "testpwd") + container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) // Run each permission test with its own isolated database @@ -45,7 +45,7 @@ func TestDumpCommand_PermissionSuite(t *testing.T) { } // setupTestDatabase creates a new database with permission test roles -func setupTestDatabase(ctx context.Context, t *testing.T, container *testutil.ContainerInfo, dbName string) *sql.DB { +func setupTestDatabase(ctx context.Context, t *testing.T, container *testutil.TestPostgres, dbName string) *sql.DB { // Create the database _, err := container.Conn.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", dbName)) if err != nil { @@ -74,8 +74,8 @@ func setupTestDatabase(ctx context.Context, t *testing.T, container *testutil.Co Host: container.Host, Port: container.Port, Database: dbName, - User: "postgres", - Password: "testpwd", + User: "testuser", + Password: "testpass", SSLMode: "prefer", ApplicationName: "pgschema", } @@ -102,7 +102,7 @@ func getRoleNames(dbName string) (restrictedRole string, regularUser string) { } // testIgnoredObjects tests procedures ignored via .pgschemaignore -func testIgnoredObjects(t *testing.T, ctx context.Context, container *testutil.ContainerInfo, dbName string) { +func testIgnoredObjects(t *testing.T, ctx context.Context, container *testutil.TestPostgres, dbName string) { // This test verifies that when procedures/functions are explicitly ignored // via .pgschemaignore, permission issues should not cause the dump to fail @@ -253,7 +253,7 @@ patterns = ["*_restricted"] // testProcedureAndFunctionSourceAccess tests that procedure and function source code is readable // via p.prosrc even when information_schema.routines.routine_definition is NULL -func testProcedureAndFunctionSourceAccess(t *testing.T, ctx context.Context, container *testutil.ContainerInfo, dbName string) { +func testProcedureAndFunctionSourceAccess(t *testing.T, ctx context.Context, container *testutil.TestPostgres, dbName string) { // Setup isolated database dbConn := setupTestDatabase(ctx, t, container, dbName) defer dbConn.Close() diff --git a/cmd/ignore_integration_test.go b/cmd/ignore_integration_test.go index 5ef1b2ad..5fe74275 100644 --- a/cmd/ignore_integration_test.go +++ b/cmd/ignore_integration_test.go @@ -31,7 +31,7 @@ func TestIgnoreIntegration(t *testing.T) { ctx := context.Background() // Setup PostgreSQL container - containerInfo := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + containerInfo := testutil.SetupTestPostgres(ctx, t) defer containerInfo.Terminate(ctx, t) // Create the test schema with various object types @@ -69,7 +69,7 @@ func TestIgnoreIntegration(t *testing.T) { t.Run("apply", func(t *testing.T) { // Create a fresh container for apply test to avoid fingerprint conflicts ctx := context.Background() - applyContainerInfo := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + applyContainerInfo := testutil.SetupTestPostgres(ctx, t) defer applyContainerInfo.Terminate(ctx, t) // Create the test schema in the fresh container @@ -266,7 +266,7 @@ patterns = ["seq_temp_*"] } // testIgnoreDump tests the dump command with ignore functionality -func testIgnoreDump(t *testing.T, containerInfo *testutil.ContainerInfo) { +func testIgnoreDump(t *testing.T, containerInfo *testutil.TestPostgres) { // Create .pgschemaignore file cleanup := createIgnoreFile(t) defer cleanup() @@ -281,7 +281,7 @@ func testIgnoreDump(t *testing.T, containerInfo *testutil.ContainerInfo) { // testIgnorePlanWithTriggerOnIgnoredTable tests that triggers can be defined on ignored tables // This tests the scenario where users manage triggers on externally-managed tables // without managing the table schema itself -func testIgnorePlanWithTriggerOnIgnoredTable(t *testing.T, containerInfo *testutil.ContainerInfo) { +func testIgnorePlanWithTriggerOnIgnoredTable(t *testing.T, containerInfo *testutil.TestPostgres) { // Create .pgschemaignore file - temp_* pattern will ignore temp_external_users cleanup := createIgnoreFile(t) defer cleanup() @@ -348,7 +348,7 @@ CREATE TRIGGER on_external_user_created } // testIgnorePlan tests the plan command with ignore functionality -func testIgnorePlan(t *testing.T, containerInfo *testutil.ContainerInfo) { +func testIgnorePlan(t *testing.T, containerInfo *testutil.TestPostgres) { // Create .pgschemaignore file cleanup := createIgnoreFile(t) defer cleanup() @@ -400,7 +400,7 @@ CREATE TABLE test_core_config ( // testIgnoreApply tests the apply command with ignore functionality // This test verifies that ignored objects are excluded from fingerprint calculation -func testIgnoreApply(t *testing.T, containerInfo *testutil.ContainerInfo) { +func testIgnoreApply(t *testing.T, containerInfo *testutil.TestPostgres) { // Create .pgschemaignore file cleanup := createIgnoreFile(t) defer cleanup() @@ -500,7 +500,7 @@ $$; } // executeIgnoreDumpCommand runs the dump command and returns the output -func executeIgnoreDumpCommand(t *testing.T, containerInfo *testutil.ContainerInfo) string { +func executeIgnoreDumpCommand(t *testing.T, containerInfo *testutil.TestPostgres) string { // Create a new root command with dump as subcommand rootCmd := &cobra.Command{ Use: "pgschema", @@ -547,7 +547,7 @@ func executeIgnoreDumpCommand(t *testing.T, containerInfo *testutil.ContainerInf } // executeIgnorePlanCommand runs the plan command and returns the output -func executeIgnorePlanCommand(t *testing.T, containerInfo *testutil.ContainerInfo, schemaFile string) string { +func executeIgnorePlanCommand(t *testing.T, containerInfo *testutil.TestPostgres, schemaFile string) string { // Create plan configuration with shared embedded postgres for performance config := &planCmd.PlanConfig{ Host: containerInfo.Host, @@ -571,7 +571,7 @@ func executeIgnorePlanCommand(t *testing.T, containerInfo *testutil.ContainerInf } // executeIgnoreApplyCommandWithError runs the apply command and returns any error -func executeIgnoreApplyCommandWithError(containerInfo *testutil.ContainerInfo, schemaFile string) error { +func executeIgnoreApplyCommandWithError(containerInfo *testutil.TestPostgres, schemaFile string) error { rootCmd := &cobra.Command{ Use: "pgschema", } diff --git a/cmd/include_integration_test.go b/cmd/include_integration_test.go index d146e36c..20439e08 100644 --- a/cmd/include_integration_test.go +++ b/cmd/include_integration_test.go @@ -29,7 +29,7 @@ func TestIncludeIntegration(t *testing.T) { ctx := context.Background() // Setup PostgreSQL container with specific database - containerInfo := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + containerInfo := testutil.SetupTestPostgres(ctx, t) defer containerInfo.Terminate(ctx, t) // Apply the include-based schema using the apply command @@ -47,7 +47,7 @@ func TestIncludeIntegration(t *testing.T) { } // applyIncludeSchema applies the testdata/include/main.sql schema using the apply command -func applyIncludeSchema(t *testing.T, containerInfo *testutil.ContainerInfo) { +func applyIncludeSchema(t *testing.T, containerInfo *testutil.TestPostgres) { mainSQLPath := "../testdata/include/main.sql" // Create a new root command with apply as subcommand @@ -81,7 +81,7 @@ func applyIncludeSchema(t *testing.T, containerInfo *testutil.ContainerInfo) { } // executeMultiFileDump runs pgschema dump --multi-file using the CLI command -func executeMultiFileDump(t *testing.T, containerInfo *testutil.ContainerInfo, outputPath string) { +func executeMultiFileDump(t *testing.T, containerInfo *testutil.TestPostgres, outputPath string) { // Create a new root command with dump as subcommand rootCmd := &cobra.Command{ Use: "pgschema", diff --git a/cmd/migrate_integration_test.go b/cmd/migrate_integration_test.go index de145f1c..72295c41 100644 --- a/cmd/migrate_integration_test.go +++ b/cmd/migrate_integration_test.go @@ -11,12 +11,10 @@ import ( "strings" "testing" - embeddedpostgres "github.com/fergusstrange/embedded-postgres" "github.com/google/go-cmp/cmp" _ "github.com/jackc/pgx/v5/stdlib" "github.com/pgschema/pgschema/cmd/apply" planCmd "github.com/pgschema/pgschema/cmd/plan" - "github.com/pgschema/pgschema/cmd/util" "github.com/pgschema/pgschema/internal/plan" "github.com/pgschema/pgschema/testutil" ) @@ -25,7 +23,7 @@ var ( generate = flag.Bool("generate", false, "generate expected test output files instead of comparing") // sharedEmbeddedPG is a shared embedded PostgreSQL instance used across all integration tests // to significantly improve test performance by avoiding repeated startup/teardown - sharedEmbeddedPG *util.EmbeddedPostgres + sharedEmbeddedPG *testutil.EmbeddedPostgres ) // TestMain sets up shared resources for all tests in this package @@ -35,7 +33,7 @@ func TestMain(m *testing.M) { // Create shared embedded postgres instance for all integration tests // This dramatically improves test performance (from ~60s to ~10s per test) - sharedEmbeddedPG = util.SetupSharedEmbeddedPostgres(nil, embeddedpostgres.PostgresVersion("17.5.0")) + sharedEmbeddedPG = testutil.SetupSharedEmbeddedPostgres(nil, testutil.PostgresVersion("17.5.0")) defer sharedEmbeddedPG.Stop() // Run tests @@ -79,7 +77,7 @@ func TestPlanAndApply(t *testing.T) { testDataRoot := "../testdata/diff" // Start a single PostgreSQL container for all test cases - container := testutil.SetupPostgresContainerWithDB(ctx, t, "postgres", "testuser", "testpass") + container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) containerHost := container.Host diff --git a/cmd/plan/plan.go b/cmd/plan/plan.go index 40e88ba2..a9c9c3cd 100644 --- a/cmd/plan/plan.go +++ b/cmd/plan/plan.go @@ -11,6 +11,7 @@ import ( "github.com/pgschema/pgschema/internal/fingerprint" "github.com/pgschema/pgschema/internal/include" "github.com/pgschema/pgschema/internal/plan" + "github.com/pgschema/pgschema/testutil" "github.com/spf13/cobra" ) @@ -122,30 +123,27 @@ type PlanConfig struct { // CreateEmbeddedPostgresForPlan creates a temporary embedded PostgreSQL instance // for validating the desired state schema. The instance should be stopped by the caller. -func CreateEmbeddedPostgresForPlan(config *PlanConfig) (*util.EmbeddedPostgres, error) { +func CreateEmbeddedPostgresForPlan(config *PlanConfig) (*testutil.EmbeddedPostgres, error) { // Detect target database PostgreSQL version - targetDBConfig := &util.ConnectionConfig{ - Host: config.Host, - Port: config.Port, - Database: config.DB, - User: config.User, - Password: config.Password, - SSLMode: "prefer", - ApplicationName: config.ApplicationName, - } - pgVersion, err := util.DetectPostgresVersionFromConfig(targetDBConfig) + pgVersion, err := testutil.DetectPostgresVersionFromDB( + config.Host, + config.Port, + config.DB, + config.User, + config.Password, + ) if err != nil { return nil, fmt.Errorf("failed to detect PostgreSQL version: %w", err) } // Start embedded PostgreSQL with matching version - embeddedConfig := &util.EmbeddedPostgresConfig{ + embeddedConfig := &testutil.EmbeddedPostgresConfig{ Version: pgVersion, Database: "pgschema_temp", Username: "pgschema", Password: "pgschema", } - embeddedPG, err := util.StartEmbeddedPostgres(embeddedConfig) + embeddedPG, err := testutil.StartEmbeddedPostgres(embeddedConfig) if err != nil { return nil, fmt.Errorf("failed to start embedded PostgreSQL: %w", err) } @@ -156,7 +154,7 @@ func CreateEmbeddedPostgresForPlan(config *PlanConfig) (*util.EmbeddedPostgres, // GeneratePlan generates a migration plan from configuration. // The caller must provide a non-nil embeddedPG instance for validating the desired state schema. // The caller is responsible for managing the embeddedPG lifecycle (creation and cleanup). -func GeneratePlan(config *PlanConfig, embeddedPG *util.EmbeddedPostgres) (*plan.Plan, error) { +func GeneratePlan(config *PlanConfig, embeddedPG *testutil.EmbeddedPostgres) (*plan.Plan, error) { // Load ignore configuration ignoreConfig, err := util.LoadIgnoreFileWithStructure() if err != nil { diff --git a/cmd/plan/plan_integration_test.go b/cmd/plan/plan_integration_test.go index 16063414..82b67065 100644 --- a/cmd/plan/plan_integration_test.go +++ b/cmd/plan/plan_integration_test.go @@ -20,7 +20,7 @@ func TestPlanCommand_DatabaseIntegration(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) // Setup database with initial schema @@ -117,7 +117,7 @@ func TestPlanCommand_OutputFormats(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) // Setup simple database schema @@ -212,7 +212,7 @@ func TestPlanCommand_SchemaFiltering(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) // Setup database with multiple schemas @@ -306,7 +306,7 @@ func TestPlanCommand_EmptyDatabase(t *testing.T) { var err error // Start PostgreSQL container with empty database - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) // Create desired state schema file diff --git a/cmd/schema_integration_test.go b/cmd/schema_integration_test.go index a8a67c45..b9f19cd3 100644 --- a/cmd/schema_integration_test.go +++ b/cmd/schema_integration_test.go @@ -25,7 +25,7 @@ func TestNonPublicSchemaOperations(t *testing.T) { ctx := context.Background() // Start PostgreSQL container - container := testutil.SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") + container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) conn := container.Conn diff --git a/cmd/util/embedded_postgres.go b/cmd/util/embedded_postgres.go deleted file mode 100644 index 5b765e95..00000000 --- a/cmd/util/embedded_postgres.go +++ /dev/null @@ -1,247 +0,0 @@ -package util - -import ( - "context" - "database/sql" - "fmt" - "io" - "net" - "os" - "path/filepath" - "time" - - embeddedpostgres "github.com/fergusstrange/embedded-postgres" - _ "github.com/jackc/pgx/v5/stdlib" - "github.com/pgschema/pgschema/internal/logger" -) - -// EmbeddedPostgres manages a temporary embedded PostgreSQL instance -type EmbeddedPostgres struct { - instance *embeddedpostgres.EmbeddedPostgres - db *sql.DB - version embeddedpostgres.PostgresVersion - host string - port int - database string - username string - password string - runtimePath string -} - -// EmbeddedPostgresConfig holds configuration for starting embedded PostgreSQL -type EmbeddedPostgresConfig struct { - Version embeddedpostgres.PostgresVersion - Database string - Username string - Password string -} - -// findAvailablePort finds an available TCP port for PostgreSQL to use -func findAvailablePort() (int, error) { - listener, err := net.Listen("tcp", ":0") - if err != nil { - return 0, err - } - defer listener.Close() - return listener.Addr().(*net.TCPAddr).Port, nil -} - -// StartEmbeddedPostgres starts a temporary embedded PostgreSQL instance -func StartEmbeddedPostgres(config *EmbeddedPostgresConfig) (*EmbeddedPostgres, error) { - log := logger.Get() - - // Create unique runtime path with timestamp (using nanoseconds for uniqueness) - timestamp := time.Now().Format("20060102_150405.000000000") - runtimePath := filepath.Join(os.TempDir(), fmt.Sprintf("pgschema-plan-%s", timestamp)) - - // Find an available port - port, err := findAvailablePort() - if err != nil { - return nil, fmt.Errorf("failed to find available port: %w", err) - } - - log.Debug("Starting embedded PostgreSQL", - "version", config.Version, - "port", port, - "database", config.Database, - "runtime_path", runtimePath, - ) - - // Configure embedded postgres - pgConfig := embeddedpostgres.DefaultConfig(). - Version(config.Version). - Database(config.Database). - Username(config.Username). - Password(config.Password). - Port(uint32(port)). - RuntimePath(runtimePath). - DataPath(filepath.Join(runtimePath, "data")). - Logger(io.Discard). // Suppress embedded-postgres startup logs - StartParameters(map[string]string{ - "logging_collector": "off", // Disable log collector - "log_destination": "stderr", // Send logs to stderr (which we discard) - "log_min_messages": "PANIC", // Only log PANIC level messages - "log_statement": "none", // Don't log SQL statements - "log_min_duration_statement": "-1", // Don't log slow queries - }) - - // Create and start PostgreSQL instance - instance := embeddedpostgres.NewDatabase(pgConfig) - if err := instance.Start(); err != nil { - return nil, fmt.Errorf("failed to start embedded PostgreSQL: %w", err) - } - - // Build connection string - host := "localhost" - dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", - config.Username, config.Password, host, port, config.Database) - - // Connect to database - db, err := sql.Open("pgx", dsn) - if err != nil { - instance.Stop() - os.RemoveAll(runtimePath) - return nil, fmt.Errorf("failed to connect to embedded PostgreSQL: %w", err) - } - - // Test the connection - ctx := context.Background() - if err := db.PingContext(ctx); err != nil { - db.Close() - instance.Stop() - os.RemoveAll(runtimePath) - return nil, fmt.Errorf("failed to ping embedded PostgreSQL: %w", err) - } - - log.Debug("Embedded PostgreSQL started successfully", - "host", host, - "port", port, - ) - - return &EmbeddedPostgres{ - instance: instance, - db: db, - version: config.Version, - host: host, - port: port, - database: config.Database, - username: config.Username, - password: config.Password, - runtimePath: runtimePath, - }, nil -} - -// GetDB returns the database connection -func (ep *EmbeddedPostgres) GetDB() *sql.DB { - return ep.db -} - -// GetConnectionInfo returns connection details -func (ep *EmbeddedPostgres) GetConnectionInfo() (host string, port int, database string) { - return ep.host, ep.port, ep.database -} - -// GetCredentials returns the username and password for the embedded PostgreSQL instance -func (ep *EmbeddedPostgres) GetCredentials() (username string, password string) { - return ep.username, ep.password -} - -// ResetSchema drops and recreates a schema, clearing all objects -// This is useful for tests that want to reuse the same embedded postgres instance -func (ep *EmbeddedPostgres) ResetSchema(ctx context.Context, schema string) error { - log := logger.Get() - log.Debug("Resetting schema in embedded PostgreSQL", - "schema", schema, - ) - - // Drop the schema if it exists (CASCADE to drop all objects) - dropSchemaSQL := fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", QuoteIdentifier(schema)) - if _, err := ep.db.ExecContext(ctx, dropSchemaSQL); err != nil { - return fmt.Errorf("failed to drop schema %s: %w", schema, err) - } - - // Recreate the schema - createSchemaSQL := fmt.Sprintf("CREATE SCHEMA %s", QuoteIdentifier(schema)) - if _, err := ep.db.ExecContext(ctx, createSchemaSQL); err != nil { - return fmt.Errorf("failed to create schema %s: %w", schema, err) - } - - log.Debug("Schema reset successfully") - return nil -} - -// ApplySchemaSQL applies SQL schema to the embedded PostgreSQL database -func (ep *EmbeddedPostgres) ApplySchemaSQL(ctx context.Context, schema string, sql string) error { - log := logger.Get() - log.Debug("Applying schema SQL to embedded PostgreSQL", - "schema", schema, - "sql_length", len(sql), - ) - - // Create the schema if it doesn't exist - createSchemaSQL := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", QuoteIdentifier(schema)) - if _, err := ep.db.ExecContext(ctx, createSchemaSQL); err != nil { - return fmt.Errorf("failed to create schema %s: %w", schema, err) - } - - // Set search_path to the target schema - setSearchPathSQL := fmt.Sprintf("SET search_path TO %s", QuoteIdentifier(schema)) - if _, err := ep.db.ExecContext(ctx, setSearchPathSQL); err != nil { - return fmt.Errorf("failed to set search_path: %w", err) - } - - // Execute the SQL directly - // Note: Desired state SQL should never contain operations like CREATE INDEX CONCURRENTLY - // that cannot run in transactions. Those are migration details, not state declarations. - if _, err := ep.db.ExecContext(ctx, sql); err != nil { - return fmt.Errorf("failed to apply schema SQL: %w", err) - } - - log.Debug("Schema SQL applied successfully") - return nil -} - -// Stop stops and cleans up the embedded PostgreSQL instance -func (ep *EmbeddedPostgres) Stop() error { - log := logger.Get() - log.Debug("Stopping embedded PostgreSQL", - "runtime_path", ep.runtimePath, - ) - - // Close database connection - if ep.db != nil { - ep.db.Close() - } - - // Stop PostgreSQL instance - var stopErr error - if ep.instance != nil { - stopErr = ep.instance.Stop() - } - - // Clean up runtime directory - if ep.runtimePath != "" { - if err := os.RemoveAll(ep.runtimePath); err != nil { - log.Debug("Failed to clean up runtime directory", - "path", ep.runtimePath, - "error", err, - ) - // Don't return error here - just log it - } - } - - if stopErr != nil { - return fmt.Errorf("failed to stop embedded PostgreSQL: %w", stopErr) - } - - log.Debug("Embedded PostgreSQL stopped and cleaned up") - return nil -} - -// QuoteIdentifier quotes a PostgreSQL identifier (schema, table, column name) -// This is a simple implementation - for production use, consider using pq.QuoteIdentifier -func QuoteIdentifier(identifier string) string { - // For now, just use the IR package's quote function - // In a production system, you might want to use a proper quoting library - return fmt.Sprintf("\"%s\"", identifier) -} diff --git a/cmd/util/embedded_postgres_test.go b/cmd/util/embedded_postgres_test.go deleted file mode 100644 index c1290de0..00000000 --- a/cmd/util/embedded_postgres_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package util - -import ( - "context" - "testing" - - embeddedpostgres "github.com/fergusstrange/embedded-postgres" -) - -func TestStartEmbeddedPostgres(t *testing.T) { - if testing.Short() { - t.Skip("Skipping embedded postgres test in short mode") - } - - config := &EmbeddedPostgresConfig{ - Version: embeddedpostgres.PostgresVersion("17.5.0"), - Database: "testdb", - Username: "testuser", - Password: "testpass", - } - - ep, err := StartEmbeddedPostgres(config) - if err != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) - } - defer ep.Stop() - - // Test connection - ctx := context.Background() - if err := ep.GetDB().PingContext(ctx); err != nil { - t.Fatalf("Failed to ping embedded postgres: %v", err) - } - - // Verify connection info - host, port, database := ep.GetConnectionInfo() - if host != "localhost" { - t.Errorf("Expected host 'localhost', got '%s'", host) - } - if port == 0 { - t.Error("Port should not be 0") - } - if database != "testdb" { - t.Errorf("Expected database 'testdb', got '%s'", database) - } -} - -func TestApplySchemaSQL(t *testing.T) { - if testing.Short() { - t.Skip("Skipping embedded postgres test in short mode") - } - - config := &EmbeddedPostgresConfig{ - Version: embeddedpostgres.PostgresVersion("17.5.0"), - Database: "testdb", - Username: "testuser", - Password: "testpass", - } - - ep, err := StartEmbeddedPostgres(config) - if err != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) - } - defer ep.Stop() - - ctx := context.Background() - - // Test applying schema SQL - schemaSQL := ` - CREATE TABLE users ( - id SERIAL PRIMARY KEY, - name TEXT NOT NULL, - email TEXT UNIQUE - ); - - CREATE INDEX idx_users_email ON users(email); - ` - - err = ep.ApplySchemaSQL(ctx, "public", schemaSQL) - if err != nil { - t.Fatalf("Failed to apply schema SQL: %v", err) - } - - // Verify table was created - var tableName string - query := "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users'" - err = ep.GetDB().QueryRowContext(ctx, query).Scan(&tableName) - if err != nil { - t.Fatalf("Failed to query table: %v", err) - } - if tableName != "users" { - t.Errorf("Expected table 'users', got '%s'", tableName) - } - - // Verify index was created - var indexName string - query = "SELECT indexname FROM pg_indexes WHERE tablename = 'users' AND indexname = 'idx_users_email'" - err = ep.GetDB().QueryRowContext(ctx, query).Scan(&indexName) - if err != nil { - t.Fatalf("Failed to query index: %v", err) - } - if indexName != "idx_users_email" { - t.Errorf("Expected index 'idx_users_email', got '%s'", indexName) - } -} - -func TestApplySchemaSQL_CustomSchema(t *testing.T) { - if testing.Short() { - t.Skip("Skipping embedded postgres test in short mode") - } - - config := &EmbeddedPostgresConfig{ - Version: embeddedpostgres.PostgresVersion("17.5.0"), - Database: "testdb", - Username: "testuser", - Password: "testpass", - } - - ep, err := StartEmbeddedPostgres(config) - if err != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) - } - defer ep.Stop() - - ctx := context.Background() - - // Test applying schema SQL to custom schema - schemaSQL := ` - CREATE TABLE products ( - id SERIAL PRIMARY KEY, - name TEXT NOT NULL, - price NUMERIC(10, 2) - ); - ` - - err = ep.ApplySchemaSQL(ctx, "myschema", schemaSQL) - if err != nil { - t.Fatalf("Failed to apply schema SQL: %v", err) - } - - // Verify schema was created - var schemaName string - query := "SELECT schema_name FROM information_schema.schemata WHERE schema_name = 'myschema'" - err = ep.GetDB().QueryRowContext(ctx, query).Scan(&schemaName) - if err != nil { - t.Fatalf("Failed to query schema: %v", err) - } - if schemaName != "myschema" { - t.Errorf("Expected schema 'myschema', got '%s'", schemaName) - } - - // Verify table was created in custom schema - var tableName string - query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'myschema' AND table_name = 'products'" - err = ep.GetDB().QueryRowContext(ctx, query).Scan(&tableName) - if err != nil { - t.Fatalf("Failed to query table: %v", err) - } - if tableName != "products" { - t.Errorf("Expected table 'products', got '%s'", tableName) - } -} - -func TestApplySchemaSQL_InvalidSQL(t *testing.T) { - if testing.Short() { - t.Skip("Skipping embedded postgres test in short mode") - } - - config := &EmbeddedPostgresConfig{ - Version: embeddedpostgres.PostgresVersion("17.5.0"), - Database: "testdb", - Username: "testuser", - Password: "testpass", - } - - ep, err := StartEmbeddedPostgres(config) - if err != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) - } - defer ep.Stop() - - ctx := context.Background() - - // Test with invalid SQL - invalidSQL := "CREATE TABLE invalid syntax here" - err = ep.ApplySchemaSQL(ctx, "public", invalidSQL) - if err == nil { - t.Error("Expected error for invalid SQL, got none") - } -} - -func TestStop(t *testing.T) { - if testing.Short() { - t.Skip("Skipping embedded postgres test in short mode") - } - - config := &EmbeddedPostgresConfig{ - Version: embeddedpostgres.PostgresVersion("17.5.0"), - Database: "testdb", - Username: "testuser", - Password: "testpass", - } - - ep, err := StartEmbeddedPostgres(config) - if err != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) - } - - // Test stopping - err = ep.Stop() - if err != nil { - t.Fatalf("Failed to stop embedded postgres: %v", err) - } - - // Verify connection is closed - ctx := context.Background() - err = ep.GetDB().PingContext(ctx) - if err == nil { - t.Error("Expected error pinging stopped postgres, got none") - } -} diff --git a/cmd/util/embedded_postgres_test_helper.go b/cmd/util/embedded_postgres_test_helper.go deleted file mode 100644 index b81b3add..00000000 --- a/cmd/util/embedded_postgres_test_helper.go +++ /dev/null @@ -1,50 +0,0 @@ -package util - -import ( - "testing" - - embeddedpostgres "github.com/fergusstrange/embedded-postgres" -) - -// SetupSharedEmbeddedPostgres creates a shared embedded PostgreSQL instance for test suites. -// This instance can be reused across multiple test cases to significantly improve test performance. -// -// Usage example: -// -// func TestMain(m *testing.M) { -// // Create shared embedded postgres for all tests -// embeddedPG := util.SetupSharedEmbeddedPostgres(nil, embeddedpostgres.PostgresVersion("17.5.0")) -// defer embeddedPG.Stop() -// -// // Run tests -// code := m.Run() -// os.Exit(code) -// } -// -// func TestMyFeature(t *testing.T) { -// config := &plan.PlanConfig{ -// // ... other config ... -// EmbeddedPG: embeddedPG, // Reuse shared instance -// } -// plan, err := plan.GeneratePlan(config) -// // ... -// } -func SetupSharedEmbeddedPostgres(t testing.TB, version embeddedpostgres.PostgresVersion) *EmbeddedPostgres { - config := &EmbeddedPostgresConfig{ - Version: version, - Database: "testdb", - Username: "testuser", - Password: "testpass", - } - - embeddedPG, err := StartEmbeddedPostgres(config) - if err != nil { - if t != nil { - t.Fatalf("Failed to start shared embedded PostgreSQL: %v", err) - } else { - panic("Failed to start shared embedded PostgreSQL: " + err.Error()) - } - } - - return embeddedPG -} diff --git a/cmd/util/postgres_version.go b/cmd/util/postgres_version.go deleted file mode 100644 index cb951ee1..00000000 --- a/cmd/util/postgres_version.go +++ /dev/null @@ -1,84 +0,0 @@ -package util - -import ( - "context" - "database/sql" - "fmt" - "strconv" - "strings" - - embeddedpostgres "github.com/fergusstrange/embedded-postgres" -) - -// DetectPostgresVersion queries the target database to determine its PostgreSQL version -// and returns the corresponding embedded-postgres version string -func DetectPostgresVersion(db *sql.DB) (embeddedpostgres.PostgresVersion, error) { - ctx := context.Background() - - // Query PostgreSQL version number (e.g., 170005 for 17.5) - var versionNum int - err := db.QueryRowContext(ctx, "SHOW server_version_num").Scan(&versionNum) - if err != nil { - return "", fmt.Errorf("failed to query PostgreSQL version: %w", err) - } - - // Extract major version: version_num / 10000 - // e.g., 170005 / 10000 = 17 - majorVersion := versionNum / 10000 - - // Map to embedded-postgres version - return mapToEmbeddedPostgresVersion(majorVersion) -} - -// DetectPostgresVersionFromConfig queries the target database using connection config -// and returns the corresponding embedded-postgres version string -func DetectPostgresVersionFromConfig(config *ConnectionConfig) (embeddedpostgres.PostgresVersion, error) { - // Connect to target database - db, err := Connect(config) - if err != nil { - return "", fmt.Errorf("failed to connect to detect version: %w", err) - } - defer db.Close() - - return DetectPostgresVersion(db) -} - -// mapToEmbeddedPostgresVersion maps a PostgreSQL major version to embedded-postgres version -// Supported versions: 14, 15, 16, 17 -func mapToEmbeddedPostgresVersion(majorVersion int) (embeddedpostgres.PostgresVersion, error) { - switch majorVersion { - case 14: - return embeddedpostgres.PostgresVersion("14.18.0"), nil - case 15: - return embeddedpostgres.PostgresVersion("15.13.0"), nil - case 16: - return embeddedpostgres.PostgresVersion("16.9.0"), nil - case 17: - return embeddedpostgres.PostgresVersion("17.5.0"), nil - default: - return "", fmt.Errorf("unsupported PostgreSQL version %d (supported: 14, 15, 16, 17)", majorVersion) - } -} - -// ParseVersionString parses a PostgreSQL version string (e.g., "17.5") and returns major version -func ParseVersionString(versionStr string) (int, error) { - // Handle various formats: "17.5", "17.5.0", "PostgreSQL 17.5", etc. - // Extract the version number part - versionStr = strings.TrimSpace(versionStr) - - // Remove "PostgreSQL " prefix if present - versionStr = strings.TrimPrefix(versionStr, "PostgreSQL ") - - // Split by "." and take the first part (major version) - parts := strings.Split(versionStr, ".") - if len(parts) == 0 { - return 0, fmt.Errorf("invalid version string: %s", versionStr) - } - - majorVersion, err := strconv.Atoi(parts[0]) - if err != nil { - return 0, fmt.Errorf("failed to parse major version from %s: %w", versionStr, err) - } - - return majorVersion, nil -} diff --git a/cmd/util/postgres_version_test.go b/cmd/util/postgres_version_test.go deleted file mode 100644 index a9661786..00000000 --- a/cmd/util/postgres_version_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package util - -import ( - "testing" - - embeddedpostgres "github.com/fergusstrange/embedded-postgres" -) - -func TestMapToEmbeddedPostgresVersion(t *testing.T) { - tests := []struct { - name string - majorVersion int - expected embeddedpostgres.PostgresVersion - expectError bool - }{ - { - name: "PostgreSQL 14", - majorVersion: 14, - expected: embeddedpostgres.PostgresVersion("14.18.0"), - expectError: false, - }, - { - name: "PostgreSQL 15", - majorVersion: 15, - expected: embeddedpostgres.PostgresVersion("15.13.0"), - expectError: false, - }, - { - name: "PostgreSQL 16", - majorVersion: 16, - expected: embeddedpostgres.PostgresVersion("16.9.0"), - expectError: false, - }, - { - name: "PostgreSQL 17", - majorVersion: 17, - expected: embeddedpostgres.PostgresVersion("17.5.0"), - expectError: false, - }, - { - name: "Unsupported version 13", - majorVersion: 13, - expected: "", - expectError: true, - }, - { - name: "Unsupported version 18", - majorVersion: 18, - expected: "", - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := mapToEmbeddedPostgresVersion(tt.majorVersion) - - if tt.expectError { - if err == nil { - t.Errorf("expected error but got none") - } - } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if result != tt.expected { - t.Errorf("expected version %s, got %s", tt.expected, result) - } - } - }) - } -} - -func TestParseVersionString(t *testing.T) { - tests := []struct { - name string - versionStr string - expected int - expectError bool - }{ - { - name: "Simple version 17.5", - versionStr: "17.5", - expected: 17, - expectError: false, - }, - { - name: "Version with patch 17.5.0", - versionStr: "17.5.0", - expected: 17, - expectError: false, - }, - { - name: "Version with prefix", - versionStr: "PostgreSQL 17.5", - expected: 17, - expectError: false, - }, - { - name: "Version 14.18.0", - versionStr: "14.18.0", - expected: 14, - expectError: false, - }, - { - name: "Version with whitespace", - versionStr: " 16.9 ", - expected: 16, - expectError: false, - }, - { - name: "Invalid version", - versionStr: "invalid", - expected: 0, - expectError: true, - }, - { - name: "Empty string", - versionStr: "", - expected: 0, - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := ParseVersionString(tt.versionStr) - - if tt.expectError { - if err == nil { - t.Errorf("expected error but got none") - } - } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if result != tt.expected { - t.Errorf("expected version %d, got %d", tt.expected, result) - } - } - }) - } -} diff --git a/internal/diff/diff_test.go b/internal/diff/diff_test.go index 4f6aedf5..5e89b437 100644 --- a/internal/diff/diff_test.go +++ b/internal/diff/diff_test.go @@ -8,14 +8,18 @@ import ( "testing" "github.com/pgschema/pgschema/ir" + "github.com/pgschema/pgschema/testutil" ) +// sharedTestPostgres is the shared embedded postgres instance for all tests in this package +var sharedTestPostgres *testutil.TestPostgres + // TestMain sets up shared resources for all tests in this package func TestMain(m *testing.M) { // Create shared embedded postgres for all tests to dramatically improve performance ctx := context.Background() - container := ir.SetupSharedTestContainer(ctx, nil) - defer container.Terminate(ctx, nil) + sharedTestPostgres = testutil.SetupSharedTestPostgres(ctx, nil) + defer sharedTestPostgres.Terminate(ctx, nil) // Run tests code := m.Run() @@ -53,7 +57,20 @@ func buildSQLFromSteps(diffs []Diff) string { // parseSQL is a helper function to convert SQL string to IR for tests // Uses embedded PostgreSQL to ensure tests use the same code path as production func parseSQL(t *testing.T, sql string) *ir.IR { - return ir.ParseSQLForTest(t, sql, "public") + t.Helper() + + // Use testutil to apply SQL to embedded postgres + conn := testutil.ParseSQLForTest(t, sharedTestPostgres, sql, "public") + + // Inspect the database to get IR + ctx := context.Background() + inspector := ir.NewInspector(conn, nil) + irResult, err := inspector.BuildIR(ctx, "public") + if err != nil { + t.Fatalf("Failed to inspect embedded PostgreSQL: %v", err) + } + + return irResult } // TestDiffFromFiles runs file-based diff tests from testdata directory. diff --git a/internal/plan/plan_test.go b/internal/plan/plan_test.go index ec7cacb7..a21a38d9 100644 --- a/internal/plan/plan_test.go +++ b/internal/plan/plan_test.go @@ -14,14 +14,18 @@ import ( "github.com/google/go-cmp/cmp" "github.com/pgschema/pgschema/internal/diff" "github.com/pgschema/pgschema/ir" + "github.com/pgschema/pgschema/testutil" ) +// sharedTestPostgres is the shared embedded postgres instance for all tests in this package +var sharedTestPostgres *testutil.TestPostgres + // TestMain sets up shared resources for all tests in this package func TestMain(m *testing.M) { // Create shared embedded postgres for all tests to dramatically improve performance ctx := context.Background() - container := ir.SetupSharedTestContainer(ctx, nil) - defer container.Terminate(ctx, nil) + sharedTestPostgres = testutil.SetupSharedTestPostgres(ctx, nil) + defer sharedTestPostgres.Terminate(ctx, nil) // Run tests code := m.Run() @@ -54,7 +58,20 @@ func discoverTestDataVersions(testdataDir string) ([]string, error) { // parseSQL is a helper function to convert SQL string to IR for tests // Uses embedded PostgreSQL to ensure tests use the same code path as production func parseSQL(t *testing.T, sql string) *ir.IR { - return ir.ParseSQLForTest(t, sql, "public") + t.Helper() + + // Use testutil to apply SQL to embedded postgres + conn := testutil.ParseSQLForTest(t, sharedTestPostgres, sql, "public") + + // Inspect the database to get IR + ctx := context.Background() + inspector := ir.NewInspector(conn, nil) + irResult, err := inspector.BuildIR(ctx, "public") + if err != nil { + t.Fatalf("Failed to inspect embedded PostgreSQL: %v", err) + } + + return irResult } func TestPlanSummary(t *testing.T) { @@ -178,62 +195,6 @@ func TestPlanJSONRoundTrip(t *testing.T) { } } -func TestPlanToJSON(t *testing.T) { - oldSQL := `CREATE TABLE users ( - id integer NOT NULL - );` - - newSQL := `CREATE TABLE users ( - id integer NOT NULL, - name text NOT NULL - );` - - oldIR := parseSQL(t, oldSQL) - newIR := parseSQL(t, newSQL) - diffs := diff.GenerateMigration(oldIR, newIR, "public") - - plan := NewPlan(diffs) - - // Test non-debug version (default behavior) - should NOT contain source field - jsonOutput, err := plan.ToJSON() - if err != nil { - t.Fatalf("Failed to generate JSON: %v", err) - } - - if !strings.Contains(jsonOutput, `"groups"`) { - t.Error("JSON output should contain groups") - } - - // Non-debug version should NOT contain source field - if strings.Contains(jsonOutput, `"source"`) { - t.Error("JSON output should NOT contain source field when debug is disabled") - } - - // Test debug version - should contain source field - jsonDebugOutput, err := plan.ToJSONWithDebug(true) - if err != nil { - t.Fatalf("Failed to generate debug JSON: %v", err) - } - - if !strings.Contains(jsonDebugOutput, `"groups"`) { - t.Error("Debug JSON output should contain groups") - } - - // Debug version should still work (but Steps don't have source field anymore) - // This is expected behavior after refactoring to Step structure - if jsonDebugOutput == "" { - t.Error("Debug JSON output should not be empty") - } - - if !strings.Contains(jsonOutput, `"version"`) { - t.Error("JSON output should contain version") - } - - if !strings.Contains(jsonOutput, `"created_at"`) { - t.Error("JSON output should contain created_at timestamp") - } -} - func TestPlanNoChanges(t *testing.T) { sql := `CREATE TABLE users ( id integer NOT NULL diff --git a/ir/testutil.go b/ir/testutil.go deleted file mode 100644 index ee9c0fe0..00000000 --- a/ir/testutil.go +++ /dev/null @@ -1,281 +0,0 @@ -// Package ir provides an intermediate representation for PostgreSQL schemas -package ir - -import ( - "context" - "database/sql" - "fmt" - "io" - "net" - "os" - "path/filepath" - "strings" - "testing" - "time" - - embeddedpostgres "github.com/fergusstrange/embedded-postgres" - _ "github.com/jackc/pgx/v5/stdlib" -) - -// getPostgresVersion returns the PostgreSQL version to use for testing. -// It reads from the PGSCHEMA_POSTGRES_VERSION environment variable, -// defaulting to "17" if not set. -// Returns an error if an unsupported version is specified. -func getPostgresVersion() (embeddedpostgres.PostgresVersion, error) { - versionStr := os.Getenv("PGSCHEMA_POSTGRES_VERSION") - if versionStr == "" { - return embeddedpostgres.PostgresVersion("17.5.0"), nil - } - - switch versionStr { - case "14": - return embeddedpostgres.PostgresVersion("14.18.0"), nil - case "15": - return embeddedpostgres.PostgresVersion("15.13.0"), nil - case "16": - return embeddedpostgres.PostgresVersion("16.9.0"), nil - case "17": - return embeddedpostgres.PostgresVersion("17.5.0"), nil - default: - return "", fmt.Errorf("unsupported PGSCHEMA_POSTGRES_VERSION: %s (supported versions: 14, 15, 16, 17)", versionStr) - } -} - -// findAvailablePort finds an available TCP port for PostgreSQL to use -func findAvailablePort() (int, error) { - listener, err := net.Listen("tcp", ":0") - if err != nil { - return 0, err - } - defer listener.Close() - return listener.Addr().(*net.TCPAddr).Port, nil -} - -// ContainerInfo holds PostgreSQL instance connection details for testing -type ContainerInfo struct { - Database *embeddedpostgres.EmbeddedPostgres - Host string - Port int - DSN string - Conn *sql.DB - RuntimePath string -} - -// setupPostgresContainer creates a new PostgreSQL test container -func setupPostgresContainer(ctx context.Context, t *testing.T) *ContainerInfo { - return setupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") -} - -// fatalError handles test failures - uses t.Fatalf if available, otherwise panics -func fatalError(t *testing.T, format string, args ...interface{}) { - if t != nil { - t.Fatalf(format, args...) - } else { - panic(fmt.Sprintf(format, args...)) - } -} - -// setupPostgresContainerWithDB creates a new PostgreSQL instance with custom database settings -func setupPostgresContainerWithDB(ctx context.Context, t *testing.T, database, username, password string) *ContainerInfo { - // Extract test name and create unique runtime path - testName := "shared" - if t != nil { - testName = strings.ReplaceAll(t.Name(), "/", "_") // Replace slashes for subtest names - } - timestamp := time.Now().Format("20060102_150405.000000000") - runtimePath := filepath.Join(os.TempDir(), fmt.Sprintf("pgschema-test-%s-%s", testName, timestamp)) - - // Find an available port - port, err := findAvailablePort() - if err != nil { - fatalError(t, "Failed to find available port: %v", err) - } - - // Get PostgreSQL version - pgVersion, err := getPostgresVersion() - if err != nil { - fatalError(t, "Failed to get PostgreSQL version: %v", err) - } - - // Configure embedded postgres with unique runtime path and dynamic port - config := embeddedpostgres.DefaultConfig(). - Version(pgVersion). - Database(database). - Username(username). - Password(password). - Port(uint32(port)). - RuntimePath(runtimePath). - DataPath(filepath.Join(runtimePath, "data")). - Logger(io.Discard). // Suppress embedded-postgres startup logs - StartParameters(map[string]string{ - "logging_collector": "off", // Disable log collector - "log_destination": "stderr", // Send logs to stderr (which we discard above) - "log_min_messages": "PANIC", // Only log PANIC level messages - "log_statement": "none", // Don't log SQL statements - "log_min_duration_statement": "-1", // Don't log slow queries - }) - - // Create and start PostgreSQL instance - postgres := embeddedpostgres.NewDatabase(config) - err = postgres.Start() - if err != nil { - fatalError(t, "Failed to start embedded postgres: %v", err) - } - - // Build connection string - host := "localhost" - testDSN := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", - username, password, host, port, database) - - // Connect to database - conn, err := sql.Open("pgx", testDSN) - if err != nil { - postgres.Stop() - fatalError(t, "Failed to connect to database: %v", err) - } - - // Test the connection - if err := conn.PingContext(ctx); err != nil { - conn.Close() - postgres.Stop() - fatalError(t, "Failed to ping database: %v", err) - } - - return &ContainerInfo{ - Database: postgres, - Host: host, - Port: port, - DSN: testDSN, - Conn: conn, - RuntimePath: runtimePath, - } -} - -// terminate cleans up the database instance and connection -func (ci *ContainerInfo) terminate(ctx context.Context, t *testing.T) { - ci.Conn.Close() - if err := ci.Database.Stop(); err != nil { - t.Logf("Failed to stop embedded postgres: %v", err) - } - // Clean up the runtime directory - if ci.RuntimePath != "" { - if err := os.RemoveAll(ci.RuntimePath); err != nil { - t.Logf("Failed to clean up runtime directory: %v", err) - } - } -} - -// sharedTestContainer holds an optional shared embedded postgres instance for tests -var sharedTestContainer *ContainerInfo - -// SetSharedTestContainer sets a shared embedded postgres instance for ParseSQLForTest to reuse. -// This significantly improves test performance by avoiding repeated postgres startup/shutdown. -// -// Usage in test packages: -// -// func TestMain(m *testing.M) { -// ctx := context.Background() -// container := ir.SetupSharedTestContainer(ctx, nil) -// defer container.Terminate(ctx, nil) -// -// code := m.Run() -// os.Exit(code) -// } -func SetupSharedTestContainer(ctx context.Context, t testing.TB) *ContainerInfo { - // Convert testing.TB to *testing.T if needed - var tt *testing.T - if t != nil { - if tPtr, ok := t.(*testing.T); ok { - tt = tPtr - } else { - // For testing.TB that's not *testing.T (like *testing.M), create a dummy *testing.T - // This is safe because setupPostgresContainer only uses t for Fatalf on errors - panic("SetupSharedTestContainer requires *testing.T or nil") - } - } - container := setupPostgresContainer(ctx, tt) - sharedTestContainer = container - return container -} - -// Terminate cleans up the container (exported for use by test packages) -func (ci *ContainerInfo) Terminate(ctx context.Context, t testing.TB) { - // Convert testing.TB to *testing.T if needed - var tt *testing.T - if t != nil { - if tPtr, ok := t.(*testing.T); ok { - tt = tPtr - } - // For nil or other types, tt remains nil which is fine for terminate - } - ci.terminate(ctx, tt) -} - -// ParseSQLForTest is a test helper that converts SQL to IR using embedded PostgreSQL. -// This replaces the old parser-based approach for tests. -// -// If a shared test container has been set via SetupSharedTestContainer, it will be reused -// (with the schema reset between calls). Otherwise, a new temporary instance is created. -// -// This ensures tests use the same code path as production (database inspection) rather than parsing. -func ParseSQLForTest(t *testing.T, sqlContent string, schema string) *IR { - t.Helper() - - ctx := context.Background() - - var conn *sql.DB - var needsCleanup bool - - if sharedTestContainer != nil { - // Reuse shared container - reset the schema for clean state - conn = sharedTestContainer.Conn - needsCleanup = false - - // Drop and recreate schema - dropSchema := fmt.Sprintf("DROP SCHEMA IF EXISTS \"%s\" CASCADE", schema) - if _, err := conn.ExecContext(ctx, dropSchema); err != nil { - t.Fatalf("Failed to drop schema: %v", err) - } - createSchema := fmt.Sprintf("CREATE SCHEMA \"%s\"", schema) - if _, err := conn.ExecContext(ctx, createSchema); err != nil { - t.Fatalf("Failed to create schema: %v", err) - } - } else { - // Create new container for this test - container := setupPostgresContainer(ctx, t) - defer container.terminate(ctx, t) - conn = container.Conn - needsCleanup = true - - // Create schema if not public - if schema != "public" { - createSchemaSQL := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS \"%s\"", schema) - if _, err := conn.ExecContext(ctx, createSchemaSQL); err != nil { - t.Fatalf("Failed to create schema: %v", err) - } - } - } - - // Set search_path to target schema - setSearchPathSQL := fmt.Sprintf("SET search_path TO \"%s\"", schema) - if _, err := conn.ExecContext(ctx, setSearchPathSQL); err != nil { - t.Fatalf("Failed to set search_path: %v", err) - } - - // Execute the SQL - if _, err := conn.ExecContext(ctx, sqlContent); err != nil { - t.Fatalf("Failed to apply SQL to embedded PostgreSQL: %v", err) - } - - // Inspect the database to get IR - inspector := NewInspector(conn, nil) - ir, err := inspector.BuildIR(ctx, schema) - if err != nil { - t.Fatalf("Failed to inspect embedded PostgreSQL: %v", err) - } - - // If we created a container just for this test, cleanup happens via defer above - _ = needsCleanup - - return ir -} \ No newline at end of file diff --git a/testutil/postgres.go b/testutil/postgres.go index 7dd0724f..87948384 100644 --- a/testutil/postgres.go +++ b/testutil/postgres.go @@ -9,6 +9,7 @@ import ( "net" "os" "path/filepath" + "strconv" "strings" "testing" "time" @@ -17,6 +18,10 @@ import ( _ "github.com/jackc/pgx/v5/stdlib" ) +// PostgresVersion is an alias for the embedded-postgres version type +// This allows test files to reference the version type without directly importing embedded-postgres +type PostgresVersion = embeddedpostgres.PostgresVersion + // getPostgresVersion returns the PostgreSQL version to use for testing. // It reads from the PGSCHEMA_POSTGRES_VERSION environment variable, // defaulting to "17" if not set. @@ -46,8 +51,8 @@ func findAvailablePort() (int, error) { return listener.Addr().(*net.TCPAddr).Port, nil } -// ContainerInfo holds PostgreSQL instance connection details -type ContainerInfo struct { +// TestPostgres holds PostgreSQL instance connection details for testing +type TestPostgres struct { Database *embeddedpostgres.EmbeddedPostgres Host string Port int @@ -56,22 +61,29 @@ type ContainerInfo struct { RuntimePath string } -// SetupPostgresContainer creates a new PostgreSQL test container -func SetupPostgresContainer(ctx context.Context, t *testing.T) *ContainerInfo { - return SetupPostgresContainerWithDB(ctx, t, "testdb", "testuser", "testpass") -} +// SetupTestPostgres creates a new PostgreSQL test instance with standard credentials +func SetupTestPostgres(ctx context.Context, t *testing.T) *TestPostgres { + // Standard test database credentials + database := "testdb" + username := "testuser" + password := "testpass" -// SetupPostgresContainerWithDB creates a new PostgreSQL instance with custom database settings -func SetupPostgresContainerWithDB(ctx context.Context, t *testing.T, database, username, password string) *ContainerInfo { // Extract test name and create unique runtime path - testName := strings.ReplaceAll(t.Name(), "/", "_") // Replace slashes for subtest names + testName := "shared" + if t != nil { + testName = strings.ReplaceAll(t.Name(), "/", "_") // Replace slashes for subtest names + } timestamp := time.Now().Format("20060102_150405.000000000") runtimePath := filepath.Join(os.TempDir(), fmt.Sprintf("pgschema-test-%s-%s", testName, timestamp)) // Find an available port port, err := findAvailablePort() if err != nil { - t.Fatalf("Failed to find available port: %v", err) + if t != nil { + t.Fatalf("Failed to find available port: %v", err) + } else { + panic(fmt.Sprintf("Failed to find available port: %v", err)) + } } // Configure embedded postgres with unique runtime path and dynamic port @@ -96,7 +108,11 @@ func SetupPostgresContainerWithDB(ctx context.Context, t *testing.T, database, u postgres := embeddedpostgres.NewDatabase(config) err = postgres.Start() if err != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) + if t != nil { + t.Fatalf("Failed to start embedded postgres: %v", err) + } else { + panic(fmt.Sprintf("Failed to start embedded postgres: %v", err)) + } } // Build connection string @@ -108,17 +124,25 @@ func SetupPostgresContainerWithDB(ctx context.Context, t *testing.T, database, u conn, err := sql.Open("pgx", testDSN) if err != nil { postgres.Stop() - t.Fatalf("Failed to connect to database: %v", err) + if t != nil { + t.Fatalf("Failed to connect to database: %v", err) + } else { + panic(fmt.Sprintf("Failed to connect to database: %v", err)) + } } // Test the connection if err := conn.PingContext(ctx); err != nil { conn.Close() postgres.Stop() - t.Fatalf("Failed to ping database: %v", err) + if t != nil { + t.Fatalf("Failed to ping database: %v", err) + } else { + panic(fmt.Sprintf("Failed to ping database: %v", err)) + } } - return &ContainerInfo{ + return &TestPostgres{ Database: postgres, Host: host, Port: port, @@ -129,15 +153,21 @@ func SetupPostgresContainerWithDB(ctx context.Context, t *testing.T, database, u } // Terminate cleans up the database instance and connection -func (ci *ContainerInfo) Terminate(ctx context.Context, t *testing.T) { - ci.Conn.Close() - if err := ci.Database.Stop(); err != nil { - t.Logf("Failed to stop embedded postgres: %v", err) +func (tp *TestPostgres) Terminate(ctx context.Context, t *testing.T) { + tp.Conn.Close() + if err := tp.Database.Stop(); err != nil { + if t != nil { + t.Logf("Failed to stop embedded postgres: %v", err) + } + // Silently ignore errors if t is nil (called from TestMain cleanup) } // Clean up the runtime directory - if ci.RuntimePath != "" { - if err := os.RemoveAll(ci.RuntimePath); err != nil { - t.Logf("Failed to clean up runtime directory: %v", err) + if tp.RuntimePath != "" { + if err := os.RemoveAll(tp.RuntimePath); err != nil { + if t != nil { + t.Logf("Failed to clean up runtime directory: %v", err) + } + // Silently ignore errors if t is nil } } } @@ -155,3 +185,382 @@ type TestConnectionConfig struct { User string Schema string } + +// ============================================================================ +// Version Detection and Mapping +// ============================================================================ + +// MapToEmbeddedPostgresVersion maps a PostgreSQL major version to embedded-postgres version +// Supported versions: 14, 15, 16, 17 +func MapToEmbeddedPostgresVersion(majorVersion int) (embeddedpostgres.PostgresVersion, error) { + switch majorVersion { + case 14: + return embeddedpostgres.PostgresVersion("14.18.0"), nil + case 15: + return embeddedpostgres.PostgresVersion("15.13.0"), nil + case 16: + return embeddedpostgres.PostgresVersion("16.9.0"), nil + case 17: + return embeddedpostgres.PostgresVersion("17.5.0"), nil + default: + return "", fmt.Errorf("unsupported PostgreSQL version %d (supported: 14, 15, 16, 17)", majorVersion) + } +} + +// DetectPostgresVersion queries the target database to determine its PostgreSQL version +// and returns the corresponding embedded-postgres version string +func DetectPostgresVersion(db *sql.DB) (embeddedpostgres.PostgresVersion, error) { + ctx := context.Background() + + // Query PostgreSQL version number (e.g., 170005 for 17.5) + var versionNum int + err := db.QueryRowContext(ctx, "SHOW server_version_num").Scan(&versionNum) + if err != nil { + return "", fmt.Errorf("failed to query PostgreSQL version: %w", err) + } + + // Extract major version: version_num / 10000 + // e.g., 170005 / 10000 = 17 + majorVersion := versionNum / 10000 + + // Map to embedded-postgres version + return MapToEmbeddedPostgresVersion(majorVersion) +} + +// ParseVersionString parses a PostgreSQL version string (e.g., "17.5") and returns major version +func ParseVersionString(versionStr string) (int, error) { + // Handle various formats: "17.5", "17.5.0", "PostgreSQL 17.5", etc. + // Extract the version number part + versionStr = strings.TrimSpace(versionStr) + + // Remove "PostgreSQL " prefix if present + versionStr = strings.TrimPrefix(versionStr, "PostgreSQL ") + + // Split by "." and take the first part (major version) + parts := strings.Split(versionStr, ".") + if len(parts) == 0 { + return 0, fmt.Errorf("invalid version string: %s", versionStr) + } + + majorVersion, err := strconv.Atoi(parts[0]) + if err != nil { + return 0, fmt.Errorf("failed to parse major version from %s: %w", versionStr, err) + } + + return majorVersion, nil +} + +// ============================================================================ +// Production EmbeddedPostgres Wrapper +// ============================================================================ + +// EmbeddedPostgres manages a temporary embedded PostgreSQL instance +// This is used both for testing and for the plan command in production +type EmbeddedPostgres struct { + instance *embeddedpostgres.EmbeddedPostgres + db *sql.DB + version embeddedpostgres.PostgresVersion + host string + port int + database string + username string + password string + runtimePath string +} + +// EmbeddedPostgresConfig holds configuration for starting embedded PostgreSQL +type EmbeddedPostgresConfig struct { + Version embeddedpostgres.PostgresVersion + Database string + Username string + Password string +} + +// StartEmbeddedPostgres starts a temporary embedded PostgreSQL instance +func StartEmbeddedPostgres(config *EmbeddedPostgresConfig) (*EmbeddedPostgres, error) { + // Create unique runtime path with timestamp (using nanoseconds for uniqueness) + timestamp := time.Now().Format("20060102_150405.000000000") + runtimePath := filepath.Join(os.TempDir(), fmt.Sprintf("pgschema-plan-%s", timestamp)) + + // Find an available port + port, err := findAvailablePort() + if err != nil { + return nil, fmt.Errorf("failed to find available port: %w", err) + } + + // Configure embedded postgres + pgConfig := embeddedpostgres.DefaultConfig(). + Version(config.Version). + Database(config.Database). + Username(config.Username). + Password(config.Password). + Port(uint32(port)). + RuntimePath(runtimePath). + DataPath(filepath.Join(runtimePath, "data")). + Logger(io.Discard). // Suppress embedded-postgres startup logs + StartParameters(map[string]string{ + "logging_collector": "off", // Disable log collector + "log_destination": "stderr", // Send logs to stderr (which we discard) + "log_min_messages": "PANIC", // Only log PANIC level messages + "log_statement": "none", // Don't log SQL statements + "log_min_duration_statement": "-1", // Don't log slow queries + }) + + // Create and start PostgreSQL instance + instance := embeddedpostgres.NewDatabase(pgConfig) + if err := instance.Start(); err != nil { + return nil, fmt.Errorf("failed to start embedded PostgreSQL: %w", err) + } + + // Build connection string + host := "localhost" + dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", + config.Username, config.Password, host, port, config.Database) + + // Connect to database + db, err := sql.Open("pgx", dsn) + if err != nil { + instance.Stop() + os.RemoveAll(runtimePath) + return nil, fmt.Errorf("failed to connect to embedded PostgreSQL: %w", err) + } + + // Test the connection + ctx := context.Background() + if err := db.PingContext(ctx); err != nil { + db.Close() + instance.Stop() + os.RemoveAll(runtimePath) + return nil, fmt.Errorf("failed to ping embedded PostgreSQL: %w", err) + } + + return &EmbeddedPostgres{ + instance: instance, + db: db, + version: config.Version, + host: host, + port: port, + database: config.Database, + username: config.Username, + password: config.Password, + runtimePath: runtimePath, + }, nil +} + +// GetDB returns the database connection +func (ep *EmbeddedPostgres) GetDB() *sql.DB { + return ep.db +} + +// GetConnectionInfo returns connection details +func (ep *EmbeddedPostgres) GetConnectionInfo() (host string, port int, database string) { + return ep.host, ep.port, ep.database +} + +// GetCredentials returns the username and password for the embedded PostgreSQL instance +func (ep *EmbeddedPostgres) GetCredentials() (username string, password string) { + return ep.username, ep.password +} + +// ResetSchema drops and recreates a schema, clearing all objects +// This is useful for tests that want to reuse the same embedded postgres instance +func (ep *EmbeddedPostgres) ResetSchema(ctx context.Context, schema string) error { + // Drop the schema if it exists (CASCADE to drop all objects) + dropSchemaSQL := fmt.Sprintf("DROP SCHEMA IF EXISTS \"%s\" CASCADE", schema) + if _, err := ep.db.ExecContext(ctx, dropSchemaSQL); err != nil { + return fmt.Errorf("failed to drop schema %s: %w", schema, err) + } + + // Recreate the schema + createSchemaSQL := fmt.Sprintf("CREATE SCHEMA \"%s\"", schema) + if _, err := ep.db.ExecContext(ctx, createSchemaSQL); err != nil { + return fmt.Errorf("failed to create schema %s: %w", schema, err) + } + + return nil +} + +// ApplySchemaSQL applies SQL schema to the embedded PostgreSQL database +func (ep *EmbeddedPostgres) ApplySchemaSQL(ctx context.Context, schema string, sql string) error { + // Create the schema if it doesn't exist + createSchemaSQL := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS \"%s\"", schema) + if _, err := ep.db.ExecContext(ctx, createSchemaSQL); err != nil { + return fmt.Errorf("failed to create schema %s: %w", schema, err) + } + + // Set search_path to the target schema + setSearchPathSQL := fmt.Sprintf("SET search_path TO \"%s\"", schema) + if _, err := ep.db.ExecContext(ctx, setSearchPathSQL); err != nil { + return fmt.Errorf("failed to set search_path: %w", err) + } + + // Execute the SQL directly + // Note: Desired state SQL should never contain operations like CREATE INDEX CONCURRENTLY + // that cannot run in transactions. Those are migration details, not state declarations. + if _, err := ep.db.ExecContext(ctx, sql); err != nil { + return fmt.Errorf("failed to apply schema SQL: %w", err) + } + + return nil +} + +// Stop stops and cleans up the embedded PostgreSQL instance +func (ep *EmbeddedPostgres) Stop() error { + // Close database connection + if ep.db != nil { + ep.db.Close() + } + + // Stop PostgreSQL instance + var stopErr error + if ep.instance != nil { + stopErr = ep.instance.Stop() + } + + // Clean up runtime directory + if ep.runtimePath != "" { + if err := os.RemoveAll(ep.runtimePath); err != nil { + // Don't return error here - just ignore cleanup failures + // This can happen on Windows when files are still in use + } + } + + if stopErr != nil { + return fmt.Errorf("failed to stop embedded PostgreSQL: %w", stopErr) + } + + return nil +} + +// SetupSharedEmbeddedPostgres creates a shared embedded PostgreSQL instance for test suites. +// This instance can be reused across multiple test cases to significantly improve test performance. +// +// Usage example: +// +// func TestMain(m *testing.M) { +// // Create shared embedded postgres for all tests +// embeddedPG := testutil.SetupSharedEmbeddedPostgres(nil, embeddedpostgres.PostgresVersion("17.5.0")) +// defer embeddedPG.Stop() +// +// // Run tests +// code := m.Run() +// os.Exit(code) +// } +func SetupSharedEmbeddedPostgres(t testing.TB, version embeddedpostgres.PostgresVersion) *EmbeddedPostgres { + config := &EmbeddedPostgresConfig{ + Version: version, + Database: "testdb", + Username: "testuser", + Password: "testpass", + } + + embeddedPG, err := StartEmbeddedPostgres(config) + if err != nil { + if t != nil { + t.Fatalf("Failed to start shared embedded PostgreSQL: %v", err) + } else { + panic("Failed to start shared embedded PostgreSQL: " + err.Error()) + } + } + + return embeddedPG +} + +// DetectPostgresVersionFromDB connects to a database and detects its version +// This is a convenience function that opens a connection, detects the version, and closes it +func DetectPostgresVersionFromDB(host string, port int, database, user, password string) (embeddedpostgres.PostgresVersion, error) { + // Build connection string + dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=prefer", + user, password, host, port, database) + + // Connect to database + db, err := sql.Open("pgx", dsn) + if err != nil { + return "", fmt.Errorf("failed to connect to database: %w", err) + } + defer db.Close() + + // Test the connection + ctx := context.Background() + if err := db.PingContext(ctx); err != nil { + return "", fmt.Errorf("failed to ping database: %w", err) + } + + // Detect version + return DetectPostgresVersion(db) +} + +// ============================================================================ +// Shared Test Postgres for IR Tests +// ============================================================================ + +// SetupSharedTestPostgres creates a shared embedded postgres instance for test packages. +// This significantly improves test performance by avoiding repeated postgres startup/shutdown. +// The returned instance should be stored by the caller and passed to ParseSQLForTest. +// +// Usage in test packages: +// +// var sharedTestPostgres *testutil.TestPostgres +// +// func TestMain(m *testing.M) { +// ctx := context.Background() +// sharedTestPostgres = testutil.SetupSharedTestPostgres(ctx, nil) +// defer sharedTestPostgres.Terminate(ctx, nil) +// +// code := m.Run() +// os.Exit(code) +// } +func SetupSharedTestPostgres(ctx context.Context, t testing.TB) *TestPostgres { + // Convert testing.TB to *testing.T if needed + var tt *testing.T + if t != nil { + if tPtr, ok := t.(*testing.T); ok { + tt = tPtr + } else { + // For testing.TB that's not *testing.T (like *testing.M), create a dummy *testing.T + // This is safe because SetupTestPostgres only uses t for Fatalf on errors + panic("SetupSharedTestPostgres requires *testing.T or nil") + } + } + return SetupTestPostgres(ctx, tt) +} + +// ParseSQLForTest is a test helper that converts SQL to an inspectable database state +// using embedded PostgreSQL. This replaces the old parser-based approach for tests. +// +// The caller must provide a TestPostgres instance (typically from SetupSharedTestPostgres). +// The schema will be reset (dropped and recreated) to ensure clean state between test calls. +// +// This function returns the database connection that can be inspected. The caller should NOT +// close this connection as it belongs to the test postgres instance. +// +// This ensures tests use the same code path as production (database inspection) rather than parsing. +func ParseSQLForTest(t *testing.T, testPG *TestPostgres, sqlContent string, schema string) *sql.DB { + t.Helper() + + ctx := context.Background() + conn := testPG.Conn + + // Drop and recreate schema for clean state + dropSchema := fmt.Sprintf("DROP SCHEMA IF EXISTS \"%s\" CASCADE", schema) + if _, err := conn.ExecContext(ctx, dropSchema); err != nil { + t.Fatalf("Failed to drop schema: %v", err) + } + createSchema := fmt.Sprintf("CREATE SCHEMA \"%s\"", schema) + if _, err := conn.ExecContext(ctx, createSchema); err != nil { + t.Fatalf("Failed to create schema: %v", err) + } + + // Set search_path to target schema + setSearchPathSQL := fmt.Sprintf("SET search_path TO \"%s\"", schema) + if _, err := conn.ExecContext(ctx, setSearchPathSQL); err != nil { + t.Fatalf("Failed to set search_path: %v", err) + } + + // Execute the SQL + if _, err := conn.ExecContext(ctx, sqlContent); err != nil { + t.Fatalf("Failed to apply SQL to embedded PostgreSQL: %v", err) + } + + return conn +} From 29b7b5e6e79dba8156e7541fcc6114d3131c3987 Mon Sep 17 00:00:00 2001 From: Tianzhou Date: Fri, 24 Oct 2025 01:02:01 +0800 Subject: [PATCH 2/6] chore: make dump use util.GetIRFromDatabase --- cmd/dump/dump.go | 27 +++------------------------ 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/cmd/dump/dump.go b/cmd/dump/dump.go index fb237525..76476044 100644 --- a/cmd/dump/dump.go +++ b/cmd/dump/dump.go @@ -1,7 +1,6 @@ package dump import ( - "context" "fmt" "os" @@ -60,36 +59,16 @@ func runDump(cmd *cobra.Command, args []string) error { } } - // Build database connection - config := &util.ConnectionConfig{ - Host: host, - Port: port, - Database: db, - User: user, - Password: finalPassword, - SSLMode: "prefer", - ApplicationName: "pgschema", - } - - dbConn, err := util.Connect(config) - if err != nil { - return err - } - defer dbConn.Close() - - ctx := context.Background() - // Load ignore configuration ignoreConfig, err := util.LoadIgnoreFileWithStructure() if err != nil { return fmt.Errorf("failed to load .pgschemaignore: %w", err) } - // Build IR using the IR system - inspector := ir.NewInspector(dbConn, ignoreConfig) - schemaIR, err := inspector.BuildIR(ctx, schema) + // Get IR from database using the shared utility + schemaIR, err := util.GetIRFromDatabase(host, port, db, user, finalPassword, schema, "pgschema", ignoreConfig) if err != nil { - return fmt.Errorf("failed to build IR: %w", err) + return fmt.Errorf("failed to get database schema: %w", err) } // Create an empty schema for comparison to generate a dump diff From 65362122aaaa3d84dd54429586f9bfae45a5c30f Mon Sep 17 00:00:00 2001 From: Tianzhou Date: Fri, 24 Oct 2025 01:13:40 +0800 Subject: [PATCH 3/6] refactor: dump test conflg --- cmd/dump/dump.go | 79 ++++++++--- cmd/dump/dump_integration_test.go | 134 +++++-------------- cmd/dump/dump_permission_integration_test.go | 70 ++-------- testutil/postgres.go | 9 -- 4 files changed, 104 insertions(+), 188 deletions(-) diff --git a/cmd/dump/dump.go b/cmd/dump/dump.go index 76476044..e1d2ca72 100644 --- a/cmd/dump/dump.go +++ b/cmd/dump/dump.go @@ -22,6 +22,18 @@ var ( file string ) +// DumpConfig holds configuration for dump execution +type DumpConfig struct { + Host string + Port int + DB string + User string + Password string + Schema string + MultiFile bool + File string +} + var DumpCmd = &cobra.Command{ Use: "dump", @@ -43,52 +55,79 @@ func init() { DumpCmd.Flags().StringVar(&file, "file", "", "Output file path (required when --multi-file is used)") } -func runDump(cmd *cobra.Command, args []string) error { +// ExecuteDump executes the dump operation with the given configuration +func ExecuteDump(config *DumpConfig) (string, error) { // Validate flags - if multiFile && file == "" { + if config.MultiFile && config.File == "" { // When --multi-file is used but no --file specified, emit warning and use single-file mode fmt.Fprintf(os.Stderr, "Warning: --multi-file flag requires --file to be specified. Fallback to single-file mode.\n") - multiFile = false - } - - // Derive final password: use flag if provided, otherwise check environment variable - finalPassword := password - if finalPassword == "" { - if envPassword := os.Getenv("PGPASSWORD"); envPassword != "" { - finalPassword = envPassword - } + config.MultiFile = false } // Load ignore configuration ignoreConfig, err := util.LoadIgnoreFileWithStructure() if err != nil { - return fmt.Errorf("failed to load .pgschemaignore: %w", err) + return "", fmt.Errorf("failed to load .pgschemaignore: %w", err) } // Get IR from database using the shared utility - schemaIR, err := util.GetIRFromDatabase(host, port, db, user, finalPassword, schema, "pgschema", ignoreConfig) + schemaIR, err := util.GetIRFromDatabase(config.Host, config.Port, config.DB, config.User, config.Password, config.Schema, "pgschema", ignoreConfig) if err != nil { - return fmt.Errorf("failed to get database schema: %w", err) + return "", fmt.Errorf("failed to get database schema: %w", err) } // Create an empty schema for comparison to generate a dump diff emptyIR := ir.NewIR() // Generate diff between empty schema and target schema (this represents a complete dump) - diffs := diff.GenerateMigration(emptyIR, schemaIR, schema) + diffs := diff.GenerateMigration(emptyIR, schemaIR, config.Schema) // Create dump formatter - formatter := dump.NewDumpFormatter(schemaIR.Metadata.DatabaseVersion, schema) + formatter := dump.NewDumpFormatter(schemaIR.Metadata.DatabaseVersion, config.Schema) - if multiFile { + if config.MultiFile { // Multi-file mode - output to files - err := formatter.FormatMultiFile(diffs, file) + err := formatter.FormatMultiFile(diffs, config.File) if err != nil { - return fmt.Errorf("failed to create multi-file output: %w", err) + return "", fmt.Errorf("failed to create multi-file output: %w", err) } + return "", nil } else { - // Single file mode - output to stdout + // Single file mode - return output as string output := formatter.FormatSingleFile(diffs) + return output, nil + } +} + +func runDump(cmd *cobra.Command, args []string) error { + // Derive final password: use flag if provided, otherwise check environment variable + finalPassword := password + if finalPassword == "" { + if envPassword := os.Getenv("PGPASSWORD"); envPassword != "" { + finalPassword = envPassword + } + } + + // Create config from command-line flags + config := &DumpConfig{ + Host: host, + Port: port, + DB: db, + User: user, + Password: finalPassword, + Schema: schema, + MultiFile: multiFile, + File: file, + } + + // Execute dump + output, err := ExecuteDump(config) + if err != nil { + return err + } + + // Print output to stdout (only in single-file mode) + if output != "" { fmt.Print(output) } diff --git a/cmd/dump/dump_integration_test.go b/cmd/dump/dump_integration_test.go index 4de5406d..dd8fc4fa 100644 --- a/cmd/dump/dump_integration_test.go +++ b/cmd/dump/dump_integration_test.go @@ -10,7 +10,6 @@ package dump import ( "context" "fmt" - "io" "os" "strings" "testing" @@ -97,31 +96,23 @@ func runExactMatchTestWithContext(t *testing.T, ctx context.Context, testDataDir t.Fatalf("Failed to execute pgdump.sql: %v", err) } - // Store original connection parameters and restore them later - originalConfig := testutil.TestConnectionConfig{ - Host: host, - Port: port, - DB: db, - User: user, - Schema: schema, + // Create dump configuration + config := &DumpConfig{ + Host: containerInfo.Host, + Port: containerInfo.Port, + DB: "testdb", + User: "testuser", + Password: "testpass", + Schema: "public", + MultiFile: false, + File: "", + } + + // Execute pgschema dump + actualOutput, err := ExecuteDump(config) + if err != nil { + t.Fatalf("Dump command failed: %v", err) } - defer func() { - host = originalConfig.Host - port = originalConfig.Port - db = originalConfig.DB - user = originalConfig.User - schema = originalConfig.Schema - }() - - // Configure connection parameters - host = containerInfo.Host - port = containerInfo.Port - db = "testdb" - user = "testuser" - testutil.SetEnvPassword("testpass") - - // Execute pgschema dump and capture output - actualOutput := executePgSchemaDump(t, "") // Read expected output expectedPath := fmt.Sprintf("../../testdata/dump/%s/pgschema.sql", testDataDir) @@ -195,35 +186,26 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { } } - // Save original command variables - originalConfig := testutil.TestConnectionConfig{ - Host: host, - Port: port, - DB: db, - User: user, - Schema: schema, - } - defer func() { - host = originalConfig.Host - port = originalConfig.Port - db = originalConfig.DB - user = originalConfig.User - schema = originalConfig.Schema - }() - // Dump both tenant schemas using pgschema dump command var dumps []string for _, tenantName := range tenants { - // Set connection parameters for this specific tenant dump - host = containerInfo.Host - port = containerInfo.Port - db = "testdb" - user = "testuser" - testutil.SetEnvPassword("testpass") - schema = tenantName - - // Execute pgschema dump and capture output - actualOutput := executePgSchemaDump(t, fmt.Sprintf("tenant %s", tenantName)) + // Create dump configuration for this tenant + config := &DumpConfig{ + Host: containerInfo.Host, + Port: containerInfo.Port, + DB: "testdb", + User: "testuser", + Password: "testpass", + Schema: tenantName, + MultiFile: false, + File: "", + } + + // Execute pgschema dump + actualOutput, err := ExecuteDump(config) + if err != nil { + t.Fatalf("Dump command failed for tenant %s: %v", tenantName, err) + } dumps = append(dumps, actualOutput) } @@ -248,56 +230,6 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { } } -func executePgSchemaDump(t *testing.T, contextInfo string) string { - // Capture output by redirecting stdout - originalStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - // Read from pipe in a goroutine to avoid deadlock - var actualOutput string - var readErr error - done := make(chan bool) - - go func() { - defer close(done) - output, err := io.ReadAll(r) - if err != nil { - readErr = err - return - } - actualOutput = string(output) - }() - - // Run the dump command - // Logger setup handled by root command - err := runDump(nil, nil) - - // Close write end and restore stdout - w.Close() - os.Stdout = originalStdout - - if err != nil { - if contextInfo != "" { - t.Fatalf("Dump command failed for %s: %v", contextInfo, err) - } else { - t.Fatalf("Dump command failed: %v", err) - } - } - - // Wait for reading to complete - <-done - if readErr != nil { - if contextInfo != "" { - t.Fatalf("Failed to read captured output for %s: %v", contextInfo, readErr) - } else { - t.Fatalf("Failed to read captured output: %v", readErr) - } - } - - return actualOutput -} - // normalizeSchemaOutput removes version-specific lines for comparison func normalizeSchemaOutput(output string) string { lines := strings.Split(output, "\n") diff --git a/cmd/dump/dump_permission_integration_test.go b/cmd/dump/dump_permission_integration_test.go index 51cfaa0e..e16b9aa3 100644 --- a/cmd/dump/dump_permission_integration_test.go +++ b/cmd/dump/dump_permission_integration_test.go @@ -12,7 +12,6 @@ import ( "context" "database/sql" "fmt" - "io" "os" "strings" "testing" @@ -434,63 +433,18 @@ func testProcedureAndFunctionSourceAccess(t *testing.T, ctx context.Context, con // This helper is used specifically for permission testing where we need to run // the dump command with restricted database user credentials. func executeDumpCommandAsUser(hostArg string, portArg int, database, userArg, password, schemaArg string) (string, error) { - // Store original connection parameters and restore them later - originalConfig := testutil.TestConnectionConfig{ - Host: host, - Port: port, - DB: db, - User: user, - Schema: schema, - } - defer func() { - host = originalConfig.Host - port = originalConfig.Port - db = originalConfig.DB - user = originalConfig.User - schema = originalConfig.Schema - }() - - // Set connection parameters for this specific dump - host = hostArg - port = portArg - db = database - user = userArg - schema = schemaArg - testutil.SetEnvPassword(password) - - // Capture output by redirecting stdout - originalStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - // Read from pipe in a goroutine to avoid deadlock - var actualOutput string - var readErr error - done := make(chan bool) - - go func() { - defer close(done) - output, err := io.ReadAll(r) - if err != nil { - readErr = err - return - } - actualOutput = string(output) - }() - - // Run the dump command - // Logger setup handled by root command - err := runDump(nil, nil) - - // Close write end and restore stdout - w.Close() - os.Stdout = originalStdout - - // Wait for reading to complete - <-done - if readErr != nil { - return "", fmt.Errorf("failed to read captured output: %w", readErr) + // Create dump configuration + config := &DumpConfig{ + Host: hostArg, + Port: portArg, + DB: database, + User: userArg, + Password: password, + Schema: schemaArg, + MultiFile: false, + File: "", } - return actualOutput, err + // Execute dump + return ExecuteDump(config) } diff --git a/testutil/postgres.go b/testutil/postgres.go index 87948384..d345185e 100644 --- a/testutil/postgres.go +++ b/testutil/postgres.go @@ -177,15 +177,6 @@ func SetEnvPassword(password string) { os.Setenv("PGPASSWORD", password) } -// TestConnectionConfig stores connection settings for save/restore operations -type TestConnectionConfig struct { - Host string - Port int - DB string - User string - Schema string -} - // ============================================================================ // Version Detection and Mapping // ============================================================================ From ed5134985f1f3e6298144cdf98e7dfedd702fdc5 Mon Sep 17 00:00:00 2001 From: Tianzhou Date: Fri, 24 Oct 2025 01:31:04 +0800 Subject: [PATCH 4/6] refactor: consolidate test conn info --- cmd/apply/apply_integration_test.go | 60 ++++++++++---------- cmd/dump/dump_integration_test.go | 12 ++-- cmd/dump/dump_permission_integration_test.go | 4 +- cmd/ignore_integration_test.go | 18 +++--- cmd/include_integration_test.go | 12 ++-- cmd/migrate_integration_test.go | 25 ++++---- cmd/plan/plan_integration_test.go | 24 ++++---- cmd/schema_integration_test.go | 34 +++++------ testutil/postgres.go | 6 ++ 9 files changed, 101 insertions(+), 94 deletions(-) diff --git a/cmd/apply/apply_integration_test.go b/cmd/apply/apply_integration_test.go index 072070c5..0adff9d0 100644 --- a/cmd/apply/apply_integration_test.go +++ b/cmd/apply/apply_integration_test.go @@ -136,9 +136,9 @@ func TestApplyCommand_TransactionRollback(t *testing.T) { planConfig := &planCmd.PlanConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", File: desiredStateFile, ApplicationName: "pgschema", @@ -195,9 +195,9 @@ func TestApplyCommand_TransactionRollback(t *testing.T) { applyConfig := &ApplyConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", Plan: migrationPlan, // Use pre-generated plan with injected failure AutoApprove: true, @@ -373,9 +373,9 @@ func TestApplyCommand_CreateIndexConcurrently(t *testing.T) { planConfig := &planCmd.PlanConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", File: desiredStateFile, ApplicationName: "pgschema", @@ -412,9 +412,9 @@ func TestApplyCommand_CreateIndexConcurrently(t *testing.T) { applyConfig := &ApplyConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", Plan: migrationPlan, // Use pre-generated plan AutoApprove: true, @@ -567,9 +567,9 @@ func TestApplyCommand_WithPlanFile(t *testing.T) { planConfig := &planCmd.PlanConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", File: desiredStateFile, ApplicationName: "pgschema", @@ -584,9 +584,9 @@ func TestApplyCommand_WithPlanFile(t *testing.T) { applyConfig := &ApplyConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", Plan: migrationPlan, // Use pre-generated plan AutoApprove: true, @@ -742,9 +742,9 @@ func TestApplyCommand_FingerprintMismatch(t *testing.T) { planConfig := &planCmd.PlanConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", File: desiredStateFile, ApplicationName: "pgschema", @@ -796,9 +796,9 @@ func TestApplyCommand_FingerprintMismatch(t *testing.T) { applyConfig := &ApplyConfig{ Host: containerHost, Port: portMapped, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", Plan: migrationPlan, // Use pre-generated plan with old fingerprint AutoApprove: true, @@ -929,9 +929,9 @@ func TestApplyCommand_WaitDirective(t *testing.T) { planConfig := &planCmd.PlanConfig{ Host: container.Host, Port: container.Port, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", File: desiredStateFile, ApplicationName: "pgschema", @@ -946,9 +946,9 @@ func TestApplyCommand_WaitDirective(t *testing.T) { applyConfig := &ApplyConfig{ Host: container.Host, Port: container.Port, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: container.DBName, + User: container.User, + Password: container.Password, Schema: "public", Plan: migrationPlan, // Use pre-generated plan AutoApprove: true, diff --git a/cmd/dump/dump_integration_test.go b/cmd/dump/dump_integration_test.go index dd8fc4fa..781bde43 100644 --- a/cmd/dump/dump_integration_test.go +++ b/cmd/dump/dump_integration_test.go @@ -100,9 +100,9 @@ func runExactMatchTestWithContext(t *testing.T, ctx context.Context, testDataDir config := &DumpConfig{ Host: containerInfo.Host, Port: containerInfo.Port, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: containerInfo.DBName, + User: containerInfo.User, + Password: containerInfo.Password, Schema: "public", MultiFile: false, File: "", @@ -193,9 +193,9 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { config := &DumpConfig{ Host: containerInfo.Host, Port: containerInfo.Port, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: containerInfo.DBName, + User: containerInfo.User, + Password: containerInfo.Password, Schema: tenantName, MultiFile: false, File: "", diff --git a/cmd/dump/dump_permission_integration_test.go b/cmd/dump/dump_permission_integration_test.go index e16b9aa3..110d71ec 100644 --- a/cmd/dump/dump_permission_integration_test.go +++ b/cmd/dump/dump_permission_integration_test.go @@ -73,8 +73,8 @@ func setupTestDatabase(ctx context.Context, t *testing.T, container *testutil.Te Host: container.Host, Port: container.Port, Database: dbName, - User: "testuser", - Password: "testpass", + User: container.User, + Password: container.Password, SSLMode: "prefer", ApplicationName: "pgschema", } diff --git a/cmd/ignore_integration_test.go b/cmd/ignore_integration_test.go index 5fe74275..1b75fd23 100644 --- a/cmd/ignore_integration_test.go +++ b/cmd/ignore_integration_test.go @@ -526,9 +526,9 @@ func executeIgnoreDumpCommand(t *testing.T, containerInfo *testutil.TestPostgres "dump", "--host", containerInfo.Host, "--port", fmt.Sprintf("%d", containerInfo.Port), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", containerInfo.DBName, + "--user", containerInfo.User, + "--password", containerInfo.Password, "--schema", "public", } rootCmd.SetArgs(args) @@ -552,9 +552,9 @@ func executeIgnorePlanCommand(t *testing.T, containerInfo *testutil.TestPostgres config := &planCmd.PlanConfig{ Host: containerInfo.Host, Port: containerInfo.Port, - DB: "testdb", - User: "testuser", - Password: "testpass", + DB: containerInfo.DBName, + User: containerInfo.User, + Password: containerInfo.Password, Schema: "public", File: schemaFile, ApplicationName: "pgschema", @@ -581,9 +581,9 @@ func executeIgnoreApplyCommandWithError(containerInfo *testutil.TestPostgres, sc "apply", "--host", containerInfo.Host, "--port", fmt.Sprintf("%d", containerInfo.Port), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", containerInfo.DBName, + "--user", containerInfo.User, + "--password", containerInfo.Password, "--schema", "public", "--file", schemaFile, "--auto-approve", diff --git a/cmd/include_integration_test.go b/cmd/include_integration_test.go index 20439e08..8a76daef 100644 --- a/cmd/include_integration_test.go +++ b/cmd/include_integration_test.go @@ -63,9 +63,9 @@ func applyIncludeSchema(t *testing.T, containerInfo *testutil.TestPostgres) { "apply", "--host", containerInfo.Host, "--port", fmt.Sprintf("%d", containerInfo.Port), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", containerInfo.DBName, + "--user", containerInfo.User, + "--password", containerInfo.Password, "--file", mainSQLPath, "--auto-approve", // Skip interactive confirmation } @@ -95,9 +95,9 @@ func executeMultiFileDump(t *testing.T, containerInfo *testutil.TestPostgres, ou "dump", "--host", containerInfo.Host, "--port", fmt.Sprintf("%d", containerInfo.Port), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", containerInfo.DBName, + "--user", containerInfo.User, + "--password", containerInfo.Password, "--schema", "public", "--multi-file", "--file", outputPath, diff --git a/cmd/migrate_integration_test.go b/cmd/migrate_integration_test.go index 72295c41..c3883bf5 100644 --- a/cmd/migrate_integration_test.go +++ b/cmd/migrate_integration_test.go @@ -80,9 +80,6 @@ func TestPlanAndApply(t *testing.T) { container := testutil.SetupTestPostgres(ctx, t) defer container.Terminate(ctx, t) - containerHost := container.Host - portMapped := container.Port - // Get test filter from environment variable testFilter := os.Getenv("PGSCHEMA_TEST_FILTER") @@ -165,7 +162,7 @@ func TestPlanAndApply(t *testing.T) { // Run all test cases using the shared container for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - runPlanAndApplyTest(t, ctx, containerHost, portMapped, tc) + runPlanAndApplyTest(t, ctx, container, tc) }) } } @@ -180,7 +177,9 @@ type testCase struct { } // runPlanAndApplyTest executes a single plan and apply test case with test-specific database -func runPlanAndApplyTest(t *testing.T, ctx context.Context, containerHost string, portMapped int, tc testCase) { +func runPlanAndApplyTest(t *testing.T, ctx context.Context, container *testutil.TestPostgres, tc testCase) { + containerHost := container.Host + portMapped := container.Port // Create a unique database name for this test case (replace invalid chars) dbName := "test_" + strings.ReplaceAll(strings.ReplaceAll(tc.name, "/", "_"), "-", "_") // PostgreSQL identifiers are limited to 63 characters @@ -212,12 +211,12 @@ func runPlanAndApplyTest(t *testing.T, ctx context.Context, containerHost string // STEP 2: Test plan command with new.sql as target t.Logf("--- Testing plan command outputs ---") - testPlanOutputs(t, containerHost, portMapped, dbName, tc.newFile, tc.planSQLFile, tc.planJSONFile, tc.planTXTFile) + testPlanOutputs(t, container, dbName, tc.newFile, tc.planSQLFile, tc.planJSONFile, tc.planTXTFile) if !*generate { // STEP 3: Apply the migration using apply command t.Logf("--- Applying migration using apply command ---") - err = applySchemaChanges(containerHost, portMapped, dbName, "testuser", "testpass", "public", tc.newFile) + err = applySchemaChanges(containerHost, portMapped, dbName, container.User, container.Password, "public", tc.newFile) if err != nil { t.Fatalf("Failed to apply schema changes using pgschema apply: %v", err) } @@ -225,7 +224,7 @@ func runPlanAndApplyTest(t *testing.T, ctx context.Context, containerHost string // STEP 4: Test idempotency - plan should produce no changes t.Logf("--- Testing idempotency ---") - secondPlanOutput, err := generatePlanSQLFormatted(containerHost, portMapped, dbName, "testuser", "testpass", "public", tc.newFile) + secondPlanOutput, err := generatePlanSQLFormatted(containerHost, portMapped, dbName, container.User, container.Password, "public", tc.newFile) if err != nil { t.Fatalf("Failed to generate plan SQL for idempotency check: %v", err) } @@ -241,14 +240,16 @@ func runPlanAndApplyTest(t *testing.T, ctx context.Context, containerHost string } // testPlanOutputs tests all plan output formats against expected files -func testPlanOutputs(t *testing.T, containerHost string, portMapped int, dbName, schemaFile, planSQLFile, planJSONFile, planTXTFile string) { +func testPlanOutputs(t *testing.T, container *testutil.TestPostgres, dbName, schemaFile, planSQLFile, planJSONFile, planTXTFile string) { + containerHost := container.Host + portMapped := container.Port // Set fixed timestamp for generate mode to ensure deterministic output if *generate { os.Setenv("PGSCHEMA_TEST_TIME", "1970-01-01T00:00:00Z") defer os.Unsetenv("PGSCHEMA_TEST_TIME") } // Test SQL format - sqlFormattedOutput, err := generatePlanSQLFormatted(containerHost, portMapped, dbName, "testuser", "testpass", "public", schemaFile) + sqlFormattedOutput, err := generatePlanSQLFormatted(containerHost, portMapped, dbName, container.User, container.Password, "public", schemaFile) if err != nil { t.Fatalf("Failed to generate plan SQL formatted output: %v", err) } @@ -281,7 +282,7 @@ func testPlanOutputs(t *testing.T, containerHost string, portMapped int, dbName, } // Test human-readable format - humanOutput, err := generatePlanHuman(containerHost, portMapped, dbName, "testuser", "testpass", "public", schemaFile) + humanOutput, err := generatePlanHuman(containerHost, portMapped, dbName, container.User, container.Password, "public", schemaFile) if err != nil { t.Fatalf("Failed to generate plan human output: %v", err) } @@ -314,7 +315,7 @@ func testPlanOutputs(t *testing.T, containerHost string, portMapped int, dbName, } // Test JSON format - jsonOutput, err := generatePlanJSON(containerHost, portMapped, dbName, "testuser", "testpass", "public", schemaFile) + jsonOutput, err := generatePlanJSON(containerHost, portMapped, dbName, container.User, container.Password, "public", schemaFile) if err != nil { t.Fatalf("Failed to generate plan JSON output: %v", err) } diff --git a/cmd/plan/plan_integration_test.go b/cmd/plan/plan_integration_test.go index 82b67065..5021f00c 100644 --- a/cmd/plan/plan_integration_test.go +++ b/cmd/plan/plan_integration_test.go @@ -90,9 +90,9 @@ func TestPlanCommand_DatabaseIntegration(t *testing.T) { args := []string{ "--host", containerHost, "--port", fmt.Sprintf("%d", portMapped), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", container.DBName, + "--user", container.User, + "--password", container.Password, "--file", desiredStateFile, "--output-human", "stdout", } @@ -184,9 +184,9 @@ func TestPlanCommand_OutputFormats(t *testing.T) { args := []string{ "--host", containerHost, "--port", fmt.Sprintf("%d", portMapped), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", container.DBName, + "--user", container.User, + "--password", container.Password, "--file", desiredStateFile, tc.outputFlag, "stdout", } @@ -279,9 +279,9 @@ func TestPlanCommand_SchemaFiltering(t *testing.T) { args := []string{ "--host", containerHost, "--port", fmt.Sprintf("%d", portMapped), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", container.DBName, + "--user", container.User, + "--password", container.Password, "--schema", "public", // Filter to only public schema "--file", publicSchemaFile, "--output-human", "stdout", @@ -348,9 +348,9 @@ func TestPlanCommand_EmptyDatabase(t *testing.T) { args := []string{ "--host", containerHost, "--port", fmt.Sprintf("%d", portMapped), - "--db", "testdb", - "--user", "testuser", - "--password", "testpass", + "--db", container.DBName, + "--user", container.User, + "--password", container.Password, "--file", desiredStateFile, "--output-human", "stdout", } diff --git a/cmd/schema_integration_test.go b/cmd/schema_integration_test.go index b9f19cd3..fe7054b7 100644 --- a/cmd/schema_integration_test.go +++ b/cmd/schema_integration_test.go @@ -61,11 +61,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { // Step 1: Generate plan using CLI planOutput, err := executePlanCommand( - container.Host, - container.Port, - "testdb", - "testuser", - "testpass", + container.Host, + container.Port, + container.DBName, + container.User, + container.Password, "tenant", // Non-public schema desiredStateFile, ) @@ -84,9 +84,9 @@ func TestNonPublicSchemaOperations(t *testing.T) { err = executeApplyCommand( container.Host, container.Port, - "testdb", - "testuser", - "testpass", + container.DBName, + container.User, + container.Password, "tenant", // Non-public schema desiredStateFile, ) @@ -171,9 +171,9 @@ func TestNonPublicSchemaOperations(t *testing.T) { err = executeApplyCommand( container.Host, container.Port, - "testdb", - "testuser", - "testpass", + container.DBName, + container.User, + container.Password, "app_a", // Target only app_a desiredStateFile, ) @@ -252,9 +252,9 @@ func TestNonPublicSchemaOperations(t *testing.T) { planOutput, err := executePlanCommand( container.Host, container.Port, - "testdb", - "testuser", - "testpass", + container.DBName, + container.User, + container.Password, "MyApp", // Mixed-case schema desiredStateFile, ) @@ -273,9 +273,9 @@ func TestNonPublicSchemaOperations(t *testing.T) { err = executeApplyCommand( container.Host, container.Port, - "testdb", - "testuser", - "testpass", + container.DBName, + container.User, + container.Password, "MyApp", // Mixed-case schema desiredStateFile, ) diff --git a/testutil/postgres.go b/testutil/postgres.go index d345185e..95304b42 100644 --- a/testutil/postgres.go +++ b/testutil/postgres.go @@ -59,6 +59,9 @@ type TestPostgres struct { DSN string Conn *sql.DB RuntimePath string + DBName string // Database name + User string // Database user + Password string // Database password } // SetupTestPostgres creates a new PostgreSQL test instance with standard credentials @@ -149,6 +152,9 @@ func SetupTestPostgres(ctx context.Context, t *testing.T) *TestPostgres { DSN: testDSN, Conn: conn, RuntimePath: runtimePath, + DBName: database, + User: username, + Password: password, } } From 82407ba1df90fa2bdf6c3755a0247b10f64b9527 Mon Sep 17 00:00:00 2001 From: Tianzhou Date: Fri, 24 Oct 2025 02:10:19 +0800 Subject: [PATCH 5/6] chore: simplify embed pg code --- cmd/apply/apply.go | 6 +- cmd/apply/apply_integration_test.go | 5 +- cmd/migrate_integration_test.go | 5 +- cmd/plan/plan.go | 24 +- internal/diff/diff_test.go | 23 +- internal/plan/plan_test.go | 23 +- internal/postgres/embedded.go | 269 +++++++++++++ testutil/postgres.go | 591 +++++----------------------- 8 files changed, 405 insertions(+), 541 deletions(-) create mode 100644 internal/postgres/embedded.go diff --git a/cmd/apply/apply.go b/cmd/apply/apply.go index f0a2b438..6ac8915c 100644 --- a/cmd/apply/apply.go +++ b/cmd/apply/apply.go @@ -12,9 +12,9 @@ import ( "github.com/pgschema/pgschema/cmd/util" "github.com/pgschema/pgschema/internal/fingerprint" "github.com/pgschema/pgschema/internal/plan" + "github.com/pgschema/pgschema/internal/postgres" "github.com/pgschema/pgschema/internal/version" "github.com/pgschema/pgschema/ir" - "github.com/pgschema/pgschema/testutil" "github.com/spf13/cobra" ) @@ -90,7 +90,7 @@ type ApplyConfig struct { // // If config.File is provided, embeddedPG is used to generate the plan. // The caller is responsible for managing the embeddedPG lifecycle (creation and cleanup). -func ApplyMigration(config *ApplyConfig, embeddedPG *testutil.EmbeddedPostgres) error { +func ApplyMigration(config *ApplyConfig, embeddedPG *postgres.EmbeddedPostgres) error { var migrationPlan *plan.Plan var err error @@ -254,7 +254,7 @@ func RunApply(cmd *cobra.Command, args []string) error { ApplicationName: applyApplicationName, } - var embeddedPG *testutil.EmbeddedPostgres + var embeddedPG *postgres.EmbeddedPostgres var err error // If using --plan flag, load plan from JSON file diff --git a/cmd/apply/apply_integration_test.go b/cmd/apply/apply_integration_test.go index 0adff9d0..f22381cc 100644 --- a/cmd/apply/apply_integration_test.go +++ b/cmd/apply/apply_integration_test.go @@ -9,20 +9,21 @@ import ( planCmd "github.com/pgschema/pgschema/cmd/plan" "github.com/pgschema/pgschema/internal/plan" + "github.com/pgschema/pgschema/internal/postgres" "github.com/pgschema/pgschema/testutil" ) var ( // sharedEmbeddedPG is a shared embedded PostgreSQL instance used across all integration tests // to significantly improve test performance by avoiding repeated startup/teardown - sharedEmbeddedPG *testutil.EmbeddedPostgres + sharedEmbeddedPG *postgres.EmbeddedPostgres ) // TestMain sets up shared resources for all tests in this package func TestMain(m *testing.M) { // Create shared embedded postgres instance for all integration tests // This dramatically improves test performance by reusing the same instance - sharedEmbeddedPG = testutil.SetupSharedEmbeddedPostgres(nil, testutil.PostgresVersion("17.5.0")) + sharedEmbeddedPG = testutil.SetupPostgres(nil, testutil.WithShared()) defer sharedEmbeddedPG.Stop() // Run tests diff --git a/cmd/migrate_integration_test.go b/cmd/migrate_integration_test.go index c3883bf5..1b95ae3c 100644 --- a/cmd/migrate_integration_test.go +++ b/cmd/migrate_integration_test.go @@ -16,6 +16,7 @@ import ( "github.com/pgschema/pgschema/cmd/apply" planCmd "github.com/pgschema/pgschema/cmd/plan" "github.com/pgschema/pgschema/internal/plan" + "github.com/pgschema/pgschema/internal/postgres" "github.com/pgschema/pgschema/testutil" ) @@ -23,7 +24,7 @@ var ( generate = flag.Bool("generate", false, "generate expected test output files instead of comparing") // sharedEmbeddedPG is a shared embedded PostgreSQL instance used across all integration tests // to significantly improve test performance by avoiding repeated startup/teardown - sharedEmbeddedPG *testutil.EmbeddedPostgres + sharedEmbeddedPG *postgres.EmbeddedPostgres ) // TestMain sets up shared resources for all tests in this package @@ -33,7 +34,7 @@ func TestMain(m *testing.M) { // Create shared embedded postgres instance for all integration tests // This dramatically improves test performance (from ~60s to ~10s per test) - sharedEmbeddedPG = testutil.SetupSharedEmbeddedPostgres(nil, testutil.PostgresVersion("17.5.0")) + sharedEmbeddedPG = testutil.SetupPostgres(nil, testutil.WithShared()) defer sharedEmbeddedPG.Stop() // Run tests diff --git a/cmd/plan/plan.go b/cmd/plan/plan.go index a9c9c3cd..ec5daeaf 100644 --- a/cmd/plan/plan.go +++ b/cmd/plan/plan.go @@ -11,7 +11,7 @@ import ( "github.com/pgschema/pgschema/internal/fingerprint" "github.com/pgschema/pgschema/internal/include" "github.com/pgschema/pgschema/internal/plan" - "github.com/pgschema/pgschema/testutil" + "github.com/pgschema/pgschema/internal/postgres" "github.com/spf13/cobra" ) @@ -123,9 +123,9 @@ type PlanConfig struct { // CreateEmbeddedPostgresForPlan creates a temporary embedded PostgreSQL instance // for validating the desired state schema. The instance should be stopped by the caller. -func CreateEmbeddedPostgresForPlan(config *PlanConfig) (*testutil.EmbeddedPostgres, error) { +func CreateEmbeddedPostgresForPlan(config *PlanConfig) (*postgres.EmbeddedPostgres, error) { // Detect target database PostgreSQL version - pgVersion, err := testutil.DetectPostgresVersionFromDB( + pgVersion, err := postgres.DetectPostgresVersionFromDB( config.Host, config.Port, config.DB, @@ -137,13 +137,13 @@ func CreateEmbeddedPostgresForPlan(config *PlanConfig) (*testutil.EmbeddedPostgr } // Start embedded PostgreSQL with matching version - embeddedConfig := &testutil.EmbeddedPostgresConfig{ + embeddedConfig := &postgres.EmbeddedPostgresConfig{ Version: pgVersion, Database: "pgschema_temp", Username: "pgschema", Password: "pgschema", } - embeddedPG, err := testutil.StartEmbeddedPostgres(embeddedConfig) + embeddedPG, err := postgres.StartEmbeddedPostgres(embeddedConfig) if err != nil { return nil, fmt.Errorf("failed to start embedded PostgreSQL: %w", err) } @@ -154,7 +154,7 @@ func CreateEmbeddedPostgresForPlan(config *PlanConfig) (*testutil.EmbeddedPostgr // GeneratePlan generates a migration plan from configuration. // The caller must provide a non-nil embeddedPG instance for validating the desired state schema. // The caller is responsible for managing the embeddedPG lifecycle (creation and cleanup). -func GeneratePlan(config *PlanConfig, embeddedPG *testutil.EmbeddedPostgres) (*plan.Plan, error) { +func GeneratePlan(config *PlanConfig, embeddedPG *postgres.EmbeddedPostgres) (*plan.Plan, error) { // Load ignore configuration ignoreConfig, err := util.LoadIgnoreFileWithStructure() if err != nil { @@ -182,19 +182,13 @@ func GeneratePlan(config *PlanConfig, embeddedPG *testutil.EmbeddedPostgres) (*p ctx := context.Background() - // Reset the schema to ensure clean state - if err := embeddedPG.ResetSchema(ctx, config.Schema); err != nil { - return nil, fmt.Errorf("failed to reset schema in embedded PostgreSQL: %w", err) - } - - // Apply desired state SQL to embedded PostgreSQL - if err := embeddedPG.ApplySchemaSQL(ctx, config.Schema, desiredState); err != nil { + // Apply desired state SQL to embedded PostgreSQL (resets schema first) + if err := embeddedPG.ApplySchema(ctx, config.Schema, desiredState); err != nil { return nil, fmt.Errorf("failed to apply desired state to embedded PostgreSQL: %w", err) } // Inspect embedded PostgreSQL to get desired state IR - embeddedHost, embeddedPort, embeddedDB := embeddedPG.GetConnectionInfo() - embeddedUsername, embeddedPassword := embeddedPG.GetCredentials() + embeddedHost, embeddedPort, embeddedDB, embeddedUsername, embeddedPassword := embeddedPG.GetConnectionDetails() desiredStateIR, err := util.GetIRFromDatabase(embeddedHost, embeddedPort, embeddedDB, embeddedUsername, embeddedPassword, config.Schema, config.ApplicationName, ignoreConfig) if err != nil { return nil, fmt.Errorf("failed to get desired state from embedded PostgreSQL: %w", err) diff --git a/internal/diff/diff_test.go b/internal/diff/diff_test.go index 5e89b437..7589b2c5 100644 --- a/internal/diff/diff_test.go +++ b/internal/diff/diff_test.go @@ -1,25 +1,24 @@ package diff import ( - "context" "os" "path/filepath" "strings" "testing" + "github.com/pgschema/pgschema/internal/postgres" "github.com/pgschema/pgschema/ir" "github.com/pgschema/pgschema/testutil" ) // sharedTestPostgres is the shared embedded postgres instance for all tests in this package -var sharedTestPostgres *testutil.TestPostgres +var sharedTestPostgres *postgres.EmbeddedPostgres // TestMain sets up shared resources for all tests in this package func TestMain(m *testing.M) { // Create shared embedded postgres for all tests to dramatically improve performance - ctx := context.Background() - sharedTestPostgres = testutil.SetupSharedTestPostgres(ctx, nil) - defer sharedTestPostgres.Terminate(ctx, nil) + sharedTestPostgres = testutil.SetupPostgres(nil, testutil.WithShared()) + defer sharedTestPostgres.Stop() // Run tests code := m.Run() @@ -58,19 +57,7 @@ func buildSQLFromSteps(diffs []Diff) string { // Uses embedded PostgreSQL to ensure tests use the same code path as production func parseSQL(t *testing.T, sql string) *ir.IR { t.Helper() - - // Use testutil to apply SQL to embedded postgres - conn := testutil.ParseSQLForTest(t, sharedTestPostgres, sql, "public") - - // Inspect the database to get IR - ctx := context.Background() - inspector := ir.NewInspector(conn, nil) - irResult, err := inspector.BuildIR(ctx, "public") - if err != nil { - t.Fatalf("Failed to inspect embedded PostgreSQL: %v", err) - } - - return irResult + return testutil.ParseSQLToIR(t, sharedTestPostgres, sql, "public") } // TestDiffFromFiles runs file-based diff tests from testdata directory. diff --git a/internal/plan/plan_test.go b/internal/plan/plan_test.go index a21a38d9..79adc5b6 100644 --- a/internal/plan/plan_test.go +++ b/internal/plan/plan_test.go @@ -1,7 +1,6 @@ package plan import ( - "context" "encoding/json" "fmt" "os" @@ -13,19 +12,19 @@ import ( "github.com/google/go-cmp/cmp" "github.com/pgschema/pgschema/internal/diff" + "github.com/pgschema/pgschema/internal/postgres" "github.com/pgschema/pgschema/ir" "github.com/pgschema/pgschema/testutil" ) // sharedTestPostgres is the shared embedded postgres instance for all tests in this package -var sharedTestPostgres *testutil.TestPostgres +var sharedTestPostgres *postgres.EmbeddedPostgres // TestMain sets up shared resources for all tests in this package func TestMain(m *testing.M) { // Create shared embedded postgres for all tests to dramatically improve performance - ctx := context.Background() - sharedTestPostgres = testutil.SetupSharedTestPostgres(ctx, nil) - defer sharedTestPostgres.Terminate(ctx, nil) + sharedTestPostgres = testutil.SetupPostgres(nil, testutil.WithShared()) + defer sharedTestPostgres.Stop() // Run tests code := m.Run() @@ -59,19 +58,7 @@ func discoverTestDataVersions(testdataDir string) ([]string, error) { // Uses embedded PostgreSQL to ensure tests use the same code path as production func parseSQL(t *testing.T, sql string) *ir.IR { t.Helper() - - // Use testutil to apply SQL to embedded postgres - conn := testutil.ParseSQLForTest(t, sharedTestPostgres, sql, "public") - - // Inspect the database to get IR - ctx := context.Background() - inspector := ir.NewInspector(conn, nil) - irResult, err := inspector.BuildIR(ctx, "public") - if err != nil { - t.Fatalf("Failed to inspect embedded PostgreSQL: %v", err) - } - - return irResult + return testutil.ParseSQLToIR(t, sharedTestPostgres, sql, "public") } func TestPlanSummary(t *testing.T) { diff --git a/internal/postgres/embedded.go b/internal/postgres/embedded.go new file mode 100644 index 00000000..426d2c71 --- /dev/null +++ b/internal/postgres/embedded.go @@ -0,0 +1,269 @@ +// Package postgres provides embedded PostgreSQL functionality for production use. +// This package is used by the plan command to create temporary PostgreSQL instances +// for validating desired state schemas. +package postgres + +import ( + "context" + "database/sql" + "fmt" + "io" + "net" + "os" + "path/filepath" + "time" + + embeddedpostgres "github.com/fergusstrange/embedded-postgres" + _ "github.com/jackc/pgx/v5/stdlib" +) + +// ============================================================================ +// Type Definitions +// ============================================================================ + +// PostgresVersion is an alias for the embedded-postgres version type. +type PostgresVersion = embeddedpostgres.PostgresVersion + +// EmbeddedPostgres manages a temporary embedded PostgreSQL instance. +// This is used by the plan command to validate desired state schemas. +type EmbeddedPostgres struct { + instance *embeddedpostgres.EmbeddedPostgres + db *sql.DB + version PostgresVersion + host string + port int + database string + username string + password string + runtimePath string +} + +// EmbeddedPostgresConfig holds configuration for starting embedded PostgreSQL +type EmbeddedPostgresConfig struct { + Version PostgresVersion + Database string + Username string + Password string +} + +// ============================================================================ +// Version Detection +// ============================================================================ + +// DetectPostgresVersionFromDB connects to a database and detects its version +// This is a convenience function that opens a connection, detects the version, and closes it +func DetectPostgresVersionFromDB(host string, port int, database, user, password string) (PostgresVersion, error) { + // Build connection string + dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=prefer", + user, password, host, port, database) + + // Connect to database + db, err := sql.Open("pgx", dsn) + if err != nil { + return "", fmt.Errorf("failed to connect to database: %w", err) + } + defer db.Close() + + // Test the connection + ctx := context.Background() + if err := db.PingContext(ctx); err != nil { + return "", fmt.Errorf("failed to ping database: %w", err) + } + + // Detect version + return detectPostgresVersion(db) +} + +// ============================================================================ +// EmbeddedPostgres Lifecycle +// ============================================================================ + +// StartEmbeddedPostgres starts a temporary embedded PostgreSQL instance +func StartEmbeddedPostgres(config *EmbeddedPostgresConfig) (*EmbeddedPostgres, error) { + // Create unique runtime path with timestamp (using nanoseconds for uniqueness) + timestamp := time.Now().Format("20060102_150405.000000000") + runtimePath := filepath.Join(os.TempDir(), fmt.Sprintf("pgschema-plan-%s", timestamp)) + + // Find an available port + port, err := findAvailablePort() + if err != nil { + return nil, fmt.Errorf("failed to find available port: %w", err) + } + + // Configure embedded postgres + pgConfig := embeddedpostgres.DefaultConfig(). + Version(config.Version). + Database(config.Database). + Username(config.Username). + Password(config.Password). + Port(uint32(port)). + RuntimePath(runtimePath). + DataPath(filepath.Join(runtimePath, "data")). + Logger(io.Discard). // Suppress embedded-postgres startup logs + StartParameters(map[string]string{ + "logging_collector": "off", // Disable log collector + "log_destination": "stderr", // Send logs to stderr (which we discard) + "log_min_messages": "PANIC", // Only log PANIC level messages + "log_statement": "none", // Don't log SQL statements + "log_min_duration_statement": "-1", // Don't log slow queries + }) + + // Create and start PostgreSQL instance + instance := embeddedpostgres.NewDatabase(pgConfig) + if err := instance.Start(); err != nil { + return nil, fmt.Errorf("failed to start embedded PostgreSQL: %w", err) + } + + // Build connection string + host := "localhost" + dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", + config.Username, config.Password, host, port, config.Database) + + // Connect to database + db, err := sql.Open("pgx", dsn) + if err != nil { + instance.Stop() + os.RemoveAll(runtimePath) + return nil, fmt.Errorf("failed to connect to embedded PostgreSQL: %w", err) + } + + // Test the connection + ctx := context.Background() + if err := db.PingContext(ctx); err != nil { + db.Close() + instance.Stop() + os.RemoveAll(runtimePath) + return nil, fmt.Errorf("failed to ping embedded PostgreSQL: %w", err) + } + + return &EmbeddedPostgres{ + instance: instance, + db: db, + version: config.Version, + host: host, + port: port, + database: config.Database, + username: config.Username, + password: config.Password, + runtimePath: runtimePath, + }, nil +} + +// Stop stops and cleans up the embedded PostgreSQL instance +func (ep *EmbeddedPostgres) Stop() error { + // Close database connection + if ep.db != nil { + ep.db.Close() + } + + // Stop PostgreSQL instance + var stopErr error + if ep.instance != nil { + stopErr = ep.instance.Stop() + } + + // Clean up runtime directory + if ep.runtimePath != "" { + if err := os.RemoveAll(ep.runtimePath); err != nil { + // Don't return error here - just ignore cleanup failures + // This can happen on Windows when files are still in use + } + } + + if stopErr != nil { + return fmt.Errorf("failed to stop embedded PostgreSQL: %w", stopErr) + } + + return nil +} + +// ============================================================================ +// EmbeddedPostgres Operations +// ============================================================================ + +// GetConnectionDetails returns all connection details needed to connect to the embedded PostgreSQL instance +func (ep *EmbeddedPostgres) GetConnectionDetails() (host string, port int, database, username, password string) { + return ep.host, ep.port, ep.database, ep.username, ep.password +} + +// ApplySchema resets a schema (drops and recreates it) and applies SQL to it. +// This ensures a clean state before applying the desired schema definition. +func (ep *EmbeddedPostgres) ApplySchema(ctx context.Context, schema string, sql string) error { + // Drop the schema if it exists (CASCADE to drop all objects) + dropSchemaSQL := fmt.Sprintf("DROP SCHEMA IF EXISTS \"%s\" CASCADE", schema) + if _, err := ep.db.ExecContext(ctx, dropSchemaSQL); err != nil { + return fmt.Errorf("failed to drop schema %s: %w", schema, err) + } + + // Create the schema + createSchemaSQL := fmt.Sprintf("CREATE SCHEMA \"%s\"", schema) + if _, err := ep.db.ExecContext(ctx, createSchemaSQL); err != nil { + return fmt.Errorf("failed to create schema %s: %w", schema, err) + } + + // Set search_path to the target schema + setSearchPathSQL := fmt.Sprintf("SET search_path TO \"%s\"", schema) + if _, err := ep.db.ExecContext(ctx, setSearchPathSQL); err != nil { + return fmt.Errorf("failed to set search_path: %w", err) + } + + // Execute the SQL directly + // Note: Desired state SQL should never contain operations like CREATE INDEX CONCURRENTLY + // that cannot run in transactions. Those are migration details, not state declarations. + if _, err := ep.db.ExecContext(ctx, sql); err != nil { + return fmt.Errorf("failed to apply schema SQL: %w", err) + } + + return nil +} + +// ============================================================================ +// Internal Helpers +// ============================================================================ + +// findAvailablePort finds an available TCP port for PostgreSQL to use +func findAvailablePort() (int, error) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + return 0, err + } + defer listener.Close() + return listener.Addr().(*net.TCPAddr).Port, nil +} + +// mapToEmbeddedPostgresVersion maps a PostgreSQL major version to embedded-postgres version +// Supported versions: 14, 15, 16, 17 +func mapToEmbeddedPostgresVersion(majorVersion int) (PostgresVersion, error) { + switch majorVersion { + case 14: + return PostgresVersion("14.18.0"), nil + case 15: + return PostgresVersion("15.13.0"), nil + case 16: + return PostgresVersion("16.9.0"), nil + case 17: + return PostgresVersion("17.5.0"), nil + default: + return "", fmt.Errorf("unsupported PostgreSQL version %d (supported: 14, 15, 16, 17)", majorVersion) + } +} + +// detectPostgresVersion queries the target database to determine its PostgreSQL version +// and returns the corresponding embedded-postgres version string +func detectPostgresVersion(db *sql.DB) (PostgresVersion, error) { + ctx := context.Background() + + // Query PostgreSQL version number (e.g., 170005 for 17.5) + var versionNum int + err := db.QueryRowContext(ctx, "SHOW server_version_num").Scan(&versionNum) + if err != nil { + return "", fmt.Errorf("failed to query PostgreSQL version: %w", err) + } + + // Extract major version: version_num / 10000 + // e.g., 170005 / 10000 = 17 + majorVersion := versionNum / 10000 + + // Map to embedded-postgres version + return mapToEmbeddedPostgresVersion(majorVersion) +} diff --git a/testutil/postgres.go b/testutil/postgres.go index 95304b42..93e5e9bf 100644 --- a/testutil/postgres.go +++ b/testutil/postgres.go @@ -5,559 +5,184 @@ import ( "context" "database/sql" "fmt" - "io" - "net" "os" - "path/filepath" - "strconv" - "strings" "testing" - "time" - embeddedpostgres "github.com/fergusstrange/embedded-postgres" _ "github.com/jackc/pgx/v5/stdlib" + "github.com/pgschema/pgschema/internal/postgres" + "github.com/pgschema/pgschema/ir" ) -// PostgresVersion is an alias for the embedded-postgres version type -// This allows test files to reference the version type without directly importing embedded-postgres -type PostgresVersion = embeddedpostgres.PostgresVersion +// ============================================================================ +// PostgreSQL Setup +// ============================================================================ -// getPostgresVersion returns the PostgreSQL version to use for testing. -// It reads from the PGSCHEMA_POSTGRES_VERSION environment variable, -// defaulting to "17" if not set. -func getPostgresVersion() embeddedpostgres.PostgresVersion { - versionStr := os.Getenv("PGSCHEMA_POSTGRES_VERSION") - switch versionStr { - case "14": - return embeddedpostgres.PostgresVersion("14.18.0") - case "15": - return embeddedpostgres.PostgresVersion("15.13.0") - case "16": - return embeddedpostgres.PostgresVersion("16.9.0") - case "17", "": - return embeddedpostgres.PostgresVersion("17.5.0") - default: - return embeddedpostgres.PostgresVersion("17.5.0") - } -} +// PostgresOption is a functional option for configuring PostgreSQL setup +type PostgresOption func(*postgresConfig) -// findAvailablePort finds an available TCP port for PostgreSQL to use -func findAvailablePort() (int, error) { - listener, err := net.Listen("tcp", ":0") - if err != nil { - return 0, err - } - defer listener.Close() - return listener.Addr().(*net.TCPAddr).Port, nil +// postgresConfig holds configuration for PostgreSQL setup +type postgresConfig struct { + shared bool // if true, use "shared" naming for runtime path } -// TestPostgres holds PostgreSQL instance connection details for testing -type TestPostgres struct { - Database *embeddedpostgres.EmbeddedPostgres - Host string - Port int - DSN string - Conn *sql.DB - RuntimePath string - DBName string // Database name - User string // Database user - Password string // Database password +// WithShared returns an option to create a shared PostgreSQL instance +// Shared instances are typically created once in TestMain and reused across tests +func WithShared() PostgresOption { + return func(c *postgresConfig) { + c.shared = true + } } -// SetupTestPostgres creates a new PostgreSQL test instance with standard credentials -func SetupTestPostgres(ctx context.Context, t *testing.T) *TestPostgres { - // Standard test database credentials - database := "testdb" - username := "testuser" - password := "testpass" - - // Extract test name and create unique runtime path - testName := "shared" - if t != nil { - testName = strings.ReplaceAll(t.Name(), "/", "_") // Replace slashes for subtest names +// SetupPostgres creates a PostgreSQL instance for testing. +// It uses the production postgres.EmbeddedPostgres implementation. +// PostgreSQL version is determined from PGSCHEMA_POSTGRES_VERSION environment variable. +// +// Usage: +// - Per-test instance: testutil.SetupPostgres(t) +// - Shared instance: testutil.SetupPostgres(nil, testutil.WithShared()) +func SetupPostgres(t testing.TB, opts ...PostgresOption) *postgres.EmbeddedPostgres { + // Apply options + cfg := &postgresConfig{shared: false} + for _, opt := range opts { + opt(cfg) } - timestamp := time.Now().Format("20060102_150405.000000000") - runtimePath := filepath.Join(os.TempDir(), fmt.Sprintf("pgschema-test-%s-%s", testName, timestamp)) - // Find an available port - port, err := findAvailablePort() - if err != nil { - if t != nil { - t.Fatalf("Failed to find available port: %v", err) - } else { - panic(fmt.Sprintf("Failed to find available port: %v", err)) - } - } + // Determine PostgreSQL version from environment + version := getPostgresVersion() - // Configure embedded postgres with unique runtime path and dynamic port - config := embeddedpostgres.DefaultConfig(). - Version(getPostgresVersion()). - Database(database). - Username(username). - Password(password). - Port(uint32(port)). - RuntimePath(runtimePath). - DataPath(filepath.Join(runtimePath, "data")). - Logger(io.Discard). // Suppress embedded-postgres startup logs - StartParameters(map[string]string{ - "logging_collector": "off", // Disable log collector - "log_destination": "stderr", // Send logs to stderr (which we discard above) - "log_min_messages": "PANIC", // Only log PANIC level messages - "log_statement": "none", // Don't log SQL statements - "log_min_duration_statement": "-1", // Don't log slow queries - }) - - // Create and start PostgreSQL instance - postgres := embeddedpostgres.NewDatabase(config) - err = postgres.Start() - if err != nil { - if t != nil { - t.Fatalf("Failed to start embedded postgres: %v", err) - } else { - panic(fmt.Sprintf("Failed to start embedded postgres: %v", err)) - } + // Create configuration for production postgres package + config := &postgres.EmbeddedPostgresConfig{ + Version: version, + Database: "testdb", + Username: "testuser", + Password: "testpass", } - // Build connection string - host := "localhost" - testDSN := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", - username, password, host, port, database) - - // Connect to database - conn, err := sql.Open("pgx", testDSN) + // Start embedded PostgreSQL using production code + embeddedPG, err := postgres.StartEmbeddedPostgres(config) if err != nil { - postgres.Stop() - if t != nil { - t.Fatalf("Failed to connect to database: %v", err) - } else { - panic(fmt.Sprintf("Failed to connect to database: %v", err)) - } - } - - // Test the connection - if err := conn.PingContext(ctx); err != nil { - conn.Close() - postgres.Stop() if t != nil { - t.Fatalf("Failed to ping database: %v", err) + t.Fatalf("Failed to start embedded PostgreSQL: %v", err) } else { - panic(fmt.Sprintf("Failed to ping database: %v", err)) - } - } - - return &TestPostgres{ - Database: postgres, - Host: host, - Port: port, - DSN: testDSN, - Conn: conn, - RuntimePath: runtimePath, - DBName: database, - User: username, - Password: password, - } -} - -// Terminate cleans up the database instance and connection -func (tp *TestPostgres) Terminate(ctx context.Context, t *testing.T) { - tp.Conn.Close() - if err := tp.Database.Stop(); err != nil { - if t != nil { - t.Logf("Failed to stop embedded postgres: %v", err) - } - // Silently ignore errors if t is nil (called from TestMain cleanup) - } - // Clean up the runtime directory - if tp.RuntimePath != "" { - if err := os.RemoveAll(tp.RuntimePath); err != nil { - if t != nil { - t.Logf("Failed to clean up runtime directory: %v", err) - } - // Silently ignore errors if t is nil + panic("Failed to start embedded PostgreSQL: " + err.Error()) } } -} -// SetEnvPassword sets the PGPASSWORD environment variable -func SetEnvPassword(password string) { - os.Setenv("PGPASSWORD", password) + return embeddedPG } // ============================================================================ -// Version Detection and Mapping +// Test Helpers // ============================================================================ -// MapToEmbeddedPostgresVersion maps a PostgreSQL major version to embedded-postgres version -// Supported versions: 14, 15, 16, 17 -func MapToEmbeddedPostgresVersion(majorVersion int) (embeddedpostgres.PostgresVersion, error) { - switch majorVersion { - case 14: - return embeddedpostgres.PostgresVersion("14.18.0"), nil - case 15: - return embeddedpostgres.PostgresVersion("15.13.0"), nil - case 16: - return embeddedpostgres.PostgresVersion("16.9.0"), nil - case 17: - return embeddedpostgres.PostgresVersion("17.5.0"), nil - default: - return "", fmt.Errorf("unsupported PostgreSQL version %d (supported: 14, 15, 16, 17)", majorVersion) - } -} +// ParseSQLToIR is a test helper that parses SQL and returns its IR representation. +// It applies the SQL to an embedded PostgreSQL instance, inspects it, and returns the IR. +// The schema will be reset (dropped and recreated) to ensure clean state between test calls. +// This ensures tests use the same code path as production (database inspection) rather than parsing. +func ParseSQLToIR(t *testing.T, embeddedPG *postgres.EmbeddedPostgres, sqlContent string, schema string) *ir.IR { + t.Helper() -// DetectPostgresVersion queries the target database to determine its PostgreSQL version -// and returns the corresponding embedded-postgres version string -func DetectPostgresVersion(db *sql.DB) (embeddedpostgres.PostgresVersion, error) { ctx := context.Background() - // Query PostgreSQL version number (e.g., 170005 for 17.5) - var versionNum int - err := db.QueryRowContext(ctx, "SHOW server_version_num").Scan(&versionNum) - if err != nil { - return "", fmt.Errorf("failed to query PostgreSQL version: %w", err) - } - - // Extract major version: version_num / 10000 - // e.g., 170005 / 10000 = 17 - majorVersion := versionNum / 10000 - - // Map to embedded-postgres version - return MapToEmbeddedPostgresVersion(majorVersion) -} - -// ParseVersionString parses a PostgreSQL version string (e.g., "17.5") and returns major version -func ParseVersionString(versionStr string) (int, error) { - // Handle various formats: "17.5", "17.5.0", "PostgreSQL 17.5", etc. - // Extract the version number part - versionStr = strings.TrimSpace(versionStr) - - // Remove "PostgreSQL " prefix if present - versionStr = strings.TrimPrefix(versionStr, "PostgreSQL ") - - // Split by "." and take the first part (major version) - parts := strings.Split(versionStr, ".") - if len(parts) == 0 { - return 0, fmt.Errorf("invalid version string: %s", versionStr) - } - - majorVersion, err := strconv.Atoi(parts[0]) - if err != nil { - return 0, fmt.Errorf("failed to parse major version from %s: %w", versionStr, err) - } - - return majorVersion, nil -} - -// ============================================================================ -// Production EmbeddedPostgres Wrapper -// ============================================================================ - -// EmbeddedPostgres manages a temporary embedded PostgreSQL instance -// This is used both for testing and for the plan command in production -type EmbeddedPostgres struct { - instance *embeddedpostgres.EmbeddedPostgres - db *sql.DB - version embeddedpostgres.PostgresVersion - host string - port int - database string - username string - password string - runtimePath string -} - -// EmbeddedPostgresConfig holds configuration for starting embedded PostgreSQL -type EmbeddedPostgresConfig struct { - Version embeddedpostgres.PostgresVersion - Database string - Username string - Password string -} - -// StartEmbeddedPostgres starts a temporary embedded PostgreSQL instance -func StartEmbeddedPostgres(config *EmbeddedPostgresConfig) (*EmbeddedPostgres, error) { - // Create unique runtime path with timestamp (using nanoseconds for uniqueness) - timestamp := time.Now().Format("20060102_150405.000000000") - runtimePath := filepath.Join(os.TempDir(), fmt.Sprintf("pgschema-plan-%s", timestamp)) - - // Find an available port - port, err := findAvailablePort() - if err != nil { - return nil, fmt.Errorf("failed to find available port: %w", err) - } - - // Configure embedded postgres - pgConfig := embeddedpostgres.DefaultConfig(). - Version(config.Version). - Database(config.Database). - Username(config.Username). - Password(config.Password). - Port(uint32(port)). - RuntimePath(runtimePath). - DataPath(filepath.Join(runtimePath, "data")). - Logger(io.Discard). // Suppress embedded-postgres startup logs - StartParameters(map[string]string{ - "logging_collector": "off", // Disable log collector - "log_destination": "stderr", // Send logs to stderr (which we discard) - "log_min_messages": "PANIC", // Only log PANIC level messages - "log_statement": "none", // Don't log SQL statements - "log_min_duration_statement": "-1", // Don't log slow queries - }) - - // Create and start PostgreSQL instance - instance := embeddedpostgres.NewDatabase(pgConfig) - if err := instance.Start(); err != nil { - return nil, fmt.Errorf("failed to start embedded PostgreSQL: %w", err) - } + // Get connection details from embedded postgres + host, port, database, username, password := embeddedPG.GetConnectionDetails() // Build connection string - host := "localhost" dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", - config.Username, config.Password, host, port, config.Database) + username, password, host, port, database) // Connect to database - db, err := sql.Open("pgx", dsn) + conn, err := sql.Open("pgx", dsn) if err != nil { - instance.Stop() - os.RemoveAll(runtimePath) - return nil, fmt.Errorf("failed to connect to embedded PostgreSQL: %w", err) + t.Fatalf("Failed to connect to database: %v", err) } + defer conn.Close() // Test the connection - ctx := context.Background() - if err := db.PingContext(ctx); err != nil { - db.Close() - instance.Stop() - os.RemoveAll(runtimePath) - return nil, fmt.Errorf("failed to ping embedded PostgreSQL: %w", err) - } - - return &EmbeddedPostgres{ - instance: instance, - db: db, - version: config.Version, - host: host, - port: port, - database: config.Database, - username: config.Username, - password: config.Password, - runtimePath: runtimePath, - }, nil -} - -// GetDB returns the database connection -func (ep *EmbeddedPostgres) GetDB() *sql.DB { - return ep.db -} - -// GetConnectionInfo returns connection details -func (ep *EmbeddedPostgres) GetConnectionInfo() (host string, port int, database string) { - return ep.host, ep.port, ep.database -} - -// GetCredentials returns the username and password for the embedded PostgreSQL instance -func (ep *EmbeddedPostgres) GetCredentials() (username string, password string) { - return ep.username, ep.password -} - -// ResetSchema drops and recreates a schema, clearing all objects -// This is useful for tests that want to reuse the same embedded postgres instance -func (ep *EmbeddedPostgres) ResetSchema(ctx context.Context, schema string) error { - // Drop the schema if it exists (CASCADE to drop all objects) - dropSchemaSQL := fmt.Sprintf("DROP SCHEMA IF EXISTS \"%s\" CASCADE", schema) - if _, err := ep.db.ExecContext(ctx, dropSchemaSQL); err != nil { - return fmt.Errorf("failed to drop schema %s: %w", schema, err) + if err := conn.PingContext(ctx); err != nil { + t.Fatalf("Failed to ping database: %v", err) } - // Recreate the schema - createSchemaSQL := fmt.Sprintf("CREATE SCHEMA \"%s\"", schema) - if _, err := ep.db.ExecContext(ctx, createSchemaSQL); err != nil { - return fmt.Errorf("failed to create schema %s: %w", schema, err) + // Drop and recreate schema for clean state + dropSchema := fmt.Sprintf("DROP SCHEMA IF EXISTS \"%s\" CASCADE", schema) + if _, err := conn.ExecContext(ctx, dropSchema); err != nil { + t.Fatalf("Failed to drop schema: %v", err) } - - return nil -} - -// ApplySchemaSQL applies SQL schema to the embedded PostgreSQL database -func (ep *EmbeddedPostgres) ApplySchemaSQL(ctx context.Context, schema string, sql string) error { - // Create the schema if it doesn't exist - createSchemaSQL := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS \"%s\"", schema) - if _, err := ep.db.ExecContext(ctx, createSchemaSQL); err != nil { - return fmt.Errorf("failed to create schema %s: %w", schema, err) + createSchema := fmt.Sprintf("CREATE SCHEMA \"%s\"", schema) + if _, err := conn.ExecContext(ctx, createSchema); err != nil { + t.Fatalf("Failed to create schema: %v", err) } - // Set search_path to the target schema + // Set search_path to target schema setSearchPathSQL := fmt.Sprintf("SET search_path TO \"%s\"", schema) - if _, err := ep.db.ExecContext(ctx, setSearchPathSQL); err != nil { - return fmt.Errorf("failed to set search_path: %w", err) - } - - // Execute the SQL directly - // Note: Desired state SQL should never contain operations like CREATE INDEX CONCURRENTLY - // that cannot run in transactions. Those are migration details, not state declarations. - if _, err := ep.db.ExecContext(ctx, sql); err != nil { - return fmt.Errorf("failed to apply schema SQL: %w", err) - } - - return nil -} - -// Stop stops and cleans up the embedded PostgreSQL instance -func (ep *EmbeddedPostgres) Stop() error { - // Close database connection - if ep.db != nil { - ep.db.Close() - } - - // Stop PostgreSQL instance - var stopErr error - if ep.instance != nil { - stopErr = ep.instance.Stop() + if _, err := conn.ExecContext(ctx, setSearchPathSQL); err != nil { + t.Fatalf("Failed to set search_path: %v", err) } - // Clean up runtime directory - if ep.runtimePath != "" { - if err := os.RemoveAll(ep.runtimePath); err != nil { - // Don't return error here - just ignore cleanup failures - // This can happen on Windows when files are still in use - } + // Execute the SQL + if _, err := conn.ExecContext(ctx, sqlContent); err != nil { + t.Fatalf("Failed to apply SQL to embedded PostgreSQL: %v", err) } - if stopErr != nil { - return fmt.Errorf("failed to stop embedded PostgreSQL: %w", stopErr) + // Inspect the database to get IR + inspector := ir.NewInspector(conn, nil) + irResult, err := inspector.BuildIR(ctx, schema) + if err != nil { + t.Fatalf("Failed to inspect embedded PostgreSQL: %v", err) } - return nil + return irResult } -// SetupSharedEmbeddedPostgres creates a shared embedded PostgreSQL instance for test suites. -// This instance can be reused across multiple test cases to significantly improve test performance. -// -// Usage example: -// -// func TestMain(m *testing.M) { -// // Create shared embedded postgres for all tests -// embeddedPG := testutil.SetupSharedEmbeddedPostgres(nil, embeddedpostgres.PostgresVersion("17.5.0")) -// defer embeddedPG.Stop() -// -// // Run tests -// code := m.Run() -// os.Exit(code) -// } -func SetupSharedEmbeddedPostgres(t testing.TB, version embeddedpostgres.PostgresVersion) *EmbeddedPostgres { - config := &EmbeddedPostgresConfig{ - Version: version, - Database: "testdb", - Username: "testuser", - Password: "testpass", - } +// ConnectToPostgres connects to an embedded PostgreSQL instance and returns connection details. +// This is a helper for tests that need database connection information. +// The caller is responsible for closing the returned *sql.DB connection. +func ConnectToPostgres(t testing.TB, embeddedPG *postgres.EmbeddedPostgres) (conn *sql.DB, host string, port int, dbname, user, password string) { + t.Helper() - embeddedPG, err := StartEmbeddedPostgres(config) - if err != nil { - if t != nil { - t.Fatalf("Failed to start shared embedded PostgreSQL: %v", err) - } else { - panic("Failed to start shared embedded PostgreSQL: " + err.Error()) - } - } + ctx := context.Background() - return embeddedPG -} + // Get connection details from embedded postgres + host, port, dbname, user, password = embeddedPG.GetConnectionDetails() -// DetectPostgresVersionFromDB connects to a database and detects its version -// This is a convenience function that opens a connection, detects the version, and closes it -func DetectPostgresVersionFromDB(host string, port int, database, user, password string) (embeddedpostgres.PostgresVersion, error) { // Build connection string - dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=prefer", - user, password, host, port, database) + dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", + user, password, host, port, dbname) // Connect to database - db, err := sql.Open("pgx", dsn) + conn, err := sql.Open("pgx", dsn) if err != nil { - return "", fmt.Errorf("failed to connect to database: %w", err) + t.Fatalf("Failed to connect to database: %v", err) } - defer db.Close() // Test the connection - ctx := context.Background() - if err := db.PingContext(ctx); err != nil { - return "", fmt.Errorf("failed to ping database: %w", err) + if err := conn.PingContext(ctx); err != nil { + conn.Close() + t.Fatalf("Failed to ping database: %v", err) } - // Detect version - return DetectPostgresVersion(db) + return conn, host, port, dbname, user, password } // ============================================================================ -// Shared Test Postgres for IR Tests +// Internal Helpers // ============================================================================ -// SetupSharedTestPostgres creates a shared embedded postgres instance for test packages. -// This significantly improves test performance by avoiding repeated postgres startup/shutdown. -// The returned instance should be stored by the caller and passed to ParseSQLForTest. -// -// Usage in test packages: -// -// var sharedTestPostgres *testutil.TestPostgres -// -// func TestMain(m *testing.M) { -// ctx := context.Background() -// sharedTestPostgres = testutil.SetupSharedTestPostgres(ctx, nil) -// defer sharedTestPostgres.Terminate(ctx, nil) -// -// code := m.Run() -// os.Exit(code) -// } -func SetupSharedTestPostgres(ctx context.Context, t testing.TB) *TestPostgres { - // Convert testing.TB to *testing.T if needed - var tt *testing.T - if t != nil { - if tPtr, ok := t.(*testing.T); ok { - tt = tPtr - } else { - // For testing.TB that's not *testing.T (like *testing.M), create a dummy *testing.T - // This is safe because SetupTestPostgres only uses t for Fatalf on errors - panic("SetupSharedTestPostgres requires *testing.T or nil") - } - } - return SetupTestPostgres(ctx, tt) -} - -// ParseSQLForTest is a test helper that converts SQL to an inspectable database state -// using embedded PostgreSQL. This replaces the old parser-based approach for tests. -// -// The caller must provide a TestPostgres instance (typically from SetupSharedTestPostgres). -// The schema will be reset (dropped and recreated) to ensure clean state between test calls. -// -// This function returns the database connection that can be inspected. The caller should NOT -// close this connection as it belongs to the test postgres instance. -// -// This ensures tests use the same code path as production (database inspection) rather than parsing. -func ParseSQLForTest(t *testing.T, testPG *TestPostgres, sqlContent string, schema string) *sql.DB { - t.Helper() - - ctx := context.Background() - conn := testPG.Conn - - // Drop and recreate schema for clean state - dropSchema := fmt.Sprintf("DROP SCHEMA IF EXISTS \"%s\" CASCADE", schema) - if _, err := conn.ExecContext(ctx, dropSchema); err != nil { - t.Fatalf("Failed to drop schema: %v", err) - } - createSchema := fmt.Sprintf("CREATE SCHEMA \"%s\"", schema) - if _, err := conn.ExecContext(ctx, createSchema); err != nil { - t.Fatalf("Failed to create schema: %v", err) - } - - // Set search_path to target schema - setSearchPathSQL := fmt.Sprintf("SET search_path TO \"%s\"", schema) - if _, err := conn.ExecContext(ctx, setSearchPathSQL); err != nil { - t.Fatalf("Failed to set search_path: %v", err) - } - - // Execute the SQL - if _, err := conn.ExecContext(ctx, sqlContent); err != nil { - t.Fatalf("Failed to apply SQL to embedded PostgreSQL: %v", err) +// getPostgresVersion returns the PostgreSQL version to use for testing. +// It reads from the PGSCHEMA_POSTGRES_VERSION environment variable, +// defaulting to "17" if not set. +func getPostgresVersion() postgres.PostgresVersion { + versionStr := os.Getenv("PGSCHEMA_POSTGRES_VERSION") + switch versionStr { + case "14": + return postgres.PostgresVersion("14.18.0") + case "15": + return postgres.PostgresVersion("15.13.0") + case "16": + return postgres.PostgresVersion("16.9.0") + case "17", "": + return postgres.PostgresVersion("17.5.0") + default: + return postgres.PostgresVersion("17.5.0") } - - return conn } From 52143a8960ac6d60c19484b1b03297dcd5ec568b Mon Sep 17 00:00:00 2001 From: tianzhou Date: Fri, 24 Oct 2025 18:37:33 +0800 Subject: [PATCH 6/6] chore: address comment --- cmd/apply/apply_integration_test.go | 123 ++++++++++++++++--- cmd/dump/dump.go | 3 +- cmd/dump/dump_integration_test.go | 52 ++++---- cmd/dump/dump_permission_integration_test.go | 50 +++++++- cmd/ignore_integration_test.go | 113 ++++++++++++++--- cmd/include_integration_test.go | 45 +++++-- cmd/migrate_integration_test.go | 43 ++++++- cmd/plan/plan_integration_test.go | 97 +++++++++++++-- cmd/schema_integration_test.go | 60 ++++----- internal/diff/diff_test.go | 2 +- internal/plan/plan_test.go | 2 +- internal/postgres/embedded.go | 20 --- testutil/postgres.go | 39 +----- 13 files changed, 474 insertions(+), 175 deletions(-) diff --git a/cmd/apply/apply_integration_test.go b/cmd/apply/apply_integration_test.go index f22381cc..aebb2804 100644 --- a/cmd/apply/apply_integration_test.go +++ b/cmd/apply/apply_integration_test.go @@ -2,6 +2,7 @@ package apply import ( "context" + "database/sql" "os" "path/filepath" "strings" @@ -23,7 +24,7 @@ var ( func TestMain(m *testing.M) { // Create shared embedded postgres instance for all integration tests // This dramatically improves test performance by reusing the same instance - sharedEmbeddedPG = testutil.SetupPostgres(nil, testutil.WithShared()) + sharedEmbeddedPG = testutil.SetupPostgres(nil) defer sharedEmbeddedPG.Stop() // Run tests @@ -61,11 +62,29 @@ func TestApplyCommand_TransactionRollback(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupTestPostgres(ctx, t) - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with initial schema - conn := container.Conn initialSQL := ` CREATE TABLE users ( @@ -318,11 +337,29 @@ func TestApplyCommand_CreateIndexConcurrently(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupTestPostgres(ctx, t) - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with initial schema - conn := container.Conn initialSQL := ` CREATE TABLE users ( @@ -519,11 +556,29 @@ func TestApplyCommand_WithPlanFile(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupTestPostgres(ctx, t) - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with initial schema - conn := container.Conn initialSQL := ` CREATE TABLE users ( @@ -688,11 +743,29 @@ func TestApplyCommand_FingerprintMismatch(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupTestPostgres(ctx, t) - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with initial schema - conn := container.Conn initialSQL := ` CREATE TABLE users ( @@ -879,11 +952,29 @@ func TestApplyCommand_WaitDirective(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupTestPostgres(ctx, t) - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with initial schema and data - conn := container.Conn initialSQL := ` CREATE TABLE users ( diff --git a/cmd/dump/dump.go b/cmd/dump/dump.go index e1d2ca72..dbce9052 100644 --- a/cmd/dump/dump.go +++ b/cmd/dump/dump.go @@ -34,14 +34,13 @@ type DumpConfig struct { File string } - var DumpCmd = &cobra.Command{ Use: "dump", Short: "Dump database schema for a specific schema", Long: "Dump and output database schema information for a specific schema. Uses the --schema flag to target a particular schema (defaults to 'public').", RunE: runDump, SilenceUsage: true, - PreRunE: util.PreRunEWithEnvVarsAndConnection(&db, &user, &host, &port), + PreRunE: util.PreRunEWithEnvVarsAndConnection(&db, &user, &host, &port), } func init() { diff --git a/cmd/dump/dump_integration_test.go b/cmd/dump/dump_integration_test.go index 781bde43..3207b612 100644 --- a/cmd/dump/dump_integration_test.go +++ b/cmd/dump/dump_integration_test.go @@ -79,9 +79,13 @@ func runExactMatchTest(t *testing.T, testDataDir string) { } func runExactMatchTestWithContext(t *testing.T, ctx context.Context, testDataDir string) { - // Setup PostgreSQL container - containerInfo := testutil.SetupTestPostgres(ctx, t) - defer containerInfo.Terminate(ctx, t) + // Setup PostgreSQL + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + + // Connect to database + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() // Read and execute the pgdump.sql file pgdumpPath := fmt.Sprintf("../../testdata/dump/%s/pgdump.sql", testDataDir) @@ -91,18 +95,18 @@ func runExactMatchTestWithContext(t *testing.T, ctx context.Context, testDataDir } // Execute the SQL to create the schema - _, err = containerInfo.Conn.ExecContext(ctx, string(pgdumpContent)) + _, err = conn.ExecContext(ctx, string(pgdumpContent)) if err != nil { t.Fatalf("Failed to execute pgdump.sql: %v", err) } // Create dump configuration config := &DumpConfig{ - Host: containerInfo.Host, - Port: containerInfo.Port, - DB: containerInfo.DBName, - User: containerInfo.User, - Password: containerInfo.Password, + Host: host, + Port: port, + DB: dbname, + User: user, + Password: password, Schema: "public", MultiFile: false, File: "", @@ -127,11 +131,13 @@ func runExactMatchTestWithContext(t *testing.T, ctx context.Context, testDataDir } func runTenantSchemaTest(t *testing.T, testDataDir string) { - ctx := context.Background() + // Setup PostgreSQL + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() - // Setup PostgreSQL container - containerInfo := testutil.SetupTestPostgres(ctx, t) - defer containerInfo.Terminate(ctx, t) + // Connect to database + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() // Load public schema types first publicSQL, err := os.ReadFile(fmt.Sprintf("../../testdata/dump/%s/public.sql", testDataDir)) @@ -139,7 +145,7 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { t.Fatalf("Failed to read public.sql: %v", err) } - _, err = containerInfo.Conn.Exec(string(publicSQL)) + _, err = conn.Exec(string(publicSQL)) if err != nil { t.Fatalf("Failed to load public types: %v", err) } @@ -147,7 +153,7 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { // Load utility functions (if util.sql exists) utilPath := fmt.Sprintf("../../testdata/dump/%s/util.sql", testDataDir) if utilSQL, err := os.ReadFile(utilPath); err == nil { - _, err = containerInfo.Conn.Exec(string(utilSQL)) + _, err = conn.Exec(string(utilSQL)) if err != nil { t.Fatalf("Failed to load utility functions from util.sql: %v", err) } @@ -158,7 +164,7 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { // Create two tenant schemas tenants := []string{"tenant1", "tenant2"} for _, tenant := range tenants { - _, err = containerInfo.Conn.Exec(fmt.Sprintf("CREATE SCHEMA %s", tenant)) + _, err = conn.Exec(fmt.Sprintf("CREATE SCHEMA %s", tenant)) if err != nil { t.Fatalf("Failed to create schema %s: %v", tenant, err) } @@ -174,13 +180,13 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { for _, tenant := range tenants { // Set search path to include public for the types, but target schema first quotedTenant := ir.QuoteIdentifier(tenant) - _, err = containerInfo.Conn.Exec(fmt.Sprintf("SET search_path TO %s, public", quotedTenant)) + _, err = conn.Exec(fmt.Sprintf("SET search_path TO %s, public", quotedTenant)) if err != nil { t.Fatalf("Failed to set search path to %s: %v", tenant, err) } // Execute the SQL - _, err = containerInfo.Conn.Exec(string(tenantSQL)) + _, err = conn.Exec(string(tenantSQL)) if err != nil { t.Fatalf("Failed to load SQL into schema %s: %v", tenant, err) } @@ -191,11 +197,11 @@ func runTenantSchemaTest(t *testing.T, testDataDir string) { for _, tenantName := range tenants { // Create dump configuration for this tenant config := &DumpConfig{ - Host: containerInfo.Host, - Port: containerInfo.Port, - DB: containerInfo.DBName, - User: containerInfo.User, - Password: containerInfo.Password, + Host: host, + Port: port, + DB: dbname, + User: user, + Password: password, Schema: tenantName, MultiFile: false, File: "", diff --git a/cmd/dump/dump_permission_integration_test.go b/cmd/dump/dump_permission_integration_test.go index 110d71ec..73acddc8 100644 --- a/cmd/dump/dump_permission_integration_test.go +++ b/cmd/dump/dump_permission_integration_test.go @@ -30,8 +30,27 @@ func TestDumpCommand_PermissionSuite(t *testing.T) { ctx := context.Background() // Start single PostgreSQL container for all permission tests - container := testutil.SetupTestPostgres(ctx, t) - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Run each permission test with its own isolated database t.Run("ProcedureAndFunctionSourceAccess", func(t *testing.T) { @@ -44,7 +63,14 @@ func TestDumpCommand_PermissionSuite(t *testing.T) { } // setupTestDatabase creates a new database with permission test roles -func setupTestDatabase(ctx context.Context, t *testing.T, container *testutil.TestPostgres, dbName string) *sql.DB { +func setupTestDatabase(ctx context.Context, t *testing.T, container *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, dbName string) *sql.DB { // Create the database _, err := container.Conn.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", dbName)) if err != nil { @@ -101,7 +127,14 @@ func getRoleNames(dbName string) (restrictedRole string, regularUser string) { } // testIgnoredObjects tests procedures ignored via .pgschemaignore -func testIgnoredObjects(t *testing.T, ctx context.Context, container *testutil.TestPostgres, dbName string) { +func testIgnoredObjects(t *testing.T, ctx context.Context, container *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, dbName string) { // This test verifies that when procedures/functions are explicitly ignored // via .pgschemaignore, permission issues should not cause the dump to fail @@ -252,7 +285,14 @@ patterns = ["*_restricted"] // testProcedureAndFunctionSourceAccess tests that procedure and function source code is readable // via p.prosrc even when information_schema.routines.routine_definition is NULL -func testProcedureAndFunctionSourceAccess(t *testing.T, ctx context.Context, container *testutil.TestPostgres, dbName string) { +func testProcedureAndFunctionSourceAccess(t *testing.T, ctx context.Context, container *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, dbName string) { // Setup isolated database dbConn := setupTestDatabase(ctx, t, container, dbName) defer dbConn.Close() diff --git a/cmd/ignore_integration_test.go b/cmd/ignore_integration_test.go index 1b75fd23..f3ea305b 100644 --- a/cmd/ignore_integration_test.go +++ b/cmd/ignore_integration_test.go @@ -6,7 +6,6 @@ package cmd // various database object types and ignore patterns including wildcards and negation. import ( - "context" "database/sql" "fmt" "os" @@ -28,11 +27,28 @@ func TestIgnoreIntegration(t *testing.T) { t.Skip("Skipping integration test in short mode") } - ctx := context.Background() - // Setup PostgreSQL container - containerInfo := testutil.SetupTestPostgres(ctx, t) - defer containerInfo.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create containerInfo struct to match old API for minimal changes + containerInfo := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Create the test schema with various object types createTestSchema(t, containerInfo.Conn) @@ -68,9 +84,27 @@ func TestIgnoreIntegration(t *testing.T) { t.Run("apply", func(t *testing.T) { // Create a fresh container for apply test to avoid fingerprint conflicts - ctx := context.Background() - applyContainerInfo := testutil.SetupTestPostgres(ctx, t) - defer applyContainerInfo.Terminate(ctx, t) + applyEmbeddedPG := testutil.SetupPostgres(t) + defer applyEmbeddedPG.Stop() + applyConn, applyHost, applyPort, applyDbname, applyUser, applyPassword := testutil.ConnectToPostgres(t, applyEmbeddedPG) + defer applyConn.Close() + + // Create applyContainerInfo struct to match old API + applyContainerInfo := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: applyConn, + Host: applyHost, + Port: applyPort, + DBName: applyDbname, + User: applyUser, + Password: applyPassword, + } // Create the test schema in the fresh container createTestSchema(t, applyContainerInfo.Conn) @@ -266,7 +300,14 @@ patterns = ["seq_temp_*"] } // testIgnoreDump tests the dump command with ignore functionality -func testIgnoreDump(t *testing.T, containerInfo *testutil.TestPostgres) { +func testIgnoreDump(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}) { // Create .pgschemaignore file cleanup := createIgnoreFile(t) defer cleanup() @@ -281,7 +322,14 @@ func testIgnoreDump(t *testing.T, containerInfo *testutil.TestPostgres) { // testIgnorePlanWithTriggerOnIgnoredTable tests that triggers can be defined on ignored tables // This tests the scenario where users manage triggers on externally-managed tables // without managing the table schema itself -func testIgnorePlanWithTriggerOnIgnoredTable(t *testing.T, containerInfo *testutil.TestPostgres) { +func testIgnorePlanWithTriggerOnIgnoredTable(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}) { // Create .pgschemaignore file - temp_* pattern will ignore temp_external_users cleanup := createIgnoreFile(t) defer cleanup() @@ -348,7 +396,14 @@ CREATE TRIGGER on_external_user_created } // testIgnorePlan tests the plan command with ignore functionality -func testIgnorePlan(t *testing.T, containerInfo *testutil.TestPostgres) { +func testIgnorePlan(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}) { // Create .pgschemaignore file cleanup := createIgnoreFile(t) defer cleanup() @@ -400,7 +455,14 @@ CREATE TABLE test_core_config ( // testIgnoreApply tests the apply command with ignore functionality // This test verifies that ignored objects are excluded from fingerprint calculation -func testIgnoreApply(t *testing.T, containerInfo *testutil.TestPostgres) { +func testIgnoreApply(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}) { // Create .pgschemaignore file cleanup := createIgnoreFile(t) defer cleanup() @@ -500,7 +562,14 @@ $$; } // executeIgnoreDumpCommand runs the dump command and returns the output -func executeIgnoreDumpCommand(t *testing.T, containerInfo *testutil.TestPostgres) string { +func executeIgnoreDumpCommand(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}) string { // Create a new root command with dump as subcommand rootCmd := &cobra.Command{ Use: "pgschema", @@ -547,7 +616,14 @@ func executeIgnoreDumpCommand(t *testing.T, containerInfo *testutil.TestPostgres } // executeIgnorePlanCommand runs the plan command and returns the output -func executeIgnorePlanCommand(t *testing.T, containerInfo *testutil.TestPostgres, schemaFile string) string { +func executeIgnorePlanCommand(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, schemaFile string) string { // Create plan configuration with shared embedded postgres for performance config := &planCmd.PlanConfig{ Host: containerInfo.Host, @@ -571,7 +647,14 @@ func executeIgnorePlanCommand(t *testing.T, containerInfo *testutil.TestPostgres } // executeIgnoreApplyCommandWithError runs the apply command and returns any error -func executeIgnoreApplyCommandWithError(containerInfo *testutil.TestPostgres, schemaFile string) error { +func executeIgnoreApplyCommandWithError(containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, schemaFile string) error { rootCmd := &cobra.Command{ Use: "pgschema", } diff --git a/cmd/include_integration_test.go b/cmd/include_integration_test.go index 8a76daef..81122f79 100644 --- a/cmd/include_integration_test.go +++ b/cmd/include_integration_test.go @@ -8,7 +8,7 @@ package cmd // the same organized file structure. import ( - "context" + "database/sql" "fmt" "os" "path/filepath" @@ -26,11 +26,28 @@ func TestIncludeIntegration(t *testing.T) { t.Skip("Skipping integration test in short mode") } - ctx := context.Background() - // Setup PostgreSQL container with specific database - containerInfo := testutil.SetupTestPostgres(ctx, t) - defer containerInfo.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create containerInfo struct to match old API for minimal changes + containerInfo := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Apply the include-based schema using the apply command applyIncludeSchema(t, containerInfo) @@ -47,7 +64,14 @@ func TestIncludeIntegration(t *testing.T) { } // applyIncludeSchema applies the testdata/include/main.sql schema using the apply command -func applyIncludeSchema(t *testing.T, containerInfo *testutil.TestPostgres) { +func applyIncludeSchema(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}) { mainSQLPath := "../testdata/include/main.sql" // Create a new root command with apply as subcommand @@ -81,7 +105,14 @@ func applyIncludeSchema(t *testing.T, containerInfo *testutil.TestPostgres) { } // executeMultiFileDump runs pgschema dump --multi-file using the CLI command -func executeMultiFileDump(t *testing.T, containerInfo *testutil.TestPostgres, outputPath string) { +func executeMultiFileDump(t *testing.T, containerInfo *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, outputPath string) { // Create a new root command with dump as subcommand rootCmd := &cobra.Command{ Use: "pgschema", diff --git a/cmd/migrate_integration_test.go b/cmd/migrate_integration_test.go index 1b95ae3c..ce8deaa8 100644 --- a/cmd/migrate_integration_test.go +++ b/cmd/migrate_integration_test.go @@ -34,7 +34,7 @@ func TestMain(m *testing.M) { // Create shared embedded postgres instance for all integration tests // This dramatically improves test performance (from ~60s to ~10s per test) - sharedEmbeddedPG = testutil.SetupPostgres(nil, testutil.WithShared()) + sharedEmbeddedPG = testutil.SetupPostgres(nil) defer sharedEmbeddedPG.Stop() // Run tests @@ -78,8 +78,27 @@ func TestPlanAndApply(t *testing.T) { testDataRoot := "../testdata/diff" // Start a single PostgreSQL container for all test cases - container := testutil.SetupTestPostgres(ctx, t) - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Get test filter from environment variable testFilter := os.Getenv("PGSCHEMA_TEST_FILTER") @@ -178,7 +197,14 @@ type testCase struct { } // runPlanAndApplyTest executes a single plan and apply test case with test-specific database -func runPlanAndApplyTest(t *testing.T, ctx context.Context, container *testutil.TestPostgres, tc testCase) { +func runPlanAndApplyTest(t *testing.T, ctx context.Context, container *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, tc testCase) { containerHost := container.Host portMapped := container.Port // Create a unique database name for this test case (replace invalid chars) @@ -241,7 +267,14 @@ func runPlanAndApplyTest(t *testing.T, ctx context.Context, container *testutil. } // testPlanOutputs tests all plan output formats against expected files -func testPlanOutputs(t *testing.T, container *testutil.TestPostgres, dbName, schemaFile, planSQLFile, planJSONFile, planTXTFile string) { +func testPlanOutputs(t *testing.T, container *struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string +}, dbName, schemaFile, planSQLFile, planJSONFile, planTXTFile string) { containerHost := container.Host portMapped := container.Port // Set fixed timestamp for generate mode to ensure deterministic output diff --git a/cmd/plan/plan_integration_test.go b/cmd/plan/plan_integration_test.go index 5021f00c..5cc3a14e 100644 --- a/cmd/plan/plan_integration_test.go +++ b/cmd/plan/plan_integration_test.go @@ -2,6 +2,7 @@ package plan import ( "context" + "database/sql" "fmt" "os" "path/filepath" @@ -20,11 +21,29 @@ func TestPlanCommand_DatabaseIntegration(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupTestPostgres(ctx, t) - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with initial schema - conn := container.Conn initialSQL := ` CREATE TABLE users ( @@ -117,11 +136,29 @@ func TestPlanCommand_OutputFormats(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupTestPostgres(ctx, t) - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup simple database schema - conn := container.Conn simpleSQL := ` CREATE TABLE users ( @@ -212,11 +249,29 @@ func TestPlanCommand_SchemaFiltering(t *testing.T) { var err error // Start PostgreSQL container - container := testutil.SetupTestPostgres(ctx, t) - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Setup database with multiple schemas - conn := container.Conn multiSchemaSQL := ` CREATE SCHEMA app; @@ -302,12 +357,30 @@ func TestPlanCommand_EmptyDatabase(t *testing.T) { t.Skip("Skipping integration test in short mode") } - ctx := context.Background() var err error // Start PostgreSQL container with empty database - container := testutil.SetupTestPostgres(ctx, t) - defer container.Terminate(ctx, t) + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() + + // Create container struct to match old API for minimal changes + container := &struct { + Conn *sql.DB + Host string + Port int + DBName string + User string + Password string + }{ + Conn: conn, + Host: host, + Port: port, + DBName: dbname, + User: user, + Password: password, + } // Create desired state schema file tmpDir := t.TempDir() diff --git a/cmd/schema_integration_test.go b/cmd/schema_integration_test.go index fe7054b7..58f8889d 100644 --- a/cmd/schema_integration_test.go +++ b/cmd/schema_integration_test.go @@ -24,11 +24,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { ctx := context.Background() - // Start PostgreSQL container - container := testutil.SetupTestPostgres(ctx, t) - defer container.Terminate(ctx, t) - - conn := container.Conn + // Start PostgreSQL + embeddedPG := testutil.SetupPostgres(t) + defer embeddedPG.Stop() + conn, host, port, dbname, user, password := testutil.ConnectToPostgres(t, embeddedPG) + defer conn.Close() // Test Case 1: Plan and Apply to tenant schema using CLI t.Run("cli_plan_and_apply_tenant_schema", func(t *testing.T) { @@ -61,11 +61,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { // Step 1: Generate plan using CLI planOutput, err := executePlanCommand( - container.Host, - container.Port, - container.DBName, - container.User, - container.Password, + host, + port, + dbname, + user, + password, "tenant", // Non-public schema desiredStateFile, ) @@ -82,11 +82,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { // Step 2: Apply changes using CLI err = executeApplyCommand( - container.Host, - container.Port, - container.DBName, - container.User, - container.Password, + host, + port, + dbname, + user, + password, "tenant", // Non-public schema desiredStateFile, ) @@ -169,11 +169,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { // Apply changes ONLY to app_a schema err = executeApplyCommand( - container.Host, - container.Port, - container.DBName, - container.User, - container.Password, + host, + port, + dbname, + user, + password, "app_a", // Target only app_a desiredStateFile, ) @@ -250,11 +250,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { // Step 1: Generate plan using CLI for mixed-case schema planOutput, err := executePlanCommand( - container.Host, - container.Port, - container.DBName, - container.User, - container.Password, + host, + port, + dbname, + user, + password, "MyApp", // Mixed-case schema desiredStateFile, ) @@ -271,11 +271,11 @@ func TestNonPublicSchemaOperations(t *testing.T) { // Step 2: Apply changes using CLI for mixed-case schema err = executeApplyCommand( - container.Host, - container.Port, - container.DBName, - container.User, - container.Password, + host, + port, + dbname, + user, + password, "MyApp", // Mixed-case schema desiredStateFile, ) diff --git a/internal/diff/diff_test.go b/internal/diff/diff_test.go index 7589b2c5..53ddeba6 100644 --- a/internal/diff/diff_test.go +++ b/internal/diff/diff_test.go @@ -17,7 +17,7 @@ var sharedTestPostgres *postgres.EmbeddedPostgres // TestMain sets up shared resources for all tests in this package func TestMain(m *testing.M) { // Create shared embedded postgres for all tests to dramatically improve performance - sharedTestPostgres = testutil.SetupPostgres(nil, testutil.WithShared()) + sharedTestPostgres = testutil.SetupPostgres(nil) defer sharedTestPostgres.Stop() // Run tests diff --git a/internal/plan/plan_test.go b/internal/plan/plan_test.go index 79adc5b6..7b2d1239 100644 --- a/internal/plan/plan_test.go +++ b/internal/plan/plan_test.go @@ -23,7 +23,7 @@ var sharedTestPostgres *postgres.EmbeddedPostgres // TestMain sets up shared resources for all tests in this package func TestMain(m *testing.M) { // Create shared embedded postgres for all tests to dramatically improve performance - sharedTestPostgres = testutil.SetupPostgres(nil, testutil.WithShared()) + sharedTestPostgres = testutil.SetupPostgres(nil) defer sharedTestPostgres.Stop() // Run tests diff --git a/internal/postgres/embedded.go b/internal/postgres/embedded.go index 426d2c71..1fce0b9d 100644 --- a/internal/postgres/embedded.go +++ b/internal/postgres/embedded.go @@ -17,10 +17,6 @@ import ( _ "github.com/jackc/pgx/v5/stdlib" ) -// ============================================================================ -// Type Definitions -// ============================================================================ - // PostgresVersion is an alias for the embedded-postgres version type. type PostgresVersion = embeddedpostgres.PostgresVersion @@ -46,10 +42,6 @@ type EmbeddedPostgresConfig struct { Password string } -// ============================================================================ -// Version Detection -// ============================================================================ - // DetectPostgresVersionFromDB connects to a database and detects its version // This is a convenience function that opens a connection, detects the version, and closes it func DetectPostgresVersionFromDB(host string, port int, database, user, password string) (PostgresVersion, error) { @@ -74,10 +66,6 @@ func DetectPostgresVersionFromDB(host string, port int, database, user, password return detectPostgresVersion(db) } -// ============================================================================ -// EmbeddedPostgres Lifecycle -// ============================================================================ - // StartEmbeddedPostgres starts a temporary embedded PostgreSQL instance func StartEmbeddedPostgres(config *EmbeddedPostgresConfig) (*EmbeddedPostgres, error) { // Create unique runtime path with timestamp (using nanoseconds for uniqueness) @@ -177,10 +165,6 @@ func (ep *EmbeddedPostgres) Stop() error { return nil } -// ============================================================================ -// EmbeddedPostgres Operations -// ============================================================================ - // GetConnectionDetails returns all connection details needed to connect to the embedded PostgreSQL instance func (ep *EmbeddedPostgres) GetConnectionDetails() (host string, port int, database, username, password string) { return ep.host, ep.port, ep.database, ep.username, ep.password @@ -217,10 +201,6 @@ func (ep *EmbeddedPostgres) ApplySchema(ctx context.Context, schema string, sql return nil } -// ============================================================================ -// Internal Helpers -// ============================================================================ - // findAvailablePort finds an available TCP port for PostgreSQL to use func findAvailablePort() (int, error) { listener, err := net.Listen("tcp", ":0") diff --git a/testutil/postgres.go b/testutil/postgres.go index 93e5e9bf..6ab6592a 100644 --- a/testutil/postgres.go +++ b/testutil/postgres.go @@ -13,39 +13,10 @@ import ( "github.com/pgschema/pgschema/ir" ) -// ============================================================================ -// PostgreSQL Setup -// ============================================================================ - -// PostgresOption is a functional option for configuring PostgreSQL setup -type PostgresOption func(*postgresConfig) - -// postgresConfig holds configuration for PostgreSQL setup -type postgresConfig struct { - shared bool // if true, use "shared" naming for runtime path -} - -// WithShared returns an option to create a shared PostgreSQL instance -// Shared instances are typically created once in TestMain and reused across tests -func WithShared() PostgresOption { - return func(c *postgresConfig) { - c.shared = true - } -} - // SetupPostgres creates a PostgreSQL instance for testing. // It uses the production postgres.EmbeddedPostgres implementation. // PostgreSQL version is determined from PGSCHEMA_POSTGRES_VERSION environment variable. -// -// Usage: -// - Per-test instance: testutil.SetupPostgres(t) -// - Shared instance: testutil.SetupPostgres(nil, testutil.WithShared()) -func SetupPostgres(t testing.TB, opts ...PostgresOption) *postgres.EmbeddedPostgres { - // Apply options - cfg := &postgresConfig{shared: false} - for _, opt := range opts { - opt(cfg) - } +func SetupPostgres(t testing.TB) *postgres.EmbeddedPostgres { // Determine PostgreSQL version from environment version := getPostgresVersion() @@ -71,10 +42,6 @@ func SetupPostgres(t testing.TB, opts ...PostgresOption) *postgres.EmbeddedPostg return embeddedPG } -// ============================================================================ -// Test Helpers -// ============================================================================ - // ParseSQLToIR is a test helper that parses SQL and returns its IR representation. // It applies the SQL to an embedded PostgreSQL instance, inspects it, and returns the IR. // The schema will be reset (dropped and recreated) to ensure clean state between test calls. @@ -164,10 +131,6 @@ func ConnectToPostgres(t testing.TB, embeddedPG *postgres.EmbeddedPostgres) (con return conn, host, port, dbname, user, password } -// ============================================================================ -// Internal Helpers -// ============================================================================ - // getPostgresVersion returns the PostgreSQL version to use for testing. // It reads from the PGSCHEMA_POSTGRES_VERSION environment variable, // defaulting to "17" if not set.