Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cmd/apply/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ var ApplyCmd = &cobra.Command{
Long: "Apply a migration plan to update a database schema. Either provide a desired state file (--file) to generate and apply a plan, or provide a pre-generated plan file (--plan) to execute directly.",
RunE: RunApply,
SilenceUsage: true,
PreRunE: util.PreRunEWithEnvVars(&applyDB, &applyUser),
PreRunE: util.PreRunEWithEnvVarsAndConnectionAndApp(&applyDB, &applyUser, &applyHost, &applyPort, &applyApplicationName),
}

func init() {
// Target database connection flags
ApplyCmd.Flags().StringVar(&applyHost, "host", util.GetEnvWithDefault("PGHOST", "localhost"), "Database server host (env: PGHOST)")
ApplyCmd.Flags().IntVar(&applyPort, "port", util.GetEnvIntWithDefault("PGPORT", 5432), "Database server port (env: PGPORT)")
ApplyCmd.Flags().StringVar(&applyHost, "host", "localhost", "Database server host (env: PGHOST)")
ApplyCmd.Flags().IntVar(&applyPort, "port", 5432, "Database server port (env: PGPORT)")
ApplyCmd.Flags().StringVar(&applyDB, "db", "", "Database name (required) (env: PGDATABASE)")
ApplyCmd.Flags().StringVar(&applyUser, "user", "", "Database user name (required) (env: PGUSER)")
ApplyCmd.Flags().StringVar(&applyPassword, "password", "", "Database password (optional, can also use PGPASSWORD env var)")
Expand All @@ -61,7 +61,7 @@ func init() {
ApplyCmd.Flags().BoolVar(&applyAutoApprove, "auto-approve", false, "Apply changes without prompting for approval")
ApplyCmd.Flags().BoolVar(&applyNoColor, "no-color", false, "Disable colored output")
ApplyCmd.Flags().StringVar(&applyLockTimeout, "lock-timeout", "", "Maximum time to wait for database locks (e.g., 30s, 5m, 1h)")
ApplyCmd.Flags().StringVar(&applyApplicationName, "application-name", util.GetEnvWithDefault("PGAPPNAME", "pgschema"), "Application name for database connection (visible in pg_stat_activity) (env: PGAPPNAME)")
ApplyCmd.Flags().StringVar(&applyApplicationName, "application-name", "pgschema", "Application name for database connection (visible in pg_stat_activity) (env: PGAPPNAME)")

// Mark file and plan as mutually exclusive
ApplyCmd.MarkFlagsMutuallyExclusive("file", "plan")
Expand Down
6 changes: 3 additions & 3 deletions cmd/dump/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ var DumpCmd = &cobra.Command{
Long: "Dump and output database schema information for a specific schema. Uses the --schema flag to target a particular schema (defaults to 'public').",
RunE: runDump,
SilenceUsage: true,
PreRunE: util.PreRunEWithEnvVars(&db, &user),
PreRunE: util.PreRunEWithEnvVarsAndConnection(&db, &user, &host, &port),
}

func init() {
DumpCmd.Flags().StringVar(&host, "host", util.GetEnvWithDefault("PGHOST", "localhost"), "Database server host (env: PGHOST)")
DumpCmd.Flags().IntVar(&port, "port", util.GetEnvIntWithDefault("PGPORT", 5432), "Database server port (env: PGPORT)")
DumpCmd.Flags().StringVar(&host, "host", "localhost", "Database server host (env: PGHOST)")
DumpCmd.Flags().IntVar(&port, "port", 5432, "Database server port (env: PGPORT)")
DumpCmd.Flags().StringVar(&db, "db", "", "Database name (required) (env: PGDATABASE)")
DumpCmd.Flags().StringVar(&user, "user", "", "Database user name (required) (env: PGUSER)")
DumpCmd.Flags().StringVar(&password, "password", "", "Database password (optional, can also use PGPASSWORD env var)")
Expand Down
35 changes: 2 additions & 33 deletions cmd/dump/dump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ func TestDumpCommand_EnvironmentVariables(t *testing.T) {
os.Setenv("PGDATABASE", "env-db")
os.Setenv("PGUSER", "env-user")

// Reinitialize the command flags to pick up env vars
// We can't easily do this without recreating the command, but we can test the helper functions
// Test that the PreRunE pattern works by testing the underlying helper functions
// The actual PreRunE integration is tested in the util package
if util.GetEnvWithDefault("PGHOST", "localhost") != "env-host" {
t.Errorf("Expected PGHOST env var to be 'env-host', got '%s'", util.GetEnvWithDefault("PGHOST", "localhost"))
}
Expand All @@ -234,37 +234,6 @@ func TestDumpCommand_EnvironmentVariables(t *testing.T) {
t.Errorf("Expected PGUSER env var to be 'env-user', got '%s'", util.GetEnvWithDefault("PGUSER", ""))
}
})

t.Run("EnvVarHelperFunctions", func(t *testing.T) {
// Test string helper
os.Setenv("TEST_STRING", "test-value")
if util.GetEnvWithDefault("TEST_STRING", "default") != "test-value" {
t.Errorf("Expected GetEnvWithDefault to return 'test-value', got '%s'", util.GetEnvWithDefault("TEST_STRING", "default"))
}

// Test with missing env var
os.Unsetenv("MISSING_VAR")
if util.GetEnvWithDefault("MISSING_VAR", "default") != "default" {
t.Errorf("Expected GetEnvWithDefault to return 'default', got '%s'", util.GetEnvWithDefault("MISSING_VAR", "default"))
}

// Test int helper
os.Setenv("TEST_INT", "12345")
if util.GetEnvIntWithDefault("TEST_INT", 0) != 12345 {
t.Errorf("Expected GetEnvIntWithDefault to return 12345, got %d", util.GetEnvIntWithDefault("TEST_INT", 0))
}

// Test int with invalid value (should return default)
os.Setenv("TEST_INVALID_INT", "not-a-number")
if util.GetEnvIntWithDefault("TEST_INVALID_INT", 999) != 999 {
t.Errorf("Expected GetEnvIntWithDefault to return default 999, got %d", util.GetEnvIntWithDefault("TEST_INVALID_INT", 999))
}

// Cleanup
os.Unsetenv("TEST_STRING")
os.Unsetenv("TEST_INT")
os.Unsetenv("TEST_INVALID_INT")
})
}

func TestDumpCommand_PgpassFile(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions cmd/plan/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ var PlanCmd = &cobra.Command{
Long: "Generate a migration plan to apply a desired schema state to a target database schema. Compares the desired state (from --file) with the current state of a specific schema (specified by --schema, defaults to 'public').",
RunE: runPlan,
SilenceUsage: true,
PreRunE: util.PreRunEWithEnvVars(&planDB, &planUser),
PreRunE: util.PreRunEWithEnvVarsAndConnection(&planDB, &planUser, &planHost, &planPort),
}

func init() {
// Target database connection flags
PlanCmd.Flags().StringVar(&planHost, "host", util.GetEnvWithDefault("PGHOST", "localhost"), "Database server host (env: PGHOST)")
PlanCmd.Flags().IntVar(&planPort, "port", util.GetEnvIntWithDefault("PGPORT", 5432), "Database server port (env: PGPORT)")
PlanCmd.Flags().StringVar(&planHost, "host", "localhost", "Database server host (env: PGHOST)")
PlanCmd.Flags().IntVar(&planPort, "port", 5432, "Database server port (env: PGPORT)")
PlanCmd.Flags().StringVar(&planDB, "db", "", "Database name (required) (env: PGDATABASE)")
PlanCmd.Flags().StringVar(&planUser, "user", "", "Database user name (required) (env: PGUSER)")
PlanCmd.Flags().StringVar(&planPassword, "password", "", "Database password (optional, can also use PGPASSWORD env var)")
Expand Down
27 changes: 27 additions & 0 deletions cmd/util/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ func GetEnvIntWithDefault(envVar string, defaultValue int) int {
// PreRunEWithEnvVars creates a PreRunE function that validates required database connection parameters
// It checks environment variables if the corresponding flags weren't explicitly set
func PreRunEWithEnvVars(dbPtr, userPtr *string) func(*cobra.Command, []string) error {
return PreRunEWithEnvVarsAndConnection(dbPtr, userPtr, nil, nil)
Copy link

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider using a more explicit approach by creating a struct to hold the optional parameters instead of passing multiple nil values. This would make the function calls clearer and reduce the risk of passing parameters in the wrong order.

Copilot uses AI. Check for mistakes.
}

// PreRunEWithEnvVarsAndConnection creates a PreRunE function that validates database connection parameters
// It checks environment variables if the corresponding flags weren't explicitly set
// This version also handles optional host, port, and application name parameters
func PreRunEWithEnvVarsAndConnection(dbPtr, userPtr *string, hostPtr *string, portPtr *int) func(*cobra.Command, []string) error {
return PreRunEWithEnvVarsAndConnectionAndApp(dbPtr, userPtr, hostPtr, portPtr, nil)
}

// PreRunEWithEnvVarsAndConnectionAndApp creates a PreRunE function that validates database connection parameters
// It checks environment variables if the corresponding flags weren't explicitly set
// This version handles all optional connection parameters including application name
func PreRunEWithEnvVarsAndConnectionAndApp(dbPtr, userPtr *string, hostPtr *string, portPtr *int, appNamePtr *string) func(*cobra.Command, []string) error {
return func(cmd *cobra.Command, args []string) error {
// Check if required values are available from environment variables
if GetEnvWithDefault("PGDATABASE", "") != "" && !cmd.Flags().Changed("db") {
Expand All @@ -38,6 +52,19 @@ func PreRunEWithEnvVars(dbPtr, userPtr *string) func(*cobra.Command, []string) e
*userPtr = GetEnvWithDefault("PGUSER", "")
}

// Check optional host and port if pointers provided
if hostPtr != nil && GetEnvWithDefault("PGHOST", "") != "" && !cmd.Flags().Changed("host") {
*hostPtr = GetEnvWithDefault("PGHOST", "")
}
if portPtr != nil && GetEnvIntWithDefault("PGPORT", 0) != 0 && !cmd.Flags().Changed("port") {
*portPtr = GetEnvIntWithDefault("PGPORT", 0)
}

// Check optional application name if pointer provided
if appNamePtr != nil && GetEnvWithDefault("PGAPPNAME", "") != "" && !cmd.Flags().Changed("application-name") {
*appNamePtr = GetEnvWithDefault("PGAPPNAME", "")
}

// Now validate that we have the required values
if *dbPtr == "" {
return fmt.Errorf("database name is required (use --db flag or PGDATABASE environment variable)")
Expand Down
114 changes: 114 additions & 0 deletions cmd/util/env_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package util

import (
"os"
"testing"
)

func TestGetEnvWithDefault(t *testing.T) {
// Test with existing env var
os.Setenv("TEST_STRING", "test-value")
if GetEnvWithDefault("TEST_STRING", "default") != "test-value" {
t.Errorf("Expected GetEnvWithDefault to return 'test-value', got '%s'", GetEnvWithDefault("TEST_STRING", "default"))
}

// Test with missing env var
os.Unsetenv("MISSING_VAR")
if GetEnvWithDefault("MISSING_VAR", "default") != "default" {
t.Errorf("Expected GetEnvWithDefault to return 'default', got '%s'", GetEnvWithDefault("MISSING_VAR", "default"))
}

// Test with empty env var (should return default)
os.Setenv("EMPTY_VAR", "")
if GetEnvWithDefault("EMPTY_VAR", "default") != "default" {
t.Errorf("Expected GetEnvWithDefault to return 'default' for empty var, got '%s'", GetEnvWithDefault("EMPTY_VAR", "default"))
}

// Cleanup
os.Unsetenv("TEST_STRING")
os.Unsetenv("EMPTY_VAR")
}

func TestGetEnvIntWithDefault(t *testing.T) {
// Test with valid int env var
os.Setenv("TEST_INT", "12345")
if GetEnvIntWithDefault("TEST_INT", 0) != 12345 {
t.Errorf("Expected GetEnvIntWithDefault to return 12345, got %d", GetEnvIntWithDefault("TEST_INT", 0))
}

// Test with invalid int value (should return default)
os.Setenv("TEST_INVALID_INT", "not-a-number")
if GetEnvIntWithDefault("TEST_INVALID_INT", 999) != 999 {
t.Errorf("Expected GetEnvIntWithDefault to return default 999, got %d", GetEnvIntWithDefault("TEST_INVALID_INT", 999))
}

// Test with missing env var
os.Unsetenv("MISSING_INT_VAR")
if GetEnvIntWithDefault("MISSING_INT_VAR", 777) != 777 {
t.Errorf("Expected GetEnvIntWithDefault to return default 777, got %d", GetEnvIntWithDefault("MISSING_INT_VAR", 777))
}

// Test with empty env var (should return default)
os.Setenv("EMPTY_INT_VAR", "")
if GetEnvIntWithDefault("EMPTY_INT_VAR", 888) != 888 {
t.Errorf("Expected GetEnvIntWithDefault to return default 888 for empty var, got %d", GetEnvIntWithDefault("EMPTY_INT_VAR", 888))
}

// Cleanup
os.Unsetenv("TEST_INT")
os.Unsetenv("TEST_INVALID_INT")
os.Unsetenv("EMPTY_INT_VAR")
}

func TestPreRunEWithEnvVars(t *testing.T) {
// Setup test environment
os.Setenv("PGDATABASE", "test-db")
os.Setenv("PGUSER", "test-user")
os.Setenv("PGHOST", "test-host")
os.Setenv("PGPORT", "1234")
os.Setenv("PGAPPNAME", "test-app")

// Test variables to be populated
var db, user, host, appName string
var port int

// Create a mock command that simulates flags not being changed
// In real usage, cobra.Command would handle this, but for testing we'll call the function directly
preRunFunc := PreRunEWithEnvVarsAndConnectionAndApp(&db, &user, &host, &port, &appName)

// We can't easily test this without a real cobra.Command, but we can test the underlying logic
// by directly calling the helper functions which are used in the PreRun function

// Test that environment variables are read correctly
if GetEnvWithDefault("PGDATABASE", "") != "test-db" {
t.Errorf("Expected PGDATABASE to be 'test-db', got '%s'", GetEnvWithDefault("PGDATABASE", ""))
}

if GetEnvWithDefault("PGUSER", "") != "test-user" {
t.Errorf("Expected PGUSER to be 'test-user', got '%s'", GetEnvWithDefault("PGUSER", ""))
}

if GetEnvWithDefault("PGHOST", "") != "test-host" {
t.Errorf("Expected PGHOST to be 'test-host', got '%s'", GetEnvWithDefault("PGHOST", ""))
}

if GetEnvIntWithDefault("PGPORT", 0) != 1234 {
t.Errorf("Expected PGPORT to be 1234, got %d", GetEnvIntWithDefault("PGPORT", 0))
}

if GetEnvWithDefault("PGAPPNAME", "") != "test-app" {
t.Errorf("Expected PGAPPNAME to be 'test-app', got '%s'", GetEnvWithDefault("PGAPPNAME", ""))
}

// Cleanup
os.Unsetenv("PGDATABASE")
os.Unsetenv("PGUSER")
os.Unsetenv("PGHOST")
os.Unsetenv("PGPORT")
os.Unsetenv("PGAPPNAME")

// Verify preRunFunc was created (basic sanity check)
if preRunFunc == nil {
t.Error("PreRunEWithEnvVarsAndConnectionAndApp should return a non-nil function")
}
}