diff --git a/backends/postgres/postgres_backend.go b/backends/postgres/postgres_backend.go index 9e4e8d9..247c4ba 100644 --- a/backends/postgres/postgres_backend.go +++ b/backends/postgres/postgres_backend.go @@ -7,7 +7,6 @@ import ( "fmt" "net/url" "os" - "strings" "sync" "time" @@ -34,6 +33,9 @@ import ( var migrationsFS embed.FS const ( + queryParamSSLMode = "sslmode" + queryParamMigrationsTable = "x-migrations-table" + JobQuery = `SELECT id,fingerprint,queue,status,deadline,payload,retries,max_retries,run_after,ran_at,created_at,error FROM neoq_jobs WHERE id = $1 @@ -324,6 +326,7 @@ func txFromContext(ctx context.Context) (t pgx.Tx, err error) { func (p *PgBackend) initializeDB() (err error) { migrations, err := iofs.New(migrationsFS, "migrations") if err != nil { + err = fmt.Errorf("unable to run migrations, error during iofs new: %w", err) p.logger.Error("unable to run migrations", slog.Any("error", err)) return } @@ -332,36 +335,16 @@ func (p *PgBackend) initializeDB() (err error) { // it with pgx-specific config params like `max_conn_count`. However, `go-migrate` uses `pq` under the hood, and // these `pgx` config params cause `pq` to throw an "unknown config parameter" error when they're encountered. // So we must first sanitize connection strings for pq - var pgxCfg *pgx.ConnConfig - pgxCfg, err = pgx.ParseConfig(p.config.ConnectionString) + pqConnectionString, err := GetPQConnectionString(p.config.ConnectionString) if err != nil { + err = fmt.Errorf("unable to run migrations, error parsing connection string: %w", err) p.logger.Error("unable to run migrations", slog.Any("error", err)) return } - // nil TLSConfig means "sslmode=disable" was set on the connection - sslMode := "verify-ca" - if pgxCfg.TLSConfig == nil { - sslMode = "disable" - } else if pgxCfg.TLSConfig.InsecureSkipVerify { - sslMode = "require" - } - if dbURL, err := url.Parse(pgxCfg.ConnString()); err == nil && - strings.HasPrefix(dbURL.Scheme, "postgres") { - val := dbURL.Query() - if v := val.Get("sslmode"); v != "" { - sslMode = v // set sslmode from existing connection string - } - } - - pqConnectionString := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=%s&x-migrations-table=neoq_schema_migrations", - pgxCfg.User, - url.QueryEscape(pgxCfg.Password), - pgxCfg.Host, - pgxCfg.Database, - sslMode) m, err := migrate.NewWithSourceInstance("iofs", migrations, pqConnectionString) if err != nil { + err = fmt.Errorf("unable to run migrations, could not create new source: %w", err) p.logger.Error("unable to run migrations", slog.Any("error", err)) return } @@ -370,6 +353,7 @@ func (p *PgBackend) initializeDB() (err error) { err = m.Up() if err != nil && !errors.Is(err, migrate.ErrNoChange) { + err = fmt.Errorf("unable to run migrations, could not apply up migration: %w", err) p.logger.Error("unable to run migrations", slog.Any("error", err)) return } @@ -1030,3 +1014,51 @@ func (p *PgBackend) acquire(ctx context.Context) (conn *pgxpool.Conn, err error) func withJobContext(ctx context.Context, j *jobs.Job) context.Context { return context.WithValue(ctx, internal.JobCtxVarKey, j) } + +func GetPQConnectionString(connectionString string) (string, error) { + pgxCfg, err := pgx.ParseConfig(connectionString) + if err != nil { + return "", fmt.Errorf("unable to parse connection string %s: %w", connectionString, err) + } + + dbURI, err := url.Parse(pgxCfg.ConnString()) + if err != nil { + return "", fmt.Errorf("unable to parse connection string %s: %w", connectionString, err) + } + + if dbURI.String() == "" { + return "", fmt.Errorf("connection string cannot be empty") + } + + scheme := dbURI.Scheme + if scheme == "" { + // This is probably a pq-style string, return it as-is + return connectionString, nil + } + + if scheme != "postgres" && scheme != "postgresql" { + // This isn't a postgresql URI-style string (postgres://hostname/db) + return "", fmt.Errorf("only postgres and postgresql scheme URIs are supported, invalid connection string: %s", connectionString) + } + + sslMode := "verify-ca" + if pgxCfg.TLSConfig == nil { + sslMode = "disable" + } else if pgxCfg.TLSConfig.InsecureSkipVerify { + sslMode = "require" + } + + // Prefer original sslmode if it was set + originalSSLMode := dbURI.Query().Get(queryParamSSLMode) + if originalSSLMode != "" { + sslMode = originalSSLMode + } + + // Clear out original query, use only query params that are pq compatible + query := url.Values{} + query.Set(queryParamSSLMode, sslMode) + query.Set(queryParamMigrationsTable, "neoq_schema_migrations") + dbURI.RawQuery = query.Encode() + + return dbURI.String(), nil +} diff --git a/backends/postgres/postgres_backend_test.go b/backends/postgres/postgres_backend_test.go index bedc70a..9e2bcb5 100644 --- a/backends/postgres/postgres_backend_test.go +++ b/backends/postgres/postgres_backend_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/url" "os" "strings" "sync" @@ -825,3 +826,176 @@ func TestFutureJobProcessing(t *testing.T) { t.Error("job ran before RunAfter") } } + +func TestGetPQConnectionString(t *testing.T) { + tests := []struct { + name string + input string + want string + wantErr bool + }{ + { + name: "standard input", + input: "postgres://username:password@hostname:5432/database", + want: "postgres://username:password@hostname:5432/database?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "standard input with postgresql scheme", + input: "postgresql://username:password@hostname:5432/database", + want: "postgresql://username:password@hostname:5432/database?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "no port number", + input: "postgres://username:password@hostname/database", + want: "postgres://username:password@hostname/database?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "custom port number", + input: "postgres://username:password@hostname:1234/database", + want: "postgres://username:password@hostname:1234/database?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "custom sslmode=disable", + input: "postgres://username:password@hostname:5432/database?sslmode=disable", + want: "postgres://username:password@hostname:5432/database?sslmode=disable&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "custom sslmode=allow", + input: "postgres://username:password@hostname:5432/database?sslmode=allow", + want: "postgres://username:password@hostname:5432/database?sslmode=allow&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "custom sslmode=prefer", + input: "postgres://username:password@hostname:5432/database?sslmode=prefer", + want: "postgres://username:password@hostname:5432/database?sslmode=prefer&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "custom sslmode=require", + input: "postgres://username:password@hostname:5432/database?sslmode=require", + want: "postgres://username:password@hostname:5432/database?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "custom sslmode=verify-ca", + input: "postgres://username:password@hostname:5432/database?sslmode=verify-ca", + want: "postgres://username:password@hostname:5432/database?sslmode=verify-ca&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "custom sslmode=verify-full", + input: "postgres://username:password@hostname:5432/database?sslmode=verify-full", + want: "postgres://username:password@hostname:5432/database?sslmode=verify-full&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "encoded password is preserved", + input: "postgres://username:pass%21%40%23$%25%5E&%2A%28%29%3A%2F%3Fword@hostname:5432/database", + want: fmt.Sprintf( + "postgres://%s@hostname:5432/database?sslmode=require&x-migrations-table=neoq_schema_migrations", + url.UserPassword("username", "pass!@#$%^&*():/?word").String(), + ), + wantErr: false, + }, + { + name: "multiple hostnames", + input: "postgres://username:password@hostname1,hostname2,hostname3:5432/database", + want: "postgres://username:password@hostname1,hostname2,hostname3:5432/database?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + + // Examples connstrings from https://www.postgresql.org/docs/16/libpq-connect.html + { + name: "valid empty postgresql scheme input", + input: "postgresql://", + want: "postgresql:?sslmode=disable&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "hostname localhost", + input: "postgresql://localhost", + want: "postgresql://localhost?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "hostname localhost with custom port", + input: "postgresql://localhost:5433", + want: "postgresql://localhost:5433?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "non-default database", + input: "postgresql://localhost/mydb", + want: "postgresql://localhost/mydb?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "username", + input: "postgresql://user@localhost", + want: "postgresql://user@localhost?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "username and password", + input: "postgresql://user:secret@localhost", + want: "postgresql://user:secret@localhost?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "custom params are ignored", + input: "postgresql://other@localhost/otherdb?connect_timeout=10&application_name=myapp", + want: "postgresql://other@localhost/otherdb?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "multiple hostnames and ports", + input: "postgresql://host1:123,host2:456/somedb?target_session_attrs=any&application_name=myapp", + want: "postgresql://host1:123,host2:456/somedb?sslmode=require&x-migrations-table=neoq_schema_migrations", + wantErr: false, + }, + { + name: "pq-style input is returned as-is", + input: "host=localhost port=5432 dbname=mydb connect_timeout=10", + want: "host=localhost port=5432 dbname=mydb connect_timeout=10", + wantErr: false, + }, + + // Inputs that cause errors + { + name: "non-postgres scheme returns error", + input: "https://user:password@example.com:443/path?query=true", + want: "", + wantErr: true, + }, + { + name: "empty input returns error", + input: "", + want: "", + wantErr: true, + }, + { + name: "custom bad sslmode=foo returns error", + input: "postgres://username:password@hostname:1234/database?sslmode=foo", + want: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := postgres.GetPQConnectionString(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("GetPQConnectionString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("GetPQConnectionString()\ngot = %v\nwant = %v", got, tt.want) + } + }) + } +}