diff --git a/go.mod b/go.mod index 32363b994..0765ee474 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/golang/protobuf v1.5.4 github.com/gomodule/redigo v1.9.2 github.com/gorilla/handlers v1.5.2 - github.com/lib/pq v1.10.9 + github.com/lib/pq v1.11.1 github.com/lomik/og-rek v0.0.0-20170411191824-628eefeb8d80 github.com/lomik/zapwriter v0.0.0-20210624082824-c1161d1eb463 github.com/maruel/natural v1.1.1 diff --git a/go.sum b/go.sum index d0445f0fa..880fdf7fb 100644 --- a/go.sum +++ b/go.sum @@ -117,8 +117,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.11.1 h1:wuChtj2hfsGmmx3nf1m7xC2XpK6OtelS2shMY+bGMtI= +github.com/lib/pq v1.11.1/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/lomik/og-rek v0.0.0-20170411191824-628eefeb8d80 h1:KVyDGUXjVOdHQt24wIgY4ZdGFXHtQHLWw0L/MAK3Kb0= github.com/lomik/og-rek v0.0.0-20170411191824-628eefeb8d80/go.mod h1:T7SQVaLtK7mcQIEVzveZVJzsDQpAtzTs2YoezrIBdvI= github.com/lomik/zapwriter v0.0.0-20210624082824-c1161d1eb463 h1:SN/0TEkyYpp8tit79JPUnecebCGZsXiYYPxN8i3I6Rk= diff --git a/vendor/github.com/lib/pq/.gitattributes b/vendor/github.com/lib/pq/.gitattributes new file mode 100644 index 000000000..dfdb8b771 --- /dev/null +++ b/vendor/github.com/lib/pq/.gitattributes @@ -0,0 +1 @@ +*.sh text eol=lf diff --git a/vendor/github.com/lib/pq/CHANGELOG.md b/vendor/github.com/lib/pq/CHANGELOG.md new file mode 100644 index 000000000..338b128bc --- /dev/null +++ b/vendor/github.com/lib/pq/CHANGELOG.md @@ -0,0 +1,143 @@ +v1.11.1 (2025-01-29) +-------------------- +This fixes two regressions present in the v1.11.0 release: + +- Fix build on 32bit systems, Windows, and Plan 9 ([#1253]). + +- Named []byte types and pointers to []byte (e.g. `*[]byte`, `json.RawMessage`) + would be treated as an array instead of bytea ([#1252]). + +[#1252]: https://github.com/lib/pq/pull/1252 +[#1253]: https://github.com/lib/pq/pull/1253 + +v1.11.0 (2025-01-28) +-------------------- +This version of pq requires Go 1.21 or newer. + +pq now supports only maintained PostgreSQL releases, which is PostgreSQL 14 and +newer. Previously PostgreSQL 8.4 and newer were supported. + +### Features + +- The `pq.Error.Error()` text includes the position of the error (if reported + by PostgreSQL) and SQLSTATE code ([#1219], [#1224]): + + pq: column "columndoesntexist" does not exist at column 8 (42703) + pq: syntax error at or near ")" at position 2:71 (42601) + +- The `pq.Error.ErrorWithDetail()` method prints a more detailed multiline + message, with the Detail, Hint, and error position (if any) ([#1219]): + + ERROR: syntax error at or near ")" (42601) + CONTEXT: line 12, column 1: + + 10 | name varchar, + 11 | version varchar, + 12 | ); + ^ + +- Add `Config`, `NewConfig()`, and `NewConnectorConfig()` to supply connection + details in a more structured way ([#1240]). + +- Support `hostaddr` and `$PGHOSTADDR` ([#1243]). + +- Support multiple values in `host`, `port`, and `hostaddr`, which are each + tried in order, or randomly if `load_balance_hosts=random` is set ([#1246]). + +- Support `target_session_attrs` connection parameter ([#1246]). + +- Support [`sslnegotiation`] to use SSL without negotiation ([#1180]). + +- Allow using a custom `tls.Config`, for example for encrypted keys ([#1228]). + +- Add `PQGO_DEBUG=1` print the communication with PostgreSQL to stderr, to aid + in debugging, testing, and bug reports ([#1223]). + +- Add support for NamedValueChecker interface ([#1125], [#1238]). + + +### Fixes + +- Match HOME directory lookup logic with libpq: prefer $HOME over /etc/passwd, + ignore ENOTDIR errors, and use APPDATA on Windows ([#1214]). + +- Fix `sslmode=verify-ca` verifying the hostname anyway when connecting to a DNS + name (rather than IP) ([#1226]). + +- Correctly detect pre-protocol errors such as the server not being able to fork + or running out of memory ([#1248]). + +- Fix build with wasm ([#1184]), appengine ([#745]), and Plan 9 ([#1133]). + +- Deprecate and type alias `pq.NullTime` to `sql.NullTime` ([#1211]). + +- Enforce integer limits of the Postgres wire protocol ([#1161]). + +- Accept the `passfile` connection parameter to override `PGPASSFILE` ([#1129]). + +- Fix connecting to socket on Windows systems ([#1179]). + +- Don't perform a permission check on the .pgpass file on Windows ([#595]). + +- Warn about incorrect .pgpass permissions ([#595]). + +- Don't set extra_float_digits ([#1212]). + +- Decode bpchar into a string ([#949]). + +- Fix panic in Ping() by not requiring CommandComplete or EmptyQueryResponse in + simpleQuery() ([#1234]) + +- Recognize bit/varbit ([#743]) and float types ([#1166]) in ColumnTypeScanType(). + +- Accept `PGGSSLIB` and `PGKRBSRVNAME` environment variables ([#1143]). + +- Handle ErrorResponse in readReadyForQuery and return proper error ([#1136]). + +- CopyIn() and CopyInSchema() now work if the list of columns is empty, in which + case it will copy all columns ([#1239]). + +- Treat nil []byte in query parameters as nil/NULL rather than `""` ([#838]). + +- Accept multiple authentication methods before checking AuthOk, which improves + compatibility with PgPool-II ([#1188]). + +[`sslnegotiation`]: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLNEGOTIATION +[#595]: https://github.com/lib/pq/pull/595 +[#745]: https://github.com/lib/pq/pull/745 +[#743]: https://github.com/lib/pq/pull/743 +[#838]: https://github.com/lib/pq/pull/838 +[#949]: https://github.com/lib/pq/pull/949 +[#1125]: https://github.com/lib/pq/pull/1125 +[#1129]: https://github.com/lib/pq/pull/1129 +[#1133]: https://github.com/lib/pq/pull/1133 +[#1136]: https://github.com/lib/pq/pull/1136 +[#1143]: https://github.com/lib/pq/pull/1143 +[#1161]: https://github.com/lib/pq/pull/1161 +[#1166]: https://github.com/lib/pq/pull/1166 +[#1179]: https://github.com/lib/pq/pull/1179 +[#1180]: https://github.com/lib/pq/pull/1180 +[#1184]: https://github.com/lib/pq/pull/1184 +[#1188]: https://github.com/lib/pq/pull/1188 +[#1211]: https://github.com/lib/pq/pull/1211 +[#1212]: https://github.com/lib/pq/pull/1212 +[#1214]: https://github.com/lib/pq/pull/1214 +[#1219]: https://github.com/lib/pq/pull/1219 +[#1223]: https://github.com/lib/pq/pull/1223 +[#1224]: https://github.com/lib/pq/pull/1224 +[#1226]: https://github.com/lib/pq/pull/1226 +[#1228]: https://github.com/lib/pq/pull/1228 +[#1234]: https://github.com/lib/pq/pull/1234 +[#1238]: https://github.com/lib/pq/pull/1238 +[#1239]: https://github.com/lib/pq/pull/1239 +[#1240]: https://github.com/lib/pq/pull/1240 +[#1243]: https://github.com/lib/pq/pull/1243 +[#1246]: https://github.com/lib/pq/pull/1246 +[#1248]: https://github.com/lib/pq/pull/1248 + + +v1.10.9 (2023-04-26) +-------------------- +- Fixes backwards incompat bug with 1.13. + +- Fixes pgpass issue diff --git a/vendor/github.com/lib/pq/LICENSE b/vendor/github.com/lib/pq/LICENSE new file mode 100644 index 000000000..6a77dc4fb --- /dev/null +++ b/vendor/github.com/lib/pq/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2011-2013, 'pq' Contributors. Portions Copyright (c) 2011 Blake Mizerany + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/lib/pq/LICENSE.md b/vendor/github.com/lib/pq/LICENSE.md deleted file mode 100644 index 5773904a3..000000000 --- a/vendor/github.com/lib/pq/LICENSE.md +++ /dev/null @@ -1,8 +0,0 @@ -Copyright (c) 2011-2013, 'pq' Contributors -Portions Copyright (C) 2011 Blake Mizerany - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/lib/pq/README.md b/vendor/github.com/lib/pq/README.md index 126ee5d35..7abcf7241 100644 --- a/vendor/github.com/lib/pq/README.md +++ b/vendor/github.com/lib/pq/README.md @@ -1,13 +1,18 @@ -# pq - A pure Go postgres driver for Go's database/sql package +pq is a Go PostgreSQL driver for database/sql. -[![GoDoc](https://godoc.org/github.com/lib/pq?status.svg)](https://pkg.go.dev/github.com/lib/pq?tab=doc) +All [maintained versions of PostgreSQL] are supported. Older versions may work, +but this is not tested. -## Install +API docs: https://pkg.go.dev/github.com/lib/pq - go get github.com/lib/pq +Install with: -## Features + go get github.com/lib/pq@latest +[maintained versions of PostgreSQL]: https://www.postgresql.org/support/versioning + +Features +-------- * SSL * Handles bad connections for `database/sql` * Scan `time.Time` correctly (i.e. `timestamp[tz]`, `time[tz]`, `date`) @@ -21,16 +26,67 @@ * pgpass support * GSS (Kerberos) auth -## Tests +Running Tests +------------- +Tests need to be run against a PostgreSQL database; you can use Docker compose +to start one: + + docker compose up -d + +This starts the latest PostgreSQL; use `docker compose up -d pg«v»` to start a +different version. + +In addition, your `/etc/hosts` currently needs an entry: + + 127.0.0.1 postgres postgres-invalid + +Or you can use any other PostgreSQL instance; see +`testdata/init/docker-entrypoint-initdb.d` for the required setup. You can use +the standard `PG*` environment variables to control the connection details; it +uses the following defaults: + + PGHOST=localhost + PGDATABASE=pqgo + PGUSER=pqgo + PGSSLMODE=disable + PGCONNECT_TIMEOUT=20 + +`PQTEST_BINARY_PARAMETERS` can be used to add `binary_parameters=yes` to all +connection strings: + + PQTEST_BINARY_PARAMETERS=1 go test + +Tests can be run against pgbouncer with: + + docker compose up -d pgbouncer pg18 + PGPORT=6432 go test ./... + +and pgpool with: -`go test` is used for testing. See [TESTS.md](TESTS.md) for more details. + docker compose up -d pgpool pg18 + PGPORT=7432 go test ./... -## Status +You can use PQGO_DEBUG=1 to make the driver print the communication with +PostgreSQL to stderr; this works anywhere (test or applications) and can be +useful to debug protocol problems. -This package is currently in maintenance mode, which means: -1. It generally does not accept new features. -2. It does accept bug fixes and version compatability changes provided by the community. -3. Maintainers usually do not resolve reported issues. -4. Community members are encouraged to help each other with reported issues. +For example: -For users that require new features or reliable resolution of reported bugs, we recommend using [pgx](https://github.com/jackc/pgx) which is under active development. + % PQGO_DEBUG=1 go test -run TestSimpleQuery + CLIENT → Startup 69 "\x00\x03\x00\x00database\x00pqgo\x00user [..]" + SERVER ← (R) AuthRequest 4 "\x00\x00\x00\x00" + SERVER ← (S) ParamStatus 19 "in_hot_standby\x00off\x00" + [..] + SERVER ← (Z) ReadyForQuery 1 "I" + START conn.query + START conn.simpleQuery + CLIENT → (Q) Query 9 "select 1\x00" + SERVER ← (T) RowDescription 29 "\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x17\x00\x04\xff\xff\xff\xff\x00\x00" + SERVER ← (D) DataRow 7 "\x00\x01\x00\x00\x00\x011" + END conn.simpleQuery + END conn.query + SERVER ← (C) CommandComplete 9 "SELECT 1\x00" + SERVER ← (Z) ReadyForQuery 1 "I" + CLIENT → (X) Terminate 0 "" + PASS + ok github.com/lib/pq 0.010s diff --git a/vendor/github.com/lib/pq/TESTS.md b/vendor/github.com/lib/pq/TESTS.md deleted file mode 100644 index f05021115..000000000 --- a/vendor/github.com/lib/pq/TESTS.md +++ /dev/null @@ -1,33 +0,0 @@ -# Tests - -## Running Tests - -`go test` is used for testing. A running PostgreSQL -server is required, with the ability to log in. The -database to connect to test with is "pqgotest," on -"localhost" but these can be overridden using [environment -variables](https://www.postgresql.org/docs/9.3/static/libpq-envars.html). - -Example: - - PGHOST=/run/postgresql go test - -## Benchmarks - -A benchmark suite can be run as part of the tests: - - go test -bench . - -## Example setup (Docker) - -Run a postgres container: - -``` -docker run --expose 5432:5432 postgres -``` - -Run tests: - -``` -PGHOST=localhost PGPORT=5432 PGUSER=postgres PGSSLMODE=disable PGDATABASE=postgres go test -``` diff --git a/vendor/github.com/lib/pq/array.go b/vendor/github.com/lib/pq/array.go index 39c8f7e2e..910f335e1 100644 --- a/vendor/github.com/lib/pq/array.go +++ b/vendor/github.com/lib/pq/array.go @@ -19,14 +19,15 @@ var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() // slice of any dimension. // // For example: -// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) // -// var x []sql.NullInt64 -// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) +// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) +// +// var x []sql.NullInt64 +// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) // // Scanning multi-dimensional arrays is not supported. Arrays where the lower // bound is not one (such as `[0:0]={1}') are not supported. -func Array(a interface{}) interface { +func Array(a any) interface { driver.Valuer sql.Scanner } { @@ -76,7 +77,7 @@ type ArrayDelimiter interface { type BoolArray []bool // Scan implements the sql.Scanner interface. -func (a *BoolArray) Scan(src interface{}) error { +func (a *BoolArray) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -150,7 +151,7 @@ func (a BoolArray) Value() (driver.Value, error) { type ByteaArray [][]byte // Scan implements the sql.Scanner interface. -func (a *ByteaArray) Scan(src interface{}) error { +func (a *ByteaArray) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -176,7 +177,7 @@ func (a *ByteaArray) scanBytes(src []byte) error { for i, v := range elems { b[i], err = parseBytea(v) if err != nil { - return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) + return fmt.Errorf("could not parse bytea array index %d: %w", i, err) } } *a = b @@ -222,7 +223,7 @@ func (a ByteaArray) Value() (driver.Value, error) { type Float64Array []float64 // Scan implements the sql.Scanner interface. -func (a *Float64Array) Scan(src interface{}) error { +func (a *Float64Array) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -246,8 +247,9 @@ func (a *Float64Array) scanBytes(src []byte) error { } else { b := make(Float64Array, len(elems)) for i, v := range elems { - if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + b[i], err = strconv.ParseFloat(string(v), 64) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) } } *a = b @@ -284,7 +286,7 @@ func (a Float64Array) Value() (driver.Value, error) { type Float32Array []float32 // Scan implements the sql.Scanner interface. -func (a *Float32Array) Scan(src interface{}) error { +func (a *Float32Array) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -308,9 +310,9 @@ func (a *Float32Array) scanBytes(src []byte) error { } else { b := make(Float32Array, len(elems)) for i, v := range elems { - var x float64 - if x, err = strconv.ParseFloat(string(v), 32); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + x, err := strconv.ParseFloat(string(v), 32) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) } b[i] = float32(x) } @@ -345,7 +347,7 @@ func (a Float32Array) Value() (driver.Value, error) { // GenericArray implements the driver.Valuer and sql.Scanner interfaces for // an array or slice of any dimension. -type GenericArray struct{ A interface{} } +type GenericArray struct{ A any } func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { var assign func([]byte, reflect.Value) error @@ -354,7 +356,7 @@ func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]b // TODO calculate the assign function for other types // TODO repeat this section on the element type of arrays or slices (multidimensional) { - if reflect.PtrTo(rt).Implements(typeSQLScanner) { + if reflect.PointerTo(rt).Implements(typeSQLScanner) { // dest is always addressable because it is an element of a slice. assign = func(src []byte, dest reflect.Value) (err error) { ss := dest.Addr().Interface().(sql.Scanner) @@ -383,7 +385,7 @@ FoundType: } // Scan implements the sql.Scanner interface. -func (a GenericArray) Scan(src interface{}) error { +func (a GenericArray) Scan(src any) error { dpv := reflect.ValueOf(a.A) switch { case dpv.Kind() != reflect.Ptr: @@ -449,8 +451,9 @@ func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) for i, e := range elems { - if err := assign(e, values.Index(i)); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + err := assign(e, values.Index(i)) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) } } @@ -502,7 +505,7 @@ func (a GenericArray) Value() (driver.Value, error) { type Int64Array []int64 // Scan implements the sql.Scanner interface. -func (a *Int64Array) Scan(src interface{}) error { +func (a *Int64Array) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -526,8 +529,9 @@ func (a *Int64Array) scanBytes(src []byte) error { } else { b := make(Int64Array, len(elems)) for i, v := range elems { - if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + b[i], err = strconv.ParseInt(string(v), 10, 64) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) } } *a = b @@ -563,7 +567,7 @@ func (a Int64Array) Value() (driver.Value, error) { type Int32Array []int32 // Scan implements the sql.Scanner interface. -func (a *Int32Array) Scan(src interface{}) error { +func (a *Int32Array) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -589,7 +593,7 @@ func (a *Int32Array) scanBytes(src []byte) error { for i, v := range elems { x, err := strconv.ParseInt(string(v), 10, 32) if err != nil { - return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) } b[i] = int32(x) } @@ -626,7 +630,7 @@ func (a Int32Array) Value() (driver.Value, error) { type StringArray []string // Scan implements the sql.Scanner interface. -func (a *StringArray) Scan(src interface{}) error { +func (a *StringArray) Scan(src any) error { switch src := src.(type) { case []byte: return a.scanBytes(src) @@ -683,10 +687,10 @@ func (a StringArray) Value() (driver.Value, error) { return "{}", nil } -// appendArray appends rv to the buffer, returning the extended buffer and -// the delimiter used between elements. +// appendArray appends rv to the buffer, returning the extended buffer and the +// delimiter used between elements. // -// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice. +// Returns an error when n <= 0 or rv is not a reflect.Array or reflect.Slice. func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { var del string var err error @@ -728,7 +732,7 @@ func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { var del = "," var err error - var iv interface{} = rv.Interface() + var iv any = rv.Interface() if ad, ok := iv.(ArrayDelimiter); ok { del = ad.ArrayDelimiter() @@ -769,7 +773,11 @@ func appendArrayQuotedBytes(b, v []byte) []byte { } func appendValue(b []byte, v driver.Value) ([]byte, error) { - return append(b, encode(nil, v, 0)...), nil + enc, err := encode(v, 0) + if err != nil { + return nil, err + } + return append(b, enc...), nil } // parseArray extracts the dimensions and elements of an array represented in diff --git a/vendor/github.com/lib/pq/buf.go b/vendor/github.com/lib/pq/buf.go index 4b0a0a8f7..67ca60cc6 100644 --- a/vendor/github.com/lib/pq/buf.go +++ b/vendor/github.com/lib/pq/buf.go @@ -3,7 +3,10 @@ package pq import ( "bytes" "encoding/binary" + "errors" + "fmt" + "github.com/lib/pq/internal/proto" "github.com/lib/pq/oid" ) @@ -31,7 +34,7 @@ func (b *readBuf) int16() (n int) { func (b *readBuf) string() string { i := bytes.IndexByte(*b, 0) if i < 0 { - errorf("invalid message format; expected string terminator") + panic(errors.New("pq: invalid message format; expected string terminator")) } s := (*b)[:i] *b = (*b)[i+1:] @@ -69,8 +72,8 @@ func (b *writeBuf) string(s string) { b.buf = append(append(b.buf, s...), '\000') } -func (b *writeBuf) byte(c byte) { - b.buf = append(b.buf, c) +func (b *writeBuf) byte(c proto.RequestCode) { + b.buf = append(b.buf, byte(c)) } func (b *writeBuf) bytes(v []byte) { @@ -79,13 +82,19 @@ func (b *writeBuf) bytes(v []byte) { func (b *writeBuf) wrap() []byte { p := b.buf[b.pos:] + if len(p) > proto.MaxUint32 { + panic(fmt.Errorf("pq: message too large (%d > math.MaxUint32)", len(p))) + } binary.BigEndian.PutUint32(p, uint32(len(p))) return b.buf } -func (b *writeBuf) next(c byte) { +func (b *writeBuf) next(c proto.RequestCode) { p := b.buf[b.pos:] + if len(p) > proto.MaxUint32 { + panic(fmt.Errorf("pq: message too large (%d > math.MaxUint32)", len(p))) + } binary.BigEndian.PutUint32(p, uint32(len(p))) b.pos = len(b.buf) + 1 - b.buf = append(b.buf, c, 0, 0, 0, 0) + b.buf = append(b.buf, byte(c), 0, 0, 0, 0) } diff --git a/vendor/github.com/lib/pq/compose.yaml b/vendor/github.com/lib/pq/compose.yaml new file mode 100644 index 000000000..1027d4def --- /dev/null +++ b/vendor/github.com/lib/pq/compose.yaml @@ -0,0 +1,77 @@ +name: 'pqgo' + +services: + pgbouncer: + profiles: ['pgbouncer'] + image: 'cleanstart/pgbouncer:1.24' + ports: ['127.0.0.1:6432:6432'] + command: ['/init/pgbouncer.ini'] + volumes: ['./testdata/init:/init'] + environment: + 'PGBOUNCER_DATABASE': 'pqgo' + + pgpool: + profiles: ['pgpool'] + image: 'pgpool/pgpool:4.4.3' + ports: ['127.0.0.1:7432:7432'] + volumes: ['./testdata/init:/init'] + entrypoint: '/init/entry-pgpool.sh' + environment: + 'PGPOOL_PARAMS_PORT': '7432' + 'PGPOOL_PARAMS_BACKEND_HOSTNAME0': 'pg18' + + pg18: + image: 'postgres:18' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/init:/init'] + shm_size: '128mb' + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' + pg17: + profiles: ['pg17'] + image: 'postgres:17' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/init:/init'] + shm_size: '128mb' + user: 'root' + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' + pg16: + profiles: ['pg16'] + image: 'postgres:16' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/init:/init'] + shm_size: '128mb' + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' + pg15: + profiles: ['pg15'] + image: 'postgres:15' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/init:/init'] + shm_size: '128mb' + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' + pg14: + profiles: ['pg14'] + image: 'postgres:14' + ports: ['127.0.0.1:5432:5432'] + entrypoint: '/init/entry.sh' + volumes: ['./testdata/init:/init'] + shm_size: '128mb' + environment: + 'POSTGRES_DATABASE': 'pqgo' + 'POSTGRES_USER': 'pqgo' + 'POSTGRES_PASSWORD': 'unused' diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go index da4ff9de6..5e7ce20da 100644 --- a/vendor/github.com/lib/pq/conn.go +++ b/vendor/github.com/lib/pq/conn.go @@ -2,7 +2,6 @@ package pq import ( "bufio" - "bytes" "context" "crypto/md5" "crypto/sha256" @@ -12,30 +11,31 @@ import ( "errors" "fmt" "io" + "math" "net" "os" - "os/user" - "path" - "path/filepath" + "reflect" "strconv" "strings" "sync" "time" - "unicode" + "github.com/lib/pq/internal/pgpass" + "github.com/lib/pq/internal/pqsql" + "github.com/lib/pq/internal/pqutil" + "github.com/lib/pq/internal/proto" "github.com/lib/pq/oid" "github.com/lib/pq/scram" ) // Common error types var ( - ErrNotSupported = errors.New("pq: Unsupported command") - ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") + ErrNotSupported = errors.New("pq: unsupported command") + ErrInFailedTransaction = errors.New("pq: could not complete operation in a failed transaction") ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") - ErrSSLKeyUnknownOwnership = errors.New("pq: Could not get owner information for private key, may not be properly protected") - ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key has world access. Permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less") - - ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") + ErrCouldNotDetectUsername = errors.New("pq: could not detect default username; please provide one explicitly") + ErrSSLKeyUnknownOwnership = pqutil.ErrSSLKeyUnknownOwnership + ErrSSLKeyHasWorldPermissions = pqutil.ErrSSLKeyHasWorldPermissions errUnexpectedReady = errors.New("unexpected ReadyForQuery") errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") @@ -44,9 +44,32 @@ var ( // Compile time validation that our types implement the expected interfaces var ( - _ driver.Driver = Driver{} + _ driver.Driver = Driver{} + _ driver.ConnBeginTx = (*conn)(nil) + _ driver.ConnPrepareContext = (*conn)(nil) + _ driver.Execer = (*conn)(nil) //lint:ignore SA1019 x + _ driver.ExecerContext = (*conn)(nil) + _ driver.NamedValueChecker = (*conn)(nil) + _ driver.Pinger = (*conn)(nil) + _ driver.Queryer = (*conn)(nil) //lint:ignore SA1019 x + _ driver.QueryerContext = (*conn)(nil) + _ driver.SessionResetter = (*conn)(nil) + _ driver.Validator = (*conn)(nil) + _ driver.StmtExecContext = (*stmt)(nil) + _ driver.StmtQueryContext = (*stmt)(nil) ) +func init() { + sql.Register("postgres", &Driver{}) +} + +var debugProto = func() bool { + // Check for exactly "1" (rather than mere existence) so we can add + // options/flags in the future. I don't know if we ever want that, but it's + // nice to leave the option open. + return os.Getenv("PQGO_DEBUG") == "1" +}() + // Driver is the Postgres database driver. type Driver struct{} @@ -57,19 +80,27 @@ func (d Driver) Open(name string) (driver.Conn, error) { return Open(name) } -func init() { - sql.Register("postgres", &Driver{}) +// Parameters sent by PostgreSQL on startup. +type parameterStatus struct { + serverVersion int + currentLocation *time.Location + inHotStandby, defaultTransactionReadOnly sql.NullBool } -type parameterStatus struct { - // server version in the same format as server_version_num, or 0 if - // unavailable - serverVersion int +type format int - // the current location based on the TimeZone value of the session, if - // available - currentLocation *time.Location -} +const ( + formatText format = 0 + formatBinary format = 1 +) + +var ( + // One result-column format code with the value 1 (i.e. all binary). + colFmtDataAllBinary = []byte{0, 1, 0, 1} + + // No result-column format codes (i.e. all text). + colFmtDataAllText = []byte{0, 0} +) type transactionStatus byte @@ -88,10 +119,8 @@ func (s transactionStatus) String() string { case txnStatusInFailedTransaction: return "in a failed transaction" default: - errorf("unknown transactionStatus %d", s) + panic(fmt.Sprintf("pq: unknown transactionStatus %d", s)) } - - panic("not reached") } // Dialer is the dialer interface. It can be used to obtain more control over @@ -113,13 +142,13 @@ type defaultDialer struct { func (d defaultDialer) Dial(network, address string) (net.Conn, error) { return d.d.Dial(network, address) } -func (d defaultDialer) DialTimeout( - network, address string, timeout time.Duration, -) (net.Conn, error) { + +func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() return d.DialContext(ctx, network, address) } + func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { return d.d.DialContext(ctx, network, address) } @@ -133,16 +162,11 @@ type conn struct { txnFinish func() // Save connection arguments to use during CancelRequest. - dialer Dialer - opts values - - // Cancellation key data for use with CancelRequest messages. - processID int - secretKey int - + dialer Dialer + cfg Config parameterStatus parameterStatus - saveMessageType byte + saveMessageType proto.ResponseCode saveMessageBuffer []byte // If an error is set, this connection is bad and all public-facing @@ -150,26 +174,11 @@ type conn struct { // (ErrBadConn) or getForNext(). err syncErr - // If set, this connection should never use the binary format when - // receiving query results from prepared statements. Only provided for - // debugging. - disablePreparedBinaryResult bool - - // Whether to always send []byte parameters over as binary. Enables single - // round-trip mode for non-prepared Query calls. - binaryParameters bool - - // If true this connection is in the middle of a COPY - inCopy bool - - // If not nil, notices will be synchronously sent here - noticeHandler func(*Error) - - // If not nil, notifications will be synchronously sent here - notificationHandler func(*Notification) - - // GSSAPI context - gss GSS + processID, secretKey int // Cancellation key data for use with CancelRequest messages. + inCopy bool // If true this connection is in the middle of a COPY + noticeHandler func(*Error) // If not nil, notices will be synchronously sent here + notificationHandler func(*Notification) // If not nil, notifications will be synchronously sent here + gss GSS // GSSAPI context } type syncErr struct { @@ -206,125 +215,16 @@ func (e *syncErr) set(err error) { } } -// Handle driver-side settings in parsed connection string. -func (cn *conn) handleDriverSettings(o values) (err error) { - boolSetting := func(key string, val *bool) error { - if value, ok := o[key]; ok { - if value == "yes" { - *val = true - } else if value == "no" { - *val = false - } else { - return fmt.Errorf("unrecognized value %q for %s", value, key) - } - } - return nil - } - - err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) - if err != nil { - return err - } - return boolSetting("binary_parameters", &cn.binaryParameters) -} - -func (cn *conn) handlePgpass(o values) { - // if a password was supplied, do not process .pgpass - if _, ok := o["password"]; ok { - return - } - filename := os.Getenv("PGPASSFILE") - if filename == "" { - // XXX this code doesn't work on Windows where the default filename is - // XXX %APPDATA%\postgresql\pgpass.conf - // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 - userHome := os.Getenv("HOME") - if userHome == "" { - user, err := user.Current() - if err != nil { - return - } - userHome = user.HomeDir - } - filename = filepath.Join(userHome, ".pgpass") - } - fileinfo, err := os.Stat(filename) - if err != nil { - return - } - mode := fileinfo.Mode() - if mode&(0x77) != 0 { - // XXX should warn about incorrect .pgpass permissions as psql does - return - } - file, err := os.Open(filename) - if err != nil { - return - } - defer file.Close() - scanner := bufio.NewScanner(io.Reader(file)) - // From: https://github.com/tg/pgpass/blob/master/reader.go - for scanner.Scan() { - if scanText(scanner.Text(), o) { - break - } - } -} - -// GetFields is a helper function for scanText. -func getFields(s string) []string { - fs := make([]string, 0, 5) - f := make([]rune, 0, len(s)) - - var esc bool - for _, c := range s { - switch { - case esc: - f = append(f, c) - esc = false - case c == '\\': - esc = true - case c == ':': - fs = append(fs, string(f)) - f = f[:0] - default: - f = append(f, c) - } - } - return append(fs, string(f)) -} - -// ScanText assists HandlePgpass in it's objective. -func scanText(line string, o values) bool { - hostname := o["host"] - ntw, _ := network(o) - port := o["port"] - db := o["dbname"] - username := o["user"] - if len(line) == 0 || line[0] == '#' { - return false - } - split := getFields(line) - if len(split) != 5 { - return false - } - if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { - o["password"] = split[4] - return true - } - return false -} - -func (cn *conn) writeBuf(b byte) *writeBuf { - cn.scratch[0] = b +func (cn *conn) writeBuf(b proto.RequestCode) *writeBuf { + cn.scratch[0] = byte(b) return &writeBuf{ buf: cn.scratch[:5], pos: 1, } } -// Open opens a new connection to the database. dsn is a connection string. -// Most users should only use it through database/sql package from the standard +// Open opens a new connection to the database. dsn is a connection string. Most +// users should only use it through database/sql package from the standard // library. func Open(dsn string) (_ driver.Conn, err error) { return DialOpen(defaultDialer{}, dsn) @@ -340,86 +240,200 @@ func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { return c.open(context.Background()) } -func (c *Connector) open(ctx context.Context) (cn *conn, err error) { - // Handle any panics during connection initialization. Note that we - // specifically do *not* want to use errRecover(), as that would turn any - // connection errors into ErrBadConns, hiding the real error message from - // the user. - defer errRecoverNoErrBadConn(&err) +func (c *Connector) open(ctx context.Context) (*conn, error) { + tsa := c.cfg.TargetSessionAttrs +restart: + var ( + errs []error + app = func(err error, cfg Config) bool { + if err != nil { + if debugProto { + fmt.Println("CONNECT (error)", err) + } + errs = append(errs, fmt.Errorf("connecting to %s:%d: %w", cfg.Host, cfg.Port, err)) + } + return err != nil + } + ) + for _, cfg := range c.cfg.hosts() { + if debugProto { + fmt.Println("CONNECT ", cfg.string()) + } + + cn := &conn{cfg: cfg, dialer: c.dialer} + cn.cfg.Password = pgpass.PasswordFromPgpass(cn.cfg.Passfile, cn.cfg.User, cn.cfg.Password, + cn.cfg.Host, strconv.Itoa(int(cn.cfg.Port)), cn.cfg.Database, cn.cfg.isset("password")) + + var err error + cn.c, err = dial(ctx, c.dialer, cn.cfg) + if app(err, cfg) { + continue + } + + err = cn.ssl(cn.cfg) + if app(err, cfg) { + if cn.c != nil { + _ = cn.c.Close() + } + continue + } + + cn.buf = bufio.NewReader(cn.c) + err = cn.startup(cn.cfg) + if app(err, cfg) { + _ = cn.c.Close() + continue + } + + // Reset the deadline, in case one was set (see dial) + if cn.cfg.ConnectTimeout > 0 { + err := cn.c.SetDeadline(time.Time{}) + if app(err, cfg) { + _ = cn.c.Close() + continue + } + } + + err = cn.checkTSA(tsa) + if app(err, cfg) { + _ = cn.c.Close() + continue + } - // Create a new values map (copy). This makes it so maps in different - // connections do not reference the same underlying data structure, so it - // is safe for multiple connections to concurrently write to their opts. - o := make(values) - for k, v := range c.opts { - o[k] = v + return cn, nil } - cn = &conn{ - opts: o, - dialer: c.dialer, + // target_session_attrs=prefer-standby is treated as standby in checkTSA; we + // ran out of hosts so none are on standby. Clear the setting and try again. + if c.cfg.TargetSessionAttrs == TargetSessionAttrsPreferStandby { + tsa = TargetSessionAttrsAny + goto restart } - err = cn.handleDriverSettings(o) - if err != nil { - return nil, err + + if len(c.cfg.Multi) == 0 { + // Remove the "connecting to [..]" when we have just one host, so the + // error is identical to what we had before. + return nil, errors.Unwrap(errs[0]) } - cn.handlePgpass(o) + return nil, fmt.Errorf("pq: could not connect to any of the hosts:\n%w", errors.Join(errs...)) +} - cn.c, err = dial(ctx, c.dialer, o) +func (cn *conn) getBool(query string) (bool, error) { + res, err := cn.simpleQuery(query) if err != nil { - return nil, err + return false, err } + defer res.Close() - err = cn.ssl(o) + v := make([]driver.Value, 1) + err = res.Next(v) if err != nil { - if cn.c != nil { - cn.c.Close() - } - return nil, err + return false, err } - // cn.startup panics on error. Make sure we don't leak cn.c. - panicking := true - defer func() { - if panicking { - cn.c.Close() + switch vv := v[0].(type) { + default: + return false, fmt.Errorf("parseBool: unknown type %T: %[1]v", v[0]) + case bool: + return vv, nil + case string: + vv, ok := v[0].(string) + if !ok { + return false, err } - }() - - cn.buf = bufio.NewReader(cn.c) - cn.startup(o) - - // reset the deadline, in case one was set (see dial) - if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { - err = cn.c.SetDeadline(time.Time{}) + return vv == "on", nil } - panicking = false - return cn, err } -func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { - network, address := network(o) +func (cn *conn) checkTSA(tsa TargetSessionAttrs) error { + var ( + geths = func() (hs bool, err error) { + hs = cn.parameterStatus.inHotStandby.Bool + if !cn.parameterStatus.inHotStandby.Valid { + hs, err = cn.getBool("select pg_catalog.pg_is_in_recovery()") + } + return hs, err + } + getro = func() (ro bool, err error) { + ro = cn.parameterStatus.defaultTransactionReadOnly.Bool + if !cn.parameterStatus.defaultTransactionReadOnly.Valid { + ro, err = cn.getBool("show transaction_read_only") + } + return ro, err + } + ) - // Zero or not specified means wait indefinitely. - if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { - seconds, err := strconv.ParseInt(timeout, 10, 0) + switch tsa { + default: + panic("unreachable") + case "", TargetSessionAttrsAny: + return nil + case TargetSessionAttrsReadWrite, TargetSessionAttrsReadOnly: + readonly, err := getro() if err != nil { - return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) + return err } - duration := time.Duration(seconds) * time.Second + if !cn.parameterStatus.defaultTransactionReadOnly.Valid { + var err error + readonly, err = cn.getBool("show transaction_read_only") + if err != nil { + return err + } + } + switch { + case tsa == TargetSessionAttrsReadOnly && !readonly: + return errors.New("session is not read-only") + case tsa == TargetSessionAttrsReadWrite: + if readonly { + return errors.New("session is read-only") + } + hs, err := geths() + if err != nil { + return err + } + if hs { + return errors.New("server is in hot standby mode") + } + return nil + default: + return nil + } + case TargetSessionAttrsPrimary, TargetSessionAttrsStandby, TargetSessionAttrsPreferStandby: + hs, err := geths() + if err != nil { + return err + } + switch { + case (tsa == TargetSessionAttrsStandby || tsa == TargetSessionAttrsPreferStandby) && !hs: + return errors.New("server is not in hot standby mode") + case tsa == TargetSessionAttrsPrimary && hs: + return errors.New("server is in hot standby mode") + default: + return nil + } + } +} + +func dial(ctx context.Context, d Dialer, cfg Config) (net.Conn, error) { + network, address := cfg.network() + // Zero or not specified means wait indefinitely. + if cfg.ConnectTimeout > 0 { // connect_timeout should apply to the entire connection establishment // procedure, so we both use a timeout for the TCP connection - // establishment and set a deadline for doing the initial handshake. - // The deadline is then reset after startup() is done. - deadline := time.Now().Add(duration) - var conn net.Conn + // establishment and set a deadline for doing the initial handshake. The + // deadline is then reset after startup() is done. + var ( + deadline = time.Now().Add(cfg.ConnectTimeout) + conn net.Conn + err error + ) if dctx, ok := d.(DialerContext); ok { - ctx, cancel := context.WithTimeout(ctx, duration) + ctx, cancel := context.WithTimeout(ctx, cfg.ConnectTimeout) defer cancel() conn, err = dctx.DialContext(ctx, network, address) } else { - conn, err = d.DialTimeout(network, address, duration) + conn, err = d.DialTimeout(network, address, cfg.ConnectTimeout) } if err != nil { return nil, err @@ -433,140 +447,17 @@ func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { return d.Dial(network, address) } -func network(o values) (string, string) { - host := o["host"] - - if strings.HasPrefix(host, "/") { - sockPath := path.Join(host, ".s.PGSQL."+o["port"]) - return "unix", sockPath - } - - return "tcp", net.JoinHostPort(host, o["port"]) -} - -type values map[string]string - -// scanner implements a tokenizer for libpq-style option strings. -type scanner struct { - s []rune - i int -} - -// newScanner returns a new scanner initialized with the option string s. -func newScanner(s string) *scanner { - return &scanner{[]rune(s), 0} -} - -// Next returns the next rune. -// It returns 0, false if the end of the text has been reached. -func (s *scanner) Next() (rune, bool) { - if s.i >= len(s.s) { - return 0, false - } - r := s.s[s.i] - s.i++ - return r, true -} - -// SkipSpaces returns the next non-whitespace rune. -// It returns 0, false if the end of the text has been reached. -func (s *scanner) SkipSpaces() (rune, bool) { - r, ok := s.Next() - for unicode.IsSpace(r) && ok { - r, ok = s.Next() - } - return r, ok -} - -// parseOpts parses the options from name and adds them to the values. -// -// The parsing code is based on conninfo_parse from libpq's fe-connect.c -func parseOpts(name string, o values) error { - s := newScanner(name) - - for { - var ( - keyRunes, valRunes []rune - r rune - ok bool - ) - - if r, ok = s.SkipSpaces(); !ok { - break - } - - // Scan the key - for !unicode.IsSpace(r) && r != '=' { - keyRunes = append(keyRunes, r) - if r, ok = s.Next(); !ok { - break - } - } - - // Skip any whitespace if we're not at the = yet - if r != '=' { - r, ok = s.SkipSpaces() - } - - // The current character should be = - if r != '=' || !ok { - return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) - } - - // Skip any whitespace after the = - if r, ok = s.SkipSpaces(); !ok { - // If we reach the end here, the last value is just an empty string as per libpq. - o[string(keyRunes)] = "" - break - } - - if r != '\'' { - for !unicode.IsSpace(r) { - if r == '\\' { - if r, ok = s.Next(); !ok { - return fmt.Errorf(`missing character after backslash`) - } - } - valRunes = append(valRunes, r) - - if r, ok = s.Next(); !ok { - break - } - } - } else { - quote: - for { - if r, ok = s.Next(); !ok { - return fmt.Errorf(`unterminated quoted string literal in connection string`) - } - switch r { - case '\'': - break quote - case '\\': - r, _ = s.Next() - fallthrough - default: - valRunes = append(valRunes, r) - } - } - } - - o[string(keyRunes)] = string(valRunes) - } - - return nil -} - func (cn *conn) isInTransaction() bool { return cn.txnStatus == txnStatusIdleInTransaction || cn.txnStatus == txnStatusInFailedTransaction } -func (cn *conn) checkIsInTransaction(intxn bool) { +func (cn *conn) checkIsInTransaction(intxn bool) error { if cn.isInTransaction() != intxn { cn.err.set(driver.ErrBadConn) - errorf("unexpected transaction status %v", cn.txnStatus) + return fmt.Errorf("pq: unexpected transaction status %v", cn.txnStatus) } + return nil } func (cn *conn) Begin() (_ driver.Tx, err error) { @@ -577,12 +468,13 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if err := cn.err.get(); err != nil { return nil, err } - defer cn.errRecover(&err) + if err := cn.checkIsInTransaction(false); err != nil { + return nil, err + } - cn.checkIsInTransaction(false) _, commandTag, err := cn.simpleExec("BEGIN" + mode) if err != nil { - return nil, err + return nil, cn.handleError(err) } if commandTag != "BEGIN" { cn.err.set(driver.ErrBadConn) @@ -601,14 +493,15 @@ func (cn *conn) closeTxn() { } } -func (cn *conn) Commit() (err error) { +func (cn *conn) Commit() error { defer cn.closeTxn() if err := cn.err.get(); err != nil { return err } - defer cn.errRecover(&err) + if err := cn.checkIsInTransaction(true); err != nil { + return err + } - cn.checkIsInTransaction(true) // We don't want the client to think that everything is okay if it tries // to commit a failed transaction. However, no matter what we return, // database/sql will release this connection back into the free connection @@ -627,27 +520,33 @@ func (cn *conn) Commit() (err error) { if cn.isInTransaction() { cn.err.set(driver.ErrBadConn) } - return err + return cn.handleError(err) } if commandTag != "COMMIT" { cn.err.set(driver.ErrBadConn) return fmt.Errorf("unexpected command tag %s", commandTag) } - cn.checkIsInTransaction(false) - return nil + return cn.checkIsInTransaction(false) } -func (cn *conn) Rollback() (err error) { +func (cn *conn) Rollback() error { defer cn.closeTxn() if err := cn.err.get(); err != nil { return err } - defer cn.errRecover(&err) - return cn.rollback() + + err := cn.rollback() + if err != nil { + return cn.handleError(err) + } + return nil } func (cn *conn) rollback() (err error) { - cn.checkIsInTransaction(true) + if err := cn.checkIsInTransaction(true); err != nil { + return err + } + _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { if cn.isInTransaction() { @@ -658,8 +557,7 @@ func (cn *conn) rollback() (err error) { if commandTag != "ROLLBACK" { return fmt.Errorf("unexpected command tag %s", commandTag) } - cn.checkIsInTransaction(false) - return nil + return cn.checkIsInTransaction(false) } func (cn *conn) gname() string { @@ -667,126 +565,136 @@ func (cn *conn) gname() string { return strconv.FormatInt(int64(cn.namei), 10) } -func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { - b := cn.writeBuf('Q') +func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, resErr error) { + if debugProto { + fmt.Fprintf(os.Stderr, " START conn.simpleExec\n") + defer fmt.Fprintf(os.Stderr, " END conn.simpleExec\n") + } + + b := cn.writeBuf(proto.Query) b.string(q) - cn.send(b) + err := cn.send(b) + if err != nil { + return nil, "", err + } for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return nil, "", err + } switch t { - case 'C': - res, commandTag = cn.parseComplete(r.string()) - case 'Z': + case proto.CommandComplete: + res, commandTag, err = cn.parseComplete(r.string()) + if err != nil { + return nil, "", err + } + case proto.ReadyForQuery: cn.processReadyForQuery(r) - if res == nil && err == nil { - err = errUnexpectedReady + if res == nil && resErr == nil { + resErr = errUnexpectedReady } - // done - return - case 'E': - err = parseError(r) - case 'I': + return res, commandTag, resErr + case proto.ErrorResponse: + resErr = parseError(r, q) + case proto.EmptyQueryResponse: res = emptyRows - case 'T', 'D': + case proto.RowDescription, proto.DataRow: // ignore any results default: cn.err.set(driver.ErrBadConn) - errorf("unknown response for simple query: %q", t) + return nil, "", fmt.Errorf("pq: unknown response for simple query: %q", t) } } } -func (cn *conn) simpleQuery(q string) (res *rows, err error) { - defer cn.errRecover(&err) +func (cn *conn) simpleQuery(q string) (*rows, error) { + if debugProto { + fmt.Fprintf(os.Stderr, " START conn.simpleQuery\n") + defer fmt.Fprintf(os.Stderr, " END conn.simpleQuery\n") + } - b := cn.writeBuf('Q') + b := cn.writeBuf(proto.Query) b.string(q) - cn.send(b) + err := cn.send(b) + if err != nil { + return nil, cn.handleError(err, q) + } + var ( + res *rows + resErr error + ) for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return nil, cn.handleError(err, q) + } switch t { - case 'C', 'I': + case proto.CommandComplete, proto.EmptyQueryResponse: // We allow queries which don't return any results through Query as - // well as Exec. We still have to give database/sql a rows object + // well as Exec. We still have to give database/sql a rows object // the user can close, though, to avoid connections from being - // leaked. A "rows" with done=true works fine for that purpose. - if err != nil { + // leaked. A "rows" with done=true works fine for that purpose. + if resErr != nil { cn.err.set(driver.ErrBadConn) - errorf("unexpected message %q in simple query execution", t) + return nil, fmt.Errorf("pq: unexpected message %q in simple query execution", t) } if res == nil { - res = &rows{ - cn: cn, - } + res = &rows{cn: cn} } // Set the result and tag to the last command complete if there wasn't a // query already run. Although queries usually return from here and cede // control to Next, a query with zero results does not. - if t == 'C' { - res.result, res.tag = cn.parseComplete(r.string()) + if t == proto.CommandComplete { + res.result, res.tag, err = cn.parseComplete(r.string()) + if err != nil { + return nil, cn.handleError(err, q) + } if res.colNames != nil { - return + return res, cn.handleError(resErr, q) } } res.done = true - case 'Z': + case proto.ReadyForQuery: cn.processReadyForQuery(r) - // done - return - case 'E': + if err == nil && res == nil { + res = &rows{done: true} + } + return res, cn.handleError(resErr, q) // done + case proto.ErrorResponse: res = nil - err = parseError(r) - case 'D': + resErr = parseError(r, q) + case proto.DataRow: if res == nil { cn.err.set(driver.ErrBadConn) - errorf("unexpected DataRow in simple query execution") + return nil, fmt.Errorf("pq: unexpected DataRow in simple query execution") } - // the query didn't fail; kick off to Next - cn.saveMessage(t, r) - return - case 'T': + return res, cn.saveMessage(t, r) // The query didn't fail; kick off to Next + case proto.RowDescription: // res might be non-nil here if we received a previous - // CommandComplete, but that's fine; just overwrite it - res = &rows{cn: cn} - res.rowsHeader = parsePortalRowDescribe(r) + // CommandComplete, but that's fine and just overwrite it. + res = &rows{cn: cn, rowsHeader: parsePortalRowDescribe(r)} // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. default: cn.err.set(driver.ErrBadConn) - errorf("unknown response for simple query: %q", t) + return nil, fmt.Errorf("pq: unknown response for simple query: %q", t) } } } -type noRows struct{} - -var emptyRows noRows - -var _ driver.Result = noRows{} - -func (noRows) LastInsertId() (int64, error) { - return 0, errNoLastInsertID -} - -func (noRows) RowsAffected() (int64, error) { - return 0, errNoRowsAffected -} - // Decides which column formats to use for a prepared statement. The input is // an array of type oids, one element per result column. -func decideColumnFormats( - colTyps []fieldDesc, forceText bool, -) (colFmts []format, colFmtData []byte) { +func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte, _ error) { if len(colTyps) == 0 { - return nil, colFmtDataAllText + return nil, colFmtDataAllText, nil } colFmts = make([]format, len(colTyps)) if forceText { - return colFmts, colFmtDataAllText + return colFmts, colFmtDataAllText, nil } allBinary := true @@ -807,95 +715,164 @@ func decideColumnFormats( case oid.T_uuid: colFmts[i] = formatBinary allText = false - default: allBinary = false } } if allBinary { - return colFmts, colFmtDataAllBinary + return colFmts, colFmtDataAllBinary, nil } else if allText { - return colFmts, colFmtDataAllText + return colFmts, colFmtDataAllText, nil } else { colFmtData = make([]byte, 2+len(colFmts)*2) + if len(colFmts) > math.MaxUint16 { + return nil, nil, fmt.Errorf("pq: too many columns (%d > math.MaxUint16)", len(colFmts)) + } binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) for i, v := range colFmts { binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) } - return colFmts, colFmtData + return colFmts, colFmtData, nil } } -func (cn *conn) prepareTo(q, stmtName string) *stmt { +func (cn *conn) prepareTo(q, stmtName string) (*stmt, error) { + if debugProto { + fmt.Fprintf(os.Stderr, " START conn.prepareTo\n") + defer fmt.Fprintf(os.Stderr, " END conn.prepareTo\n") + } + st := &stmt{cn: cn, name: stmtName} - b := cn.writeBuf('P') + b := cn.writeBuf(proto.Parse) b.string(st.name) b.string(q) b.int16(0) - b.next('D') - b.byte('S') + b.next(proto.Describe) + b.byte(proto.Sync) b.string(st.name) - b.next('S') - cn.send(b) + b.next(proto.Sync) + err := cn.send(b) + if err != nil { + return nil, err + } - cn.readParseResponse() - st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() - st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) - cn.readReadyForQuery() - return st + err = cn.readParseResponse() + if err != nil { + return nil, err + } + st.paramTyps, st.colNames, st.colTyps, err = cn.readStatementDescribeResponse() + if err != nil { + return nil, err + } + st.colFmts, st.colFmtData, err = decideColumnFormats(st.colTyps, cn.cfg.DisablePreparedBinaryResult) + if err != nil { + return nil, err + } + + err = cn.readReadyForQuery() + if err != nil { + return nil, err + } + return st, nil } -func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { +func (cn *conn) Prepare(q string) (driver.Stmt, error) { if err := cn.err.get(); err != nil { return nil, err } - defer cn.errRecover(&err) - if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { + if pqsql.StartsWithCopy(q) { s, err := cn.prepareCopyIn(q) if err == nil { cn.inCopy = true } - return s, err + return s, cn.handleError(err, q) } - return cn.prepareTo(q, cn.gname()), nil + s, err := cn.prepareTo(q, cn.gname()) + if err != nil { + return nil, cn.handleError(err, q) + } + return s, nil } -func (cn *conn) Close() (err error) { - // Skip cn.bad return here because we always want to close a connection. - defer cn.errRecover(&err) - - // Ensure that cn.c.Close is always run. Since error handling is done with - // panics and cn.errRecover, the Close must be in a defer. - defer func() { - cerr := cn.c.Close() - if err == nil { - err = cerr - } - }() - +func (cn *conn) Close() error { // Don't go through send(); ListenerConn relies on us not scribbling on the // scratch buffer of this connection. - return cn.sendSimpleMessage('X') + err := cn.sendSimpleMessage(proto.Terminate) + if err != nil { + _ = cn.c.Close() // Ensure that cn.c.Close is always run. + return cn.handleError(err) + } + return cn.c.Close() +} + +func toNamedValue(v []driver.Value) []driver.NamedValue { + v2 := make([]driver.NamedValue, len(v)) + for i := range v { + v2[i] = driver.NamedValue{Ordinal: i + 1, Value: v[i]} + } + return v2 +} + +// CheckNamedValue implements [driver.NamedValueChecker]. +func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error { + // Ignore Valuer, for backward compatibility with pq.Array(). + if _, ok := nv.Value.(driver.Valuer); ok { + return driver.ErrSkip + } + + v := reflect.ValueOf(nv.Value) + if !v.IsValid() { + return driver.ErrSkip + } + t := v.Type() + for t.Kind() == reflect.Ptr { + t, v = t.Elem(), v.Elem() + } + + // Ignore []byte and related types: *[]byte, json.RawMessage, etc. + if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 { + return driver.ErrSkip + } + + switch v.Kind() { + default: + return driver.ErrSkip + case reflect.Slice: + var err error + nv.Value, err = Array(v.Interface()).Value() + return err + case reflect.Uint64: + value := v.Uint() + if value >= math.MaxInt64 { + nv.Value = strconv.FormatUint(value, 10) + } else { + nv.Value = int64(value) + } + return nil + } } // Implement the "Queryer" interface func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { - return cn.query(query, args) + return cn.query(query, toNamedValue(args)) } -func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { +func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) { + if debugProto { + fmt.Fprintf(os.Stderr, " START conn.query\n") + defer fmt.Fprintf(os.Stderr, " END conn.query\n") + } if err := cn.err.get(); err != nil { return nil, err } if cn.inCopy { return nil, errCopyInProgress } - defer cn.errRecover(&err) // Check to see if we can use the "simpleQuery" interface, which is // *much* faster than going through prepare/exec @@ -903,18 +880,40 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { return cn.simpleQuery(query) } - if cn.binaryParameters { - cn.sendBinaryModeQuery(query, args) - - cn.readParseResponse() - cn.readBindResponse() - rows := &rows{cn: cn} - rows.rowsHeader = cn.readPortalDescribeResponse() - cn.postExecuteWorkaround() - return rows, nil + if cn.cfg.BinaryParameters { + err := cn.sendBinaryModeQuery(query, args) + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.readParseResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.readBindResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + + rows := &rows{cn: cn} + rows.rowsHeader, err = cn.readPortalDescribeResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.postExecuteWorkaround() + if err != nil { + return nil, cn.handleError(err, query) + } + return rows, nil + } + + st, err := cn.prepareTo(query, "") + if err != nil { + return nil, cn.handleError(err, query) + } + err = st.exec(args) + if err != nil { + return nil, cn.handleError(err, query) } - st := cn.prepareTo(query, "") - st.exec(args) return &rows{ cn: cn, rowsHeader: st.rowsHeader, @@ -922,69 +921,100 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { } // Implement the optional "Execer" interface for one-shot queries -func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { +func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) { if err := cn.err.get(); err != nil { return nil, err } - defer cn.errRecover(&err) - // Check to see if we can use the "simpleExec" interface, which is - // *much* faster than going through prepare/exec + // Check to see if we can use the "simpleExec" interface, which is *much* + // faster than going through prepare/exec if len(args) == 0 { // ignore commandTag, our caller doesn't care r, _, err := cn.simpleExec(query) - return r, err + return r, cn.handleError(err, query) } - if cn.binaryParameters { - cn.sendBinaryModeQuery(query, args) + if cn.cfg.BinaryParameters { + err := cn.sendBinaryModeQuery(query, toNamedValue(args)) + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.readParseResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.readBindResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + + _, err = cn.readPortalDescribeResponse() + if err != nil { + return nil, cn.handleError(err, query) + } + err = cn.postExecuteWorkaround() + if err != nil { + return nil, cn.handleError(err, query) + } + res, _, err := cn.readExecuteResponse("Execute") + return res, cn.handleError(err, query) + } - cn.readParseResponse() - cn.readBindResponse() - cn.readPortalDescribeResponse() - cn.postExecuteWorkaround() - res, _, err = cn.readExecuteResponse("Execute") - return res, err + // Use the unnamed statement to defer planning until bind time, or else + // value-based selectivity estimates cannot be used. + st, err := cn.prepareTo(query, "") + if err != nil { + return nil, cn.handleError(err, query) } - // Use the unnamed statement to defer planning until bind - // time, or else value-based selectivity estimates cannot be - // used. - st := cn.prepareTo(query, "") r, err := st.Exec(args) if err != nil { - panic(err) + return nil, cn.handleError(err, query) } - return r, err + return r, nil } -type safeRetryError struct { - Err error -} +type safeRetryError struct{ Err error } -func (se *safeRetryError) Error() string { - return se.Err.Error() -} +func (se *safeRetryError) Error() string { return se.Err.Error() } -func (cn *conn) send(m *writeBuf) { - n, err := cn.c.Write(m.wrap()) - if err != nil { - if n == 0 { - err = &safeRetryError{Err: err} +func (cn *conn) send(m *writeBuf) error { + if debugProto { + w := m.wrap() + for len(w) > 0 { // Can contain multiple messages. + c := proto.RequestCode(w[0]) + l := int(binary.BigEndian.Uint32(w[1:5])) - 4 + fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", c, l, w[5:l+5]) + w = w[l+5:] } - panic(err) } + + n, err := cn.c.Write(m.wrap()) + if err != nil && n == 0 { + err = &safeRetryError{Err: err} + } + return err } func (cn *conn) sendStartupPacket(m *writeBuf) error { + if debugProto { + w := m.wrap() + fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", + "Startup", + int(binary.BigEndian.Uint32(w[1:5]))-4, + w[5:]) + } _, err := cn.c.Write((m.wrap())[1:]) return err } -// Send a message of type typ to the server on the other end of cn. The -// message should have no payload. This method does not use the scratch -// buffer. -func (cn *conn) sendSimpleMessage(typ byte) (err error) { - _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'}) +// Send a message of type typ to the server on the other end of cn. The message +// should have no payload. This method does not use the scratch buffer. +func (cn *conn) sendSimpleMessage(typ proto.RequestCode) error { + if debugProto { + fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", + proto.RequestCode(typ), 0, []byte{}) + } + _, err := cn.c.Write([]byte{byte(typ), '\x00', '\x00', '\x00', '\x04'}) return err } @@ -993,18 +1023,19 @@ func (cn *conn) sendSimpleMessage(typ byte) (err error) { // method is useful in cases where you have to see what the next message is // going to be (e.g. to see whether it's an error or not) but you can't handle // the message yourself. -func (cn *conn) saveMessage(typ byte, buf *readBuf) { +func (cn *conn) saveMessage(typ proto.ResponseCode, buf *readBuf) error { if cn.saveMessageType != 0 { cn.err.set(driver.ErrBadConn) - errorf("unexpected saveMessageType %d", cn.saveMessageType) + return fmt.Errorf("unexpected saveMessageType %d", cn.saveMessageType) } cn.saveMessageType = typ cn.saveMessageBuffer = *buf + return nil } // recvMessage receives any message from the backend, or returns an error if // a problem occurred while reading the message. -func (cn *conn) recvMessage(r *readBuf) (byte, error) { +func (cn *conn) recvMessage(r *readBuf) (proto.ResponseCode, error) { // workaround for a QueryRow bug, see exec if cn.saveMessageType != 0 { t := cn.saveMessageType @@ -1020,9 +1051,21 @@ func (cn *conn) recvMessage(r *readBuf) (byte, error) { return 0, err } - // read the type and length of the message that follows - t := x[0] + // Read the type and length of the message that follows. + t := proto.ResponseCode(x[0]) n := int(binary.BigEndian.Uint32(x[1:])) - 4 + + // When PostgreSQL cannot start a backend (e.g., an external process limit), + // it sends plain text like "Ecould not fork new process [..]", which + // doesn't use the standard encoding for the Error message. + // + // libpq checks "if ErrorResponse && (msgLength < 8 || msgLength > MAX_ERRLEN)", + // but check < 4 since n represents bytes remaining to be read after length. + if t == proto.ErrorResponse && (n < 4 || n > proto.MaxErrlen) { + msg, _ := cn.buf.ReadString('\x00') + return 0, fmt.Errorf("pq: server error: %s%s", string(x[1:]), strings.TrimSuffix(msg, "\x00")) + } + var y []byte if n <= len(cn.scratch) { y = cn.scratch[:n] @@ -1034,75 +1077,80 @@ func (cn *conn) recvMessage(r *readBuf) (byte, error) { return 0, err } *r = y + if debugProto { + fmt.Fprintf(os.Stderr, "SERVER ← %-20s %5d %q\n", t, n, y) + } return t, nil } -// recv receives a message from the backend, but if an error happened while -// reading the message or the received message was an ErrorResponse, it panics. -// NoticeResponses are ignored. This function should generally be used only +// recv receives a message from the backend, returning an error if an error +// happened while reading the message or the received message an ErrorResponse. +// NoticeResponses are ignored. This function should generally be used only // during the startup sequence. -func (cn *conn) recv() (t byte, r *readBuf) { +func (cn *conn) recv() (proto.ResponseCode, *readBuf, error) { for { - var err error - r = &readBuf{} - t, err = cn.recvMessage(r) + r := new(readBuf) + t, err := cn.recvMessage(r) if err != nil { - panic(err) + return 0, nil, err } switch t { - case 'E': - panic(parseError(r)) - case 'N': + case proto.ErrorResponse: + return 0, nil, parseError(r, "") + case proto.NoticeResponse: if n := cn.noticeHandler; n != nil { - n(parseError(r)) + n(parseError(r, "")) } - case 'A': + case proto.NotificationResponse: if n := cn.notificationHandler; n != nil { n(recvNotification(r)) } default: - return + return t, r, nil } } } // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by // the caller to avoid an allocation. -func (cn *conn) recv1Buf(r *readBuf) byte { +func (cn *conn) recv1Buf(r *readBuf) (proto.ResponseCode, error) { for { t, err := cn.recvMessage(r) if err != nil { - panic(err) + return 0, err } switch t { - case 'A': + case proto.NotificationResponse: if n := cn.notificationHandler; n != nil { n(recvNotification(r)) } - case 'N': + case proto.NoticeResponse: if n := cn.noticeHandler; n != nil { - n(parseError(r)) + n(parseError(r, "")) } - case 'S': + case proto.ParameterStatus: cn.processParameterStatus(r) default: - return t + return t, nil } } } -// recv1 receives a message from the backend, panicking if an error occurs -// while attempting to read it. All asynchronous messages are ignored, with -// the exception of ErrorResponse. -func (cn *conn) recv1() (t byte, r *readBuf) { - r = &readBuf{} - t = cn.recv1Buf(r) - return t, r +// recv1 receives a message from the backend, returning an error if an error +// happened while reading the message or the received message an ErrorResponse. +// All asynchronous messages are ignored, with the exception of ErrorResponse. +func (cn *conn) recv1() (proto.ResponseCode, *readBuf, error) { + r := new(readBuf) + t, err := cn.recv1Buf(r) + if err != nil { + return 0, nil, err + } + return t, r, nil } -func (cn *conn) ssl(o values) error { - upgrade, err := ssl(o) +func (cn *conn) ssl(cfg Config) error { + upgrade, err := ssl(cfg) if err != nil { return err } @@ -1112,367 +1160,230 @@ func (cn *conn) ssl(o values) error { return nil } - w := cn.writeBuf(0) - w.int32(80877103) - if err = cn.sendStartupPacket(w); err != nil { - return err - } + // Only negotiate the ssl handshake if requested (which is the default). + // sllnegotiation=direct is supported by pg17 and above. + if cfg.SSLNegotiation != SSLNegotiationDirect { + w := cn.writeBuf(0) + w.int32(proto.NegotiateSSLCode) + if err = cn.sendStartupPacket(w); err != nil { + return err + } - b := cn.scratch[:1] - _, err = io.ReadFull(cn.c, b) - if err != nil { - return err - } + b := cn.scratch[:1] + _, err = io.ReadFull(cn.c, b) + if err != nil { + return err + } - if b[0] != 'S' { - return ErrSSLNotSupported + if b[0] != 'S' { + return ErrSSLNotSupported + } } cn.c, err = upgrade(cn.c) return err } -// isDriverSetting returns true iff a setting is purely for configuring the -// driver's options and should not be sent to the server in the connection -// startup packet. -func isDriverSetting(key string) bool { - switch key { - case "host", "port": - return true - case "password": - return true - case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni": - return true - case "fallback_application_name": - return true - case "connect_timeout": - return true - case "disable_prepared_binary_result": - return true - case "binary_parameters": - return true - case "krbsrvname": - return true - case "krbspn": - return true - default: - return false - } -} - -func (cn *conn) startup(o values) { +func (cn *conn) startup(cfg Config) error { w := cn.writeBuf(0) - w.int32(196608) - // Send the backend the name of the database we want to connect to, and the - // user we want to connect as. Additionally, we send over any run-time - // parameters potentially included in the connection string. If the server - // doesn't recognize any of them, it will reply with an error. - for k, v := range o { - if isDriverSetting(k) { - // skip options which can't be run-time parameters - continue - } - // The protocol requires us to supply the database name as "database" - // instead of "dbname". - if k == "dbname" { - k = "database" - } + w.int32(proto.ProtocolVersion30) + + w.string("user") + w.string(cfg.User) + w.string("database") + w.string(cfg.Database) + // w.string("replication") // Sent by libpq, but we don't support that. + w.string("options") + w.string(cfg.Options) + if cfg.ApplicationName != "" { + w.string("application_name") + w.string(cfg.ApplicationName) + } + w.string("client_encoding") + w.string(cfg.ClientEncoding) + + for k, v := range cfg.Runtime { w.string(k) w.string(v) } + w.string("") if err := cn.sendStartupPacket(w); err != nil { - panic(err) + return err } for { - t, r := cn.recv() + t, r, err := cn.recv() + if err != nil { + return err + } switch t { - case 'K': + case proto.BackendKeyData: cn.processBackendKeyData(r) - case 'S': + case proto.ParameterStatus: cn.processParameterStatus(r) - case 'R': - cn.auth(r, o) - case 'Z': + case proto.AuthenticationRequest: + err := cn.auth(r, cfg) + if err != nil { + return err + } + case proto.ReadyForQuery: cn.processReadyForQuery(r) - return + return nil default: - errorf("unknown response for startup: %q", t) + return fmt.Errorf("pq: unknown response for startup: %q", t) } } } -func (cn *conn) auth(r *readBuf, o values) { - switch code := r.int32(); code { - case 0: - // OK - case 3: - w := cn.writeBuf('p') - w.string(o["password"]) - cn.send(w) +func (cn *conn) auth(r *readBuf, cfg Config) error { + switch code := proto.AuthCode(r.int32()); code { + default: + return fmt.Errorf("pq: unknown authentication response: %s", code) + case proto.AuthReqKrb4, proto.AuthReqKrb5, proto.AuthReqCrypt, proto.AuthReqSSPI: + return fmt.Errorf("pq: unsupported authentication method: %s", code) + case proto.AuthReqOk: + return nil - t, r := cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) - } + case proto.AuthReqPassword: + w := cn.writeBuf(proto.PasswordMessage) + w.string(cfg.Password) + // Don't need to check AuthOk response here; auth() is called in a loop, + // which catches the errors and AuthReqOk responses. + return cn.send(w) - if r.int32() != 0 { - errorf("unexpected authentication response: %q", t) - } - case 5: + case proto.AuthReqMD5: s := string(r.next(4)) - w := cn.writeBuf('p') - w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) - cn.send(w) - - t, r := cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) - } + w := cn.writeBuf(proto.PasswordMessage) + w.string("md5" + md5s(md5s(cfg.Password+cfg.User)+s)) + // Same here. + return cn.send(w) - if r.int32() != 0 { - errorf("unexpected authentication response: %q", t) - } - case 7: // GSSAPI, startup + case proto.AuthReqGSS: // GSSAPI, startup if newGss == nil { - errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)") + return fmt.Errorf("pq: kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos)") } cli, err := newGss() if err != nil { - errorf("kerberos error: %s", err.Error()) + return fmt.Errorf("pq: kerberos error: %w", err) } var token []byte - - if spn, ok := o["krbspn"]; ok { + if cfg.isset("krbspn") { // Use the supplied SPN if provided.. - token, err = cli.GetInitTokenFromSpn(spn) + token, err = cli.GetInitTokenFromSpn(cfg.KrbSpn) } else { // Allow the kerberos service name to be overridden service := "postgres" - if val, ok := o["krbsrvname"]; ok { - service = val + if cfg.isset("krbsrvname") { + service = cfg.KrbSrvname } - - token, err = cli.GetInitToken(o["host"], service) + token, err = cli.GetInitToken(cfg.Host, service) } - if err != nil { - errorf("failed to get Kerberos ticket: %q", err) + return fmt.Errorf("pq: failed to get Kerberos ticket: %w", err) } - w := cn.writeBuf('p') + w := cn.writeBuf(proto.GSSResponse) w.bytes(token) - cn.send(w) + err = cn.send(w) + if err != nil { + return err + } // Store for GSSAPI continue message cn.gss = cli + return nil - case 8: // GSSAPI continue - + case proto.AuthReqGSSCont: // GSSAPI continue if cn.gss == nil { - errorf("GSSAPI protocol error") + return errors.New("pq: GSSAPI protocol error") } - b := []byte(*r) - - done, tokOut, err := cn.gss.Continue(b) + done, tokOut, err := cn.gss.Continue([]byte(*r)) if err == nil && !done { - w := cn.writeBuf('p') + w := cn.writeBuf(proto.SASLInitialResponse) w.bytes(tokOut) - cn.send(w) + err = cn.send(w) + if err != nil { + return err + } } - // Errors fall through and read the more detailed message - // from the server.. + // Errors fall through and read the more detailed message from the + // server. + return nil - case 10: - sc := scram.NewClient(sha256.New, o["user"], o["password"]) + case proto.AuthReqSASL: + sc := scram.NewClient(sha256.New, cfg.User, cfg.Password) sc.Step(nil) if sc.Err() != nil { - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) } scOut := sc.Out() - w := cn.writeBuf('p') + w := cn.writeBuf(proto.SASLResponse) w.string("SCRAM-SHA-256") w.int32(len(scOut)) w.bytes(scOut) - cn.send(w) + err := cn.send(w) + if err != nil { + return err + } - t, r := cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) + t, r, err := cn.recv() + if err != nil { + return err + } + if t != proto.AuthenticationRequest { + return fmt.Errorf("pq: unexpected password response: %q", t) } - if r.int32() != 11 { - errorf("unexpected authentication response: %q", t) + if r.int32() != int(proto.AuthReqSASLCont) { + return fmt.Errorf("pq: unexpected authentication response: %q", t) } nextStep := r.next(len(*r)) sc.Step(nextStep) if sc.Err() != nil { - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) } scOut = sc.Out() - w = cn.writeBuf('p') + w = cn.writeBuf(proto.SASLResponse) w.bytes(scOut) - cn.send(w) + err = cn.send(w) + if err != nil { + return err + } - t, r = cn.recv() - if t != 'R' { - errorf("unexpected password response: %q", t) + t, r, err = cn.recv() + if err != nil { + return err + } + if t != proto.AuthenticationRequest { + return fmt.Errorf("pq: unexpected password response: %q", t) } - if r.int32() != 12 { - errorf("unexpected authentication response: %q", t) + if r.int32() != int(proto.AuthReqSASLFin) { + return fmt.Errorf("pq: unexpected authentication response: %q", t) } nextStep = r.next(len(*r)) sc.Step(nextStep) if sc.Err() != nil { - errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) } - default: - errorf("unknown authentication response: %d", code) - } -} - -type format int - -const formatText format = 0 -const formatBinary format = 1 - -// One result-column format code with the value 1 (i.e. all binary). -var colFmtDataAllBinary = []byte{0, 1, 0, 1} - -// No result-column format codes (i.e. all text). -var colFmtDataAllText = []byte{0, 0} - -type stmt struct { - cn *conn - name string - rowsHeader - colFmtData []byte - paramTyps []oid.Oid - closed bool -} - -func (st *stmt) Close() (err error) { - if st.closed { return nil } - if err := st.cn.err.get(); err != nil { - return err - } - defer st.cn.errRecover(&err) - - w := st.cn.writeBuf('C') - w.byte('S') - w.string(st.name) - st.cn.send(w) - - st.cn.send(st.cn.writeBuf('S')) - - t, _ := st.cn.recv1() - if t != '3' { - st.cn.err.set(driver.ErrBadConn) - errorf("unexpected close response: %q", t) - } - st.closed = true - - t, r := st.cn.recv1() - if t != 'Z' { - st.cn.err.set(driver.ErrBadConn) - errorf("expected ready for query, but got: %q", t) - } - st.cn.processReadyForQuery(r) - - return nil -} - -func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { - return st.query(v) -} - -func (st *stmt) query(v []driver.Value) (r *rows, err error) { - if err := st.cn.err.get(); err != nil { - return nil, err - } - defer st.cn.errRecover(&err) - - st.exec(v) - return &rows{ - cn: st.cn, - rowsHeader: st.rowsHeader, - }, nil -} - -func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { - if err := st.cn.err.get(); err != nil { - return nil, err - } - defer st.cn.errRecover(&err) - - st.exec(v) - res, _, err = st.cn.readExecuteResponse("simple query") - return res, err -} - -func (st *stmt) exec(v []driver.Value) { - if len(v) >= 65536 { - errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) - } - if len(v) != len(st.paramTyps) { - errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) - } - - cn := st.cn - w := cn.writeBuf('B') - w.byte(0) // unnamed portal - w.string(st.name) - - if cn.binaryParameters { - cn.sendBinaryParameters(w, v) - } else { - w.int16(0) - w.int16(len(v)) - for i, x := range v { - if x == nil { - w.int32(-1) - } else { - b := encode(&cn.parameterStatus, x, st.paramTyps[i]) - w.int32(len(b)) - w.bytes(b) - } - } - } - w.bytes(st.colFmtData) - - w.next('E') - w.byte(0) - w.int32(0) - - w.next('S') - cn.send(w) - - cn.readBindResponse() - cn.postExecuteWorkaround() - -} - -func (st *stmt) NumInput() int { - return len(st.paramTyps) } // parseComplete parses the "command tag" from a CommandComplete message, and -// returns the number of rows affected (if applicable) and a string -// identifying only the command that was executed, e.g. "ALTER TABLE". If the -// command tag could not be parsed, parseComplete panics. -func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { +// returns the number of rows affected (if applicable) and a string identifying +// only the command that was executed, e.g. "ALTER TABLE". Returns an error if +// the command can cannot be parsed. +func (cn *conn) parseComplete(commandTag string) (driver.Result, string, error) { commandsWithAffectedRows := []string{ "SELECT ", // INSERT is handled below @@ -1492,218 +1403,29 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { break } } - // INSERT also includes the oid of the inserted row in its command tag. - // Oids in user tables are deprecated, and the oid is only returned when - // exactly one row is inserted, so it's unlikely to be of value to any - // real-world application and we can ignore it. + // INSERT also includes the oid of the inserted row in its command tag. Oids + // in user tables are deprecated, and the oid is only returned when exactly + // one row is inserted, so it's unlikely to be of value to any real-world + // application and we can ignore it. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { parts := strings.Split(commandTag, " ") if len(parts) != 3 { cn.err.set(driver.ErrBadConn) - errorf("unexpected INSERT command tag %s", commandTag) + return nil, "", fmt.Errorf("pq: unexpected INSERT command tag %s", commandTag) } affectedRows = &parts[len(parts)-1] commandTag = "INSERT" } // There should be no affected rows attached to the tag, just return it if affectedRows == nil { - return driver.RowsAffected(0), commandTag + return driver.RowsAffected(0), commandTag, nil } n, err := strconv.ParseInt(*affectedRows, 10, 64) if err != nil { cn.err.set(driver.ErrBadConn) - errorf("could not parse commandTag: %s", err) - } - return driver.RowsAffected(n), commandTag -} - -type rowsHeader struct { - colNames []string - colTyps []fieldDesc - colFmts []format -} - -type rows struct { - cn *conn - finish func() - rowsHeader - done bool - rb readBuf - result driver.Result - tag string - - next *rowsHeader -} - -func (rs *rows) Close() error { - if finish := rs.finish; finish != nil { - defer finish() - } - // no need to look at cn.bad as Next() will - for { - err := rs.Next(nil) - switch err { - case nil: - case io.EOF: - // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row - // description, used with HasNextResultSet). We need to fetch messages until - // we hit a 'Z', which is done by waiting for done to be set. - if rs.done { - return nil - } - default: - return err - } - } -} - -func (rs *rows) Columns() []string { - return rs.colNames -} - -func (rs *rows) Result() driver.Result { - if rs.result == nil { - return emptyRows - } - return rs.result -} - -func (rs *rows) Tag() string { - return rs.tag -} - -func (rs *rows) Next(dest []driver.Value) (err error) { - if rs.done { - return io.EOF - } - - conn := rs.cn - if err := conn.err.getForNext(); err != nil { - return err - } - defer conn.errRecover(&err) - - for { - t := conn.recv1Buf(&rs.rb) - switch t { - case 'E': - err = parseError(&rs.rb) - case 'C', 'I': - if t == 'C' { - rs.result, rs.tag = conn.parseComplete(rs.rb.string()) - } - continue - case 'Z': - conn.processReadyForQuery(&rs.rb) - rs.done = true - if err != nil { - return err - } - return io.EOF - case 'D': - n := rs.rb.int16() - if err != nil { - conn.err.set(driver.ErrBadConn) - errorf("unexpected DataRow after error %s", err) - } - if n < len(dest) { - dest = dest[:n] - } - for i := range dest { - l := rs.rb.int32() - if l == -1 { - dest[i] = nil - continue - } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) - } - return - case 'T': - next := parsePortalRowDescribe(&rs.rb) - rs.next = &next - return io.EOF - default: - errorf("unexpected message after execute: %q", t) - } - } -} - -func (rs *rows) HasNextResultSet() bool { - hasNext := rs.next != nil && !rs.done - return hasNext -} - -func (rs *rows) NextResultSet() error { - if rs.next == nil { - return io.EOF - } - rs.rowsHeader = *rs.next - rs.next = nil - return nil -} - -// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be -// used as part of an SQL statement. For example: -// -// tblname := "my_table" -// data := "my_data" -// quoted := pq.QuoteIdentifier(tblname) -// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) -// -// Any double quotes in name will be escaped. The quoted identifier will be -// case sensitive when used in a query. If the input string contains a zero -// byte, the result will be truncated immediately before it. -func QuoteIdentifier(name string) string { - end := strings.IndexRune(name, 0) - if end > -1 { - name = name[:end] - } - return `"` + strings.Replace(name, `"`, `""`, -1) + `"` -} - -// BufferQuoteIdentifier satisfies the same purpose as QuoteIdentifier, but backed by a -// byte buffer. -func BufferQuoteIdentifier(name string, buffer *bytes.Buffer) { - end := strings.IndexRune(name, 0) - if end > -1 { - name = name[:end] - } - buffer.WriteRune('"') - buffer.WriteString(strings.Replace(name, `"`, `""`, -1)) - buffer.WriteRune('"') -} - -// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal -// to DDL and other statements that do not accept parameters) to be used as part -// of an SQL statement. For example: -// -// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") -// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) -// -// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be -// replaced by two backslashes (i.e. "\\") and the C-style escape identifier -// that PostgreSQL provides ('E') will be prepended to the string. -func QuoteLiteral(literal string) string { - // This follows the PostgreSQL internal algorithm for handling quoted literals - // from libpq, which can be found in the "PQEscapeStringInternal" function, - // which is found in the libpq/fe-exec.c source file: - // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c - // - // substitute any single-quotes (') with two single-quotes ('') - literal = strings.Replace(literal, `'`, `''`, -1) - // determine if the string has any backslashes (\) in it. - // if it does, replace any backslashes (\) with two backslashes (\\) - // then, we need to wrap the entire string with a PostgreSQL - // C-style escape. Per how "PQEscapeStringInternal" handles this case, we - // also add a space before the "E" - if strings.Contains(literal, `\`) { - literal = strings.Replace(literal, `\`, `\\`, -1) - literal = ` E'` + literal + `'` - } else { - // otherwise, we can just wrap the literal with a pair of single quotes - literal = `'` + literal + `'` + return nil, "", fmt.Errorf("pq: could not parse commandTag: %w", err) } - return literal + return driver.RowsAffected(n), commandTag, nil } func md5s(s string) string { @@ -1712,13 +1434,12 @@ func md5s(s string) string { return fmt.Sprintf("%x", h.Sum(nil)) } -func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { - // Do one pass over the parameters to see if we're going to send any of - // them over in binary. If we are, create a paramFormats array at the - // same time. +func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.NamedValue) error { + // Do one pass over the parameters to see if we're going to send any of them + // over in binary. If we are, create a paramFormats array at the same time. var paramFormats []int for i, x := range args { - _, ok := x.([]byte) + _, ok := x.Value.([]byte) if ok { if paramFormats == nil { paramFormats = make([]int, len(args)) @@ -1737,64 +1458,81 @@ func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { b.int16(len(args)) for _, x := range args { - if x == nil { + if x.Value == nil { + b.int32(-1) + } else if xx, ok := x.Value.([]byte); ok && xx == nil { b.int32(-1) } else { - datum := binaryEncode(&cn.parameterStatus, x) + datum, err := binaryEncode(x.Value) + if err != nil { + return err + } b.int32(len(datum)) b.bytes(datum) } } + return nil } -func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { +func (cn *conn) sendBinaryModeQuery(query string, args []driver.NamedValue) error { if len(args) >= 65536 { - errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) + return fmt.Errorf("pq: got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) } - b := cn.writeBuf('P') + b := cn.writeBuf(proto.Parse) b.byte(0) // unnamed statement b.string(query) b.int16(0) - b.next('B') + b.next(proto.Bind) b.int16(0) // unnamed portal and statement - cn.sendBinaryParameters(b, args) + err := cn.sendBinaryParameters(b, args) + if err != nil { + return err + } b.bytes(colFmtDataAllText) - b.next('D') - b.byte('P') + b.next(proto.Describe) + b.byte(proto.Parse) b.byte(0) // unnamed portal - b.next('E') + b.next(proto.Execute) b.byte(0) b.int32(0) - b.next('S') - cn.send(b) + b.next(proto.Sync) + return cn.send(b) } func (cn *conn) processParameterStatus(r *readBuf) { - var err error - - param := r.string() - switch param { + switch r.string() { + default: + // ignore case "server_version": - var major1 int - var major2 int - _, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) + var major1, major2 int + _, err := fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) if err == nil { cn.parameterStatus.serverVersion = major1*10000 + major2*100 } - case "TimeZone": + var err error cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) if err != nil { cn.parameterStatus.currentLocation = nil } - - default: - // ignore + // Use sql.NullBool so we can distinguish between false and not sent. If + // it's not sent we use a query to get the value – I don't know when these + // parameters are not sent, but this is what libpq does. + case "in_hot_standby": + b, err := pqutil.ParseBool(r.string()) + if err == nil { + cn.parameterStatus.inHotStandby = sql.NullBool{Valid: true, Bool: b} + } + case "default_transaction_read_only": + b, err := pqutil.ParseBool(r.string()) + if err == nil { + cn.parameterStatus.defaultTransactionReadOnly = sql.NullBool{Valid: true, Bool: b} + } } } @@ -1802,15 +1540,22 @@ func (cn *conn) processReadyForQuery(r *readBuf) { cn.txnStatus = transactionStatus(r.byte()) } -func (cn *conn) readReadyForQuery() { - t, r := cn.recv1() +func (cn *conn) readReadyForQuery() error { + t, r, err := cn.recv1() + if err != nil { + return err + } switch t { - case 'Z': + case proto.ReadyForQuery: cn.processReadyForQuery(r) - return + return nil + case proto.ErrorResponse: + err := parseError(r, "") + cn.err.set(driver.ErrBadConn) + return err default: cn.err.set(driver.ErrBadConn) - errorf("unexpected message %q; expected ReadyForQuery", t) + return fmt.Errorf("pq: unexpected message %q; expected ReadyForQuery", t) } } @@ -1819,85 +1564,92 @@ func (cn *conn) processBackendKeyData(r *readBuf) { cn.secretKey = r.int32() } -func (cn *conn) readParseResponse() { - t, r := cn.recv1() +func (cn *conn) readParseResponse() error { + t, r, err := cn.recv1() + if err != nil { + return err + } switch t { - case '1': - return - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) + case proto.ParseComplete: + return nil + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return err default: cn.err.set(driver.ErrBadConn) - errorf("unexpected Parse response %q", t) + return fmt.Errorf("pq: unexpected Parse response %q", t) } } -func (cn *conn) readStatementDescribeResponse() ( - paramTyps []oid.Oid, - colNames []string, - colTyps []fieldDesc, -) { +func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc, _ error) { for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return nil, nil, nil, err + } switch t { - case 't': + case proto.ParameterDescription: nparams := r.int16() paramTyps = make([]oid.Oid, nparams) for i := range paramTyps { paramTyps[i] = r.oid() } - case 'n': - return paramTyps, nil, nil - case 'T': + case proto.NoData: + return paramTyps, nil, nil, nil + case proto.RowDescription: colNames, colTyps = parseStatementRowDescribe(r) - return paramTyps, colNames, colTyps - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) + return paramTyps, colNames, colTyps, nil + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return nil, nil, nil, err default: cn.err.set(driver.ErrBadConn) - errorf("unexpected Describe statement response %q", t) + return nil, nil, nil, fmt.Errorf("pq: unexpected Describe statement response %q", t) } } } -func (cn *conn) readPortalDescribeResponse() rowsHeader { - t, r := cn.recv1() +func (cn *conn) readPortalDescribeResponse() (rowsHeader, error) { + t, r, err := cn.recv1() + if err != nil { + return rowsHeader{}, err + } switch t { - case 'T': - return parsePortalRowDescribe(r) - case 'n': - return rowsHeader{} - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) + case proto.RowDescription: + return parsePortalRowDescribe(r), nil + case proto.NoData: + return rowsHeader{}, nil + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return rowsHeader{}, err default: cn.err.set(driver.ErrBadConn) - errorf("unexpected Describe response %q", t) + return rowsHeader{}, fmt.Errorf("pq: unexpected Describe response %q", t) } - panic("not reached") } -func (cn *conn) readBindResponse() { - t, r := cn.recv1() +func (cn *conn) readBindResponse() error { + t, r, err := cn.recv1() + if err != nil { + return err + } switch t { - case '2': - return - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) + case proto.BindComplete: + return nil + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return err default: cn.err.set(driver.ErrBadConn) - errorf("unexpected Bind response %q", t) + return fmt.Errorf("pq: unexpected Bind response %q", t) } } -func (cn *conn) postExecuteWorkaround() { +func (cn *conn) postExecuteWorkaround() error { // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores // any errors from rows.Next, which masks errors that happened during the // execution of the query. To avoid the problem in common cases, we wait @@ -1908,56 +1660,62 @@ func (cn *conn) postExecuteWorkaround() { // However, if it's an error, we wait until ReadyForQuery and then return // the error to our caller. for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return err + } switch t { - case 'E': - err := parseError(r) - cn.readReadyForQuery() - panic(err) - case 'C', 'D', 'I': + case proto.ErrorResponse: + err := parseError(r, "") + _ = cn.readReadyForQuery() + return err + case proto.CommandComplete, proto.DataRow, proto.EmptyQueryResponse: // the query didn't fail, but we can't process this message - cn.saveMessage(t, r) - return + return cn.saveMessage(t, r) default: cn.err.set(driver.ErrBadConn) - errorf("unexpected message during extended query execution: %q", t) + return fmt.Errorf("pq: unexpected message during extended query execution: %q", t) } } } // Only for Exec(), since we ignore the returned data -func (cn *conn) readExecuteResponse( - protocolState string, -) (res driver.Result, commandTag string, err error) { +func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, resErr error) { for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return nil, "", err + } switch t { - case 'C': - if err != nil { + case proto.CommandComplete: + if resErr != nil { cn.err.set(driver.ErrBadConn) - errorf("unexpected CommandComplete after error %s", err) + return nil, "", fmt.Errorf("pq: unexpected CommandComplete after error %s", resErr) } - res, commandTag = cn.parseComplete(r.string()) - case 'Z': + res, commandTag, err = cn.parseComplete(r.string()) + if err != nil { + return nil, "", err + } + case proto.ReadyForQuery: cn.processReadyForQuery(r) - if res == nil && err == nil { - err = errUnexpectedReady + if res == nil && resErr == nil { + resErr = errUnexpectedReady } - return res, commandTag, err - case 'E': - err = parseError(r) - case 'T', 'D', 'I': - if err != nil { + return res, commandTag, resErr + case proto.ErrorResponse: + resErr = parseError(r, "") + case proto.RowDescription, proto.DataRow, proto.EmptyQueryResponse: + if resErr != nil { cn.err.set(driver.ErrBadConn) - errorf("unexpected %q after error %s", t, err) + return nil, "", fmt.Errorf("pq: unexpected %q after error %s", t, resErr) } - if t == 'I' { + if t == proto.EmptyQueryResponse { res = emptyRows } // ignore any results default: cn.err.set(driver.ErrBadConn) - errorf("unknown %s response: %q", protocolState, t) + return nil, "", fmt.Errorf("pq: unknown %s response: %q", protocolState, t) } } } @@ -1998,108 +1756,6 @@ func parsePortalRowDescribe(r *readBuf) rowsHeader { } } -// parseEnviron tries to mimic some of libpq's environment handling -// -// To ease testing, it does not directly reference os.Environ, but is -// designed to accept its output. -// -// Environment-set connection information is intended to have a higher -// precedence than a library default but lower than any explicitly -// passed information (such as in the URL or connection string). -func parseEnviron(env []string) (out map[string]string) { - out = make(map[string]string) - - for _, v := range env { - parts := strings.SplitN(v, "=", 2) - - accrue := func(keyname string) { - out[keyname] = parts[1] - } - unsupported := func() { - panic(fmt.Sprintf("setting %v not supported", parts[0])) - } - - // The order of these is the same as is seen in the - // PostgreSQL 9.1 manual. Unsupported but well-defined - // keys cause a panic; these should be unset prior to - // execution. Options which pq expects to be set to a - // certain value are allowed, but must be set to that - // value if present (they can, of course, be absent). - switch parts[0] { - case "PGHOST": - accrue("host") - case "PGHOSTADDR": - unsupported() - case "PGPORT": - accrue("port") - case "PGDATABASE": - accrue("dbname") - case "PGUSER": - accrue("user") - case "PGPASSWORD": - accrue("password") - case "PGSERVICE", "PGSERVICEFILE", "PGREALM": - unsupported() - case "PGOPTIONS": - accrue("options") - case "PGAPPNAME": - accrue("application_name") - case "PGSSLMODE": - accrue("sslmode") - case "PGSSLCERT": - accrue("sslcert") - case "PGSSLKEY": - accrue("sslkey") - case "PGSSLROOTCERT": - accrue("sslrootcert") - case "PGSSLSNI": - accrue("sslsni") - case "PGREQUIRESSL", "PGSSLCRL": - unsupported() - case "PGREQUIREPEER": - unsupported() - case "PGKRBSRVNAME", "PGGSSLIB": - unsupported() - case "PGCONNECT_TIMEOUT": - accrue("connect_timeout") - case "PGCLIENTENCODING": - accrue("client_encoding") - case "PGDATESTYLE": - accrue("datestyle") - case "PGTZ": - accrue("timezone") - case "PGGEQO": - accrue("geqo") - case "PGSYSCONFDIR", "PGLOCALEDIR": - unsupported() - } - } - - return out -} - -// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". -func isUTF8(name string) bool { - // Recognize all sorts of silly things as "UTF-8", like Postgres does - s := strings.Map(alnumLowerASCII, name) - return s == "utf8" || s == "unicode" -} - -func alnumLowerASCII(ch rune) rune { - if 'A' <= ch && ch <= 'Z' { - return ch + ('a' - 'A') - } - if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' { - return ch - } - return -1 // discard -} - -// The database/sql/driver package says: -// All Conn implementations should implement the following interfaces: Pinger, SessionResetter, and Validator. -var _ driver.Pinger = &conn{} -var _ driver.SessionResetter = &conn{} - func (cn *conn) ResetSession(ctx context.Context) error { // Ensure bad connections are reported: From database/sql/driver: // If a connection is never returned to the connection pool but immediately reused, then diff --git a/vendor/github.com/lib/pq/conn_go115.go b/vendor/github.com/lib/pq/conn_go115.go deleted file mode 100644 index f4ef030f9..000000000 --- a/vendor/github.com/lib/pq/conn_go115.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build go1.15 -// +build go1.15 - -package pq - -import "database/sql/driver" - -var _ driver.Validator = &conn{} diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go index 63d4ca6aa..23a10aeec 100644 --- a/vendor/github.com/lib/pq/conn_go18.go +++ b/vendor/github.com/lib/pq/conn_go18.go @@ -6,22 +6,17 @@ import ( "database/sql/driver" "fmt" "io" - "io/ioutil" "time" -) -const ( - watchCancelDialContextTimeout = time.Second * 10 + "github.com/lib/pq/internal/proto" ) +const watchCancelDialContextTimeout = 10 * time.Second + // Implement the "QueryerContext" interface func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } finish := cn.watchCancel(ctx) - r, err := cn.query(query, list) + r, err := cn.query(query, args) if err != nil { if finish != nil { finish() @@ -57,7 +52,6 @@ func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, // Implement the "ConnBeginTx" interface func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { var mode string - switch sql.IsolationLevel(opts.Isolation) { case sql.LevelDefault: // Don't touch mode: use the server's default @@ -72,7 +66,6 @@ func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, default: return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) } - if opts.ReadOnly { mode += " READ ONLY" } else { @@ -95,7 +88,7 @@ func (cn *conn) Ping(ctx context.Context) error { if err != nil { return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger } - rows.Close() + _ = rows.Close() return nil } @@ -131,7 +124,7 @@ func (cn *conn) watchCancel(ctx context.Context) func() { select { case <-finished: cn.err.set(ctx.Err()) - cn.Close() + _ = cn.Close() case finished <- struct{}{}: } } @@ -140,55 +133,39 @@ func (cn *conn) watchCancel(ctx context.Context) func() { } func (cn *conn) cancel(ctx context.Context) error { - // Create a new values map (copy). This makes sure the connection created - // in this method cannot write to the same underlying data, which could - // cause a concurrent map write panic. This is necessary because cancel - // is called from a goroutine in watchCancel. - o := make(values) - for k, v := range cn.opts { - o[k] = v - } + // Use a copy since a new connection is created here. This is necessary + // because cancel is called from a goroutine in watchCancel. + cfg := cn.cfg.Clone() - c, err := dial(ctx, cn.dialer, o) + c, err := dial(ctx, cn.dialer, cfg) if err != nil { return err } - defer c.Close() + defer func() { _ = c.Close() }() - { - can := conn{ - c: c, - } - err = can.ssl(o) - if err != nil { - return err - } - - w := can.writeBuf(0) - w.int32(80877102) // cancel request code - w.int32(cn.processID) - w.int32(cn.secretKey) - - if err := can.sendStartupPacket(w); err != nil { - return err - } + cn2 := conn{c: c} + err = cn2.ssl(cfg) + if err != nil { + return err } - // Read until EOF to ensure that the server received the cancel. - { - _, err := io.Copy(ioutil.Discard, c) + w := cn2.writeBuf(0) + w.int32(proto.CancelRequestCode) + w.int32(cn.processID) + w.int32(cn.secretKey) + if err := cn2.sendStartupPacket(w); err != nil { return err } + + // Read until EOF to ensure that the server received the cancel. + _, err = io.Copy(io.Discard, c) + return err } // Implement the "StmtQueryContext" interface func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } finish := st.watchCancel(ctx) - r, err := st.query(list) + r, err := st.query(args) if err != nil { if finish != nil { finish() @@ -201,16 +178,19 @@ func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (dri // Implement the "StmtExecContext" interface func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - list := make([]driver.Value, len(args)) - for i, nv := range args { - list[i] = nv.Value - } - if finish := st.watchCancel(ctx); finish != nil { defer finish() } + if err := st.cn.err.get(); err != nil { + return nil, err + } - return st.Exec(list) + err := st.exec(args) + if err != nil { + return nil, st.cn.handleError(err) + } + res, _, err := st.cn.readExecuteResponse("simple query") + return res, st.cn.handleError(err) } // watchCancel is implemented on stmt in order to not mark the parent conn as bad @@ -220,10 +200,9 @@ func (st *stmt) watchCancel(ctx context.Context) func() { go func() { select { case <-done: - // At this point the function level context is canceled, - // so it must not be used for the additional network - // request to cancel the query. - // Create a new context to pass into the dial. + // At this point the function level context is canceled, so it + // must not be used for the additional network request to cancel + // the query. Create a new context to pass into the dial. ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) defer cancel() diff --git a/vendor/github.com/lib/pq/connector.go b/vendor/github.com/lib/pq/connector.go index 1145e1225..4c318662a 100644 --- a/vendor/github.com/lib/pq/connector.go +++ b/vendor/github.com/lib/pq/connector.go @@ -3,81 +3,458 @@ package pq import ( "context" "database/sql/driver" - "errors" "fmt" + "math/rand" + "net" + "net/netip" + neturl "net/url" "os" + "path/filepath" + "reflect" + "slices" + "sort" + "strconv" "strings" + "time" + "unicode" + + "github.com/lib/pq/internal/pqutil" +) + +type ( + // SSLMode is a sslmode setting. + SSLMode string + + // SSLNegotiation is a sslnegotiation setting. + SSLNegotiation string + + // TargetSessionAttrs is a target_session_attrs setting. + TargetSessionAttrs string + + // LoadBalanceHosts is a load_balance_hosts setting. + LoadBalanceHosts string +) + +// Values for [SSLMode] that pq supports. +const ( + // disable: No SSL + SSLModeDisable = SSLMode("disable") + + // require: require SSL, but skip verification. + SSLModeRequire = SSLMode("require") + + // verify-ca: require SSL and verify that the certificate was signed by a + // trusted CA. + SSLModeVerifyCA = SSLMode("verify-ca") + + // verify-full: require SSK and verify that the certificate was signed by a + // trusted CA and the server host name matches the one in the certificate. + SSLModeVerifyFull = SSLMode("verify-full") +) + +var sslModes = []SSLMode{SSLModeDisable, SSLModeRequire, SSLModeVerifyFull, SSLModeVerifyCA} + +// Values for [SSLNegotiation] that pq supports. +const ( + // Negotiate whether SSL should be used. This is the default. + SSLNegotiationPostgres = SSLNegotiation("postgres") + + // Always use SSL, don't try to negotiate. + SSLNegotiationDirect = SSLNegotiation("direct") +) + +var sslNegotiations = []SSLNegotiation{SSLNegotiationPostgres, SSLNegotiationDirect} + +// Values for [TargetSessionAttrs] that pq supports. +const ( + // Any successful connection is acceptable. This is the default. + TargetSessionAttrsAny = TargetSessionAttrs("any") + + // Session must accept read-write transactions by default: the server must + // not be in hot standby mode and default_transaction_read_only must be + // off. + TargetSessionAttrsReadWrite = TargetSessionAttrs("read-write") + + // Session must not accept read-write transactions by default. + TargetSessionAttrsReadOnly = TargetSessionAttrs("read-only") + + // Server must not be in hot standby mode. + TargetSessionAttrsPrimary = TargetSessionAttrs("primary") + + // Server must be in hot standby mode. + TargetSessionAttrsStandby = TargetSessionAttrs("standby") + + // First try to find a standby server, but if none of the listed hosts is a + // standby server, try again in any mode. + TargetSessionAttrsPreferStandby = TargetSessionAttrs("prefer-standby") +) + +var targetSessionAttrs = []TargetSessionAttrs{TargetSessionAttrsAny, + TargetSessionAttrsReadWrite, TargetSessionAttrsReadOnly, TargetSessionAttrsPrimary, + TargetSessionAttrsStandby, TargetSessionAttrsPreferStandby} + +// Values for [LoadBalanceHosts] that pq supports. +const ( + // Don't load balance; try hosts in the order in which they're provided. + // This is the default. + LoadBalanceHostsDisable = LoadBalanceHosts("disable") + + // Hosts are tried in random order to balance connections across multiple + // PostgreSQL servers. + // + // When using this value it's recommended to also configure a reasonable + // value for connect_timeout. Because then, if one of the nodes that are + // used for load balancing is not responding, a new node will be tried. + LoadBalanceHostsRandom = LoadBalanceHosts("random") ) +var loadBalanceHosts = []LoadBalanceHosts{LoadBalanceHostsDisable, LoadBalanceHostsRandom} + // Connector represents a fixed configuration for the pq driver with a given -// name. Connector satisfies the database/sql/driver Connector interface and -// can be used to create any number of DB Conn's via the database/sql OpenDB -// function. -// -// See https://golang.org/pkg/database/sql/driver/#Connector. -// See https://golang.org/pkg/database/sql/#OpenDB. +// dsn. Connector satisfies the [database/sql/driver.Connector] interface and +// can be used to create any number of DB Conn's via [sql.OpenDB]. type Connector struct { - opts values + cfg Config dialer Dialer } -// Connect returns a connection to the database using the fixed configuration -// of this Connector. Context is not used. -func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { - return c.open(ctx) +// NewConnector returns a connector for the pq driver in a fixed configuration +// with the given dsn. The returned connector can be used to create any number +// of equivalent Conn's. The returned connector is intended to be used with +// [sql.OpenDB]. +func NewConnector(dsn string) (*Connector, error) { + cfg, err := NewConfig(dsn) + if err != nil { + return nil, err + } + return NewConnectorConfig(cfg) +} + +// NewConnectorConfig returns a connector for the pq driver in a fixed +// configuration with the given [Config]. The returned connector can be used to +// create any number of equivalent Conn's. The returned connector is intended to +// be used with [sql.OpenDB]. +func NewConnectorConfig(cfg Config) (*Connector, error) { + return &Connector{cfg: cfg, dialer: defaultDialer{}}, nil } +// Connect returns a connection to the database using the fixed configuration of +// this Connector. Context is not used. +func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { return c.open(ctx) } + // Dialer allows change the dialer used to open connections. -func (c *Connector) Dialer(dialer Dialer) { - c.dialer = dialer -} +func (c *Connector) Dialer(dialer Dialer) { c.dialer = dialer } // Driver returns the underlying driver of this Connector. -func (c *Connector) Driver() driver.Driver { - return &Driver{} -} +func (c *Connector) Driver() driver.Driver { return &Driver{} } -// NewConnector returns a connector for the pq driver in a fixed configuration -// with the given dsn. The returned connector can be used to create any number -// of equivalent Conn's. The returned connector is intended to be used with -// database/sql.OpenDB. +// Config holds options pq supports when connecting to PostgreSQL. // -// See https://golang.org/pkg/database/sql/driver/#Connector. -// See https://golang.org/pkg/database/sql/#OpenDB. -func NewConnector(dsn string) (*Connector, error) { - var err error - o := make(values) +// The postgres struct tag is used for the value from the DSN (e.g. +// "dbname=abc"), and the env struct tag is used for the environment variable +// (e.g. "PGDATABASE=abc") +type Config struct { + // The host to connect to. Absolute paths and values that start with @ are + // for unix domain sockets. Defaults to localhost. + // + // A comma-separated list of host names is also accepted, in which case each + // host name in the list is tried in order or randomly if load_balance_hosts + // is set. An empty item selects the default of localhost. The + // target_session_attrs option controls properties the host must have to be + // considered acceptable. + Host string `postgres:"host" env:"PGHOST"` - // A number of defaults are applied here, in this order: + // IPv4 or IPv6 address to connect to. Using hostaddr allows the application + // to avoid a host name lookup, which might be important in applications + // with time constraints. A hostname is required for sslmode=verify-full and + // the GSSAPI or SSPI authentication methods. // - // * Very low precedence defaults applied in every situation - // * Environment variables - // * Explicitly passed connection information - o["host"] = "localhost" - o["port"] = "5432" - // N.B.: Extra float digits should be set to 3, but that breaks - // Postgres 8.4 and older, where the max is 2. - o["extra_float_digits"] = "2" - for k, v := range parseEnviron(os.Environ()) { - o[k] = v + // The following rules are used: + // + // - If host is given without hostaddr, a host name lookup occurs. + // + // - If hostaddr is given without host, the value for hostaddr gives the + // server network address. The connection attempt will fail if the + // authentication method requires a host name. + // + // - If both host and hostaddr are given, the value for hostaddr gives the + // server network address. The value for host is ignored unless the + // authentication method requires it, in which case it will be used as the + // host name. + // + // A comma-separated list of hostaddr values is also accepted, in which case + // each host in the list is tried in order or randonly if load_balance_hosts + // is set. An empty item causes the corresponding host name to be used, or + // the default host name if that is empty as well. The target_session_attrs + // option controls properties the host must have to be considered + // acceptable. + Hostaddr netip.Addr `postgres:"hostaddr" env:"PGHOSTADDR"` + + // The port to connect to. Defaults to 5432. + // + // If multiple hosts were given in the host or hostaddr parameters, this + // parameter may specify a comma-separated list of ports of the same length + // as the host list, or it may specify a single port number to be used for + // all hosts. An empty string, or an empty item in a comma-separated list, + // specifies the default of 5432. + Port uint16 `postgres:"port" env:"PGPORT"` + + // The name of the database to connect to. + Database string `postgres:"dbname" env:"PGDATABASE"` + + // The user to sign in as. Defaults to the current user. + User string `postgres:"user" env:"PGUSER"` + + // The user's password. + Password string `postgres:"password" env:"PGPASSWORD"` + + // Path to [pgpass] file to store passwords; overrides Password. + // + // [pgpass]: http://www.postgresql.org/docs/current/static/libpq-pgpass.html + Passfile string `postgres:"passfile" env:"PGPASSFILE"` + + // Commandline options to send to the server at connection start. + Options string `postgres:"options" env:"PGOPTIONS"` + + // Application name, displayed in pg_stat_activity and log entries. + ApplicationName string `postgres:"application_name" env:"PGAPPNAME"` + + // Used if application_name is not given. Specifying a fallback name is + // useful in generic utility programs that wish to set a default application + // name but allow it to be overridden by the user. + FallbackApplicationName string `postgres:"fallback_application_name" env:"-"` + + // Whether to use SSL. Defaults to "require" (different from libpq's default + // of "prefer"). + // + // [RegisterTLSConfig] can be used to registers a custom [tls.Config], which + // can be used by setting sslmode=pqgo-«key» in the connection string. + SSLMode SSLMode `postgres:"sslmode" env:"PGSSLMODE"` + + // When set to "direct" it will use SSL without negotiation (PostgreSQL ≥17 only). + SSLNegotiation SSLNegotiation `postgres:"sslnegotiation" env:"PGSSLNEGOTIATION"` + + // Cert file location. The file must contain PEM encoded data. + SSLCert string `postgres:"sslcert" env:"PGSSLCERT"` + + // Key file location. The file must contain PEM encoded data. + SSLKey string `postgres:"sslkey" env:"PGSSLKEY"` + + // The location of the root certificate file. The file must contain PEM encoded data. + SSLRootCert string `postgres:"sslrootcert" env:"PGSSLROOTCERT"` + + // By default SNI is on, any value which is not starting with "1" disables + // SNI. + SSLSNI bool `postgres:"sslsni" env:"PGSSLSNI"` + + // Interpert sslcert and sslkey as PEM encoded data, rather than a path to a + // PEM file. This is a pq extension, not supported in libpq. + SSLInline bool `postgres:"sslinline" env:"-"` + + // GSS (Kerberos) service name when constructing the SPN (default is + // postgres). This will be combined with the host to form the full SPN: + // krbsrvname/host. + KrbSrvname string `postgres:"krbsrvname" env:"PGKRBSRVNAME"` + + // GSS (Kerberos) SPN. This takes priority over krbsrvname if present. This + // is a pq extension, not supported in libpq. + KrbSpn string `postgres:"krbspn" env:"-"` + + // Maximum time to wait while connecting, in seconds. Zero, negative, or not + // specified means wait indefinitely + ConnectTimeout time.Duration `postgres:"connect_timeout" env:"PGCONNECT_TIMEOUT"` + + // Whether to always send []byte parameters over as binary. Enables single + // round-trip mode for non-prepared Query calls. This is a pq extension, not + // supported in libpq. + BinaryParameters bool `postgres:"binary_parameters" env:"-"` + + // This connection should never use the binary format when receiving query + // results from prepared statements. Only provided for debugging. This is a + // pq extension, not supported in libpq. + DisablePreparedBinaryResult bool `postgres:"disable_prepared_binary_result" env:"-"` + + // Client encoding; pq only supports UTF8 and this must be blank or "UTF8". + ClientEncoding string `postgres:"client_encoding" env:"PGCLIENTENCODING"` + + // Date/time representation to use; pq only supports "ISO, MDY" and this + // must be blank or "ISO, MDY". + Datestyle string `postgres:"datestyle" env:"PGDATESTYLE"` + + // Default time zone. + TZ string `postgres:"tz" env:"PGTZ"` + + // Default mode for the genetic query optimizer. + Geqo string `postgres:"geqo" env:"PGGEQO"` + + // Determine whether the session must have certain properties to be + // acceptable. It's typically used in combination with multiple host names + // to select the first acceptable alternative among several hosts. + TargetSessionAttrs TargetSessionAttrs `postgres:"target_session_attrs" env:"PGTARGETSESSIONATTRS"` + + // Controls the order in which the client tries to connect to the available + // hosts. Once a connection attempt is successful no other hosts will be + // tried. This parameter is typically used in combination with multiple host + // names. + // + // This parameter can be used in combination with target_session_attrs to, + // for example, load balance over standby servers only. Once successfully + // connected, subsequent queries on the returned connection will all be sent + // to the same server. + LoadBalanceHosts LoadBalanceHosts `postgres:"load_balance_hosts" env:"PGLOADBALANCEHOSTS"` + + // Runtime parameters: any unrecognized parameter in the DSN will be added + // to this and sent to PostgreSQL during startup. + Runtime map[string]string `postgres:"-" env:"-"` + + // Multi contains additional connection details. The first value is + // available in [Config.Host], [Config.Hostaddr], and [Config.Port], and + // additional ones (if any) are available here. + Multi []ConfigMultihost + + // Record which parameters were given, so we can distinguish between an + // empty string "not given at all". + // + // The alternative is to use pointers or sql.Null[..], but that's more + // awkward to use. + set []string `env:"set"` + + multiHost []string + multiHostaddr []netip.Addr + multiPort []uint16 +} + +// ConfigMultihost specifies an additional server to try to connect to. +type ConfigMultihost struct { + Host string + Hostaddr netip.Addr + Port uint16 +} + +// NewConfig creates a new [Config] from the current environment and given DSN. +// +// A subset of the connection parameters supported by PostgreSQL are supported +// by pq; see the [Config] struct fields for supported parameters. pq also lets +// you specify any [run-time parameter] (such as search_path or work_mem) +// directly in the connection string. This is different from libpq, which does +// not allow run-time parameters in the connection string, instead requiring you +// to supply them in the options parameter. +// +// # key=value connection strings +// +// For key=value strings, use single quotes for values that contain whitespace +// or empty values. A backslash will escape the next character: +// +// "user=pqgo password='with spaces'" +// "user=''" +// "user=space\ man password='it\'s valid'" +// +// # URL connection strings +// +// pq supports URL-style postgres:// or postgresql:// connection strings in the +// form: +// +// postgres[ql]://[user[:pwd]@][net-location][:port][/dbname][?param1=value1&...] +// +// Go's [net/url.Parse] is more strict than PostgreSQL's URL parser and will +// (correctly) reject %2F in the host part. This means that unix-socket URLs: +// +// postgres://[user[:pwd]@][unix-socket][:port[/dbname]][?param1=value1&...] +// postgres://%2Ftmp%2Fpostgres/db +// +// will not work. You will need to use "host=/tmp/postgres dbname=db". +// +// Similarly, multiple ports also won't work, but ?port= will: +// +// postgres://host1,host2:5432,6543/dbname Doesn't work +// postgres://host1,host2/dbname?port=5432,6543 Works +// +// # Environment +// +// Most [PostgreSQL environment variables] are supported by pq. Environment +// variables have a lower precedence than explicitly provided connection +// parameters. pq will return an error if environment variables it does not +// support are set. Environment variables have a lower precedence than +// explicitly provided connection parameters. +// +// [run-time parameter]: http://www.postgresql.org/docs/current/static/runtime-config.html +// [PostgreSQL environment variables]: http://www.postgresql.org/docs/current/static/libpq-envars.html +func NewConfig(dsn string) (Config, error) { + return newConfig(dsn, os.Environ()) +} + +// Clone returns a copy of the [Config]. +func (cfg Config) Clone() Config { + rt := make(map[string]string) + for k, v := range cfg.Runtime { + rt[k] = v } + c := cfg + c.Runtime = rt + c.set = append([]string{}, cfg.set...) + return c +} - if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { - dsn, err = ParseURL(dsn) - if err != nil { - return nil, err - } +// hosts returns a slice of copies of this config, one for each host. +func (cfg Config) hosts() []Config { + cfgs := make([]Config, 1, len(cfg.Multi)+1) + cfgs[0] = cfg.Clone() + for _, m := range cfg.Multi { + c := cfg.Clone() + c.Host, c.Hostaddr, c.Port = m.Host, m.Hostaddr, m.Port + cfgs = append(cfgs, c) } - if err := parseOpts(dsn, o); err != nil { - return nil, err + if cfg.LoadBalanceHosts == LoadBalanceHostsRandom { + rand.Shuffle(len(cfgs), func(i, j int) { cfgs[i], cfgs[j] = cfgs[j], cfgs[i] }) } - // Use the "fallback" application name if necessary - if fallback, ok := o["fallback_application_name"]; ok { - if _, ok := o["application_name"]; !ok { - o["application_name"] = fallback + return cfgs +} + +func newConfig(dsn string, env []string) (Config, error) { + cfg := Config{Host: "localhost", Port: 5432, SSLSNI: true} + if err := cfg.fromEnv(env); err != nil { + return Config{}, err + } + if err := cfg.fromDSN(dsn); err != nil { + return Config{}, err + } + + // Need to have exactly the same number of host and hostaddr, or only specify one. + if cfg.isset("host") && cfg.Host != "" && cfg.Hostaddr != (netip.Addr{}) && len(cfg.multiHost) != len(cfg.multiHostaddr) { + return Config{}, fmt.Errorf("pq: could not match %d host names to %d hostaddr values", + len(cfg.multiHost)+1, len(cfg.multiHostaddr)+1) + } + // Need one port that applies to all or exactly the same number of ports as hosts. + l, ll := max(len(cfg.multiHost), len(cfg.multiHostaddr)), len(cfg.multiPort) + if l > 0 && ll > 0 && l != ll { + return Config{}, fmt.Errorf("pq: could not match %d port numbers to %d hosts", ll+1, l+1) + } + + // Populate Multi + if len(cfg.multiHostaddr) > len(cfg.multiHost) { + cfg.multiHost = make([]string, len(cfg.multiHostaddr)) + } + for i, h := range cfg.multiHost { + p := cfg.Port + if len(cfg.multiPort) > 0 { + p = cfg.multiPort[i] + } + var addr netip.Addr + if len(cfg.multiHostaddr) > 0 { + addr = cfg.multiHostaddr[i] } + cfg.Multi = append(cfg.Multi, ConfigMultihost{ + Host: h, + Port: p, + Hostaddr: addr, + }) + } + + // Use the "fallback" application name if necessary + if cfg.isset("fallback_application_name") && !cfg.isset("application_name") { + cfg.ApplicationName = cfg.FallbackApplicationName } // We can't work with any client_encoding other than UTF-8 currently. @@ -87,34 +464,488 @@ func NewConnector(dsn string) (*Connector, error) { // parsing its value is not worth it. Instead, we always explicitly send // client_encoding as a separate run-time parameter, which should override // anything set in options. - if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { - return nil, errors.New("client_encoding must be absent or 'UTF8'") + if cfg.isset("client_encoding") && !isUTF8(cfg.ClientEncoding) { + return Config{}, fmt.Errorf(`pq: unsupported client_encoding %q: must be absent or "UTF8"`, cfg.ClientEncoding) } - o["client_encoding"] = "UTF8" // DateStyle needs a similar treatment. - if datestyle, ok := o["datestyle"]; ok { - if datestyle != "ISO, MDY" { - return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle) + if cfg.isset("datestyle") && cfg.Datestyle != "ISO, MDY" { + return Config{}, fmt.Errorf(`pq: unsupported datestyle %q: must be absent or "ISO, MDY"`, cfg.Datestyle) + } + cfg.ClientEncoding, cfg.Datestyle = "UTF8", "ISO, MDY" + + // Set default user if not explicitly provided. + if !cfg.isset("user") { + u, err := pqutil.User() + if err != nil { + return Config{}, err } - } else { - o["datestyle"] = "ISO, MDY" + cfg.User = u + } + + // SSL is not necessary or supported over UNIX domain sockets. + if nw, _ := cfg.network(); nw == "unix" { + cfg.SSLMode = SSLModeDisable + } + + return cfg, nil +} + +func (cfg Config) network() (string, string) { + if cfg.Hostaddr != (netip.Addr{}) { + return "tcp", net.JoinHostPort(cfg.Hostaddr.String(), strconv.Itoa(int(cfg.Port))) } + // UNIX domain sockets are either represented by an (absolute) file system + // path or they live in the abstract name space (starting with an @). + if filepath.IsAbs(cfg.Host) || strings.HasPrefix(cfg.Host, "@") { + sockPath := filepath.Join(cfg.Host, ".s.PGSQL."+strconv.Itoa(int(cfg.Port))) + return "unix", sockPath + } + return "tcp", net.JoinHostPort(cfg.Host, strconv.Itoa(int(cfg.Port))) +} + +func (cfg *Config) fromEnv(env []string) error { + e := make(map[string]string) + for _, v := range env { + k, v, ok := strings.Cut(v, "=") + if !ok { + continue + } + switch k { + case "PGREQUIREAUTH", "PGCHANNELBINDING", "PGSERVICE", "PGSERVICEFILE", "PGREALM", + "PGSSLCERTMODE", "PGSSLCOMPRESSION", "PGREQUIRESSL", "PGSSLCRL", "PGREQUIREPEER", + "PGSYSCONFDIR", "PGLOCALEDIR", "PGSSLCRLDIR", "PGSSLMINPROTOCOLVERSION", "PGSSLMAXPROTOCOLVERSION", + "PGGSSENCMODE", "PGGSSDELEGATION", "PGMINPROTOCOLVERSION", "PGMAXPROTOCOLVERSION", "PGGSSLIB": + return fmt.Errorf("pq: environment variable $%s is not supported", k) + case "PGKRBSRVNAME": + if newGss == nil { + return fmt.Errorf("pq: environment variable $%s is not supported as Kerberos is not enabled", k) + } + } + e[k] = v + } + return cfg.setFromTag(e, "env") +} - // If a user is not provided by any other means, the last - // resort is to use the current operating system provided user - // name. - if _, ok := o["user"]; !ok { - u, err := userCurrent() +// parseOpts parses the options from name and adds them to the values. +// +// The parsing code is based on conninfo_parse from libpq's fe-connect.c +func (cfg *Config) fromDSN(dsn string) error { + if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { + var err error + dsn, err = convertURL(dsn) if err != nil { - return nil, err + return err + } + } + + var ( + opt = make(map[string]string) + s = []rune(dsn) + i int + next = func() (rune, bool) { + if i >= len(s) { + return 0, false + } + r := s[i] + i++ + return r, true + } + skipSpaces = func() (rune, bool) { + r, ok := next() + for unicode.IsSpace(r) && ok { + r, ok = next() + } + return r, ok + } + ) + + for { + var ( + keyRunes, valRunes []rune + r rune + ok bool + ) + + if r, ok = skipSpaces(); !ok { + break + } + + // Scan the key + for !unicode.IsSpace(r) && r != '=' { + keyRunes = append(keyRunes, r) + if r, ok = next(); !ok { + break + } + } + + // Skip any whitespace if we're not at the = yet + if r != '=' { + r, ok = skipSpaces() + } + + // The current character should be = + if r != '=' || !ok { + return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) + } + + // Skip any whitespace after the = + if r, ok = skipSpaces(); !ok { + // If we reach the end here, the last value is just an empty string as per libpq. + opt[string(keyRunes)] = "" + break + } + + if r != '\'' { + for !unicode.IsSpace(r) { + if r == '\\' { + if r, ok = next(); !ok { + return fmt.Errorf(`missing character after backslash`) + } + } + valRunes = append(valRunes, r) + + if r, ok = next(); !ok { + break + } + } + } else { + quote: + for { + if r, ok = next(); !ok { + return fmt.Errorf(`unterminated quoted string literal in connection string`) + } + switch r { + case '\'': + break quote + case '\\': + r, _ = next() + fallthrough + default: + valRunes = append(valRunes, r) + } + } } - o["user"] = u + + opt[string(keyRunes)] = string(valRunes) + } + + return cfg.setFromTag(opt, "postgres") +} + +func (cfg *Config) setFromTag(o map[string]string, tag string) error { + f := "pq: wrong value for %q: " + if tag == "env" { + f = "pq: wrong value for $%s: " + } + var ( + types = reflect.TypeOf(cfg).Elem() + values = reflect.ValueOf(cfg).Elem() + ) + for i := 0; i < types.NumField(); i++ { + var ( + rt = types.Field(i) + rv = values.Field(i) + k = rt.Tag.Get(tag) + connectTimeout = (tag == "postgres" && k == "connect_timeout") || (tag == "env" && k == "PGCONNECT_TIMEOUT") + host = (tag == "postgres" && k == "host") || (tag == "env" && k == "PGHOST") + hostaddr = (tag == "postgres" && k == "hostaddr") || (tag == "env" && k == "PGHOSTADDR") + port = (tag == "postgres" && k == "port") || (tag == "env" && k == "PGPORT") + sslmode = (tag == "postgres" && k == "sslmode") || (tag == "env" && k == "PGSSLMODE") + sslnegotiation = (tag == "postgres" && k == "sslnegotiation") || (tag == "env" && k == "PGSSLNEGOTIATION") + targetsessionattrs = (tag == "postgres" && k == "target_session_attrs") || (tag == "env" && k == "PGTARGETSESSIONATTRS") + loadbalancehosts = (tag == "postgres" && k == "load_balance_hosts") || (tag == "env" && k == "PGLOADBALANCEHOSTS") + ) + if k == "" || k == "-" { + continue + } + + v, ok := o[k] + delete(o, k) + if ok { + if t, ok := rt.Tag.Lookup("postgres"); ok && t != "" && t != "-" { + cfg.set = append(cfg.set, t) + } + switch rt.Type.Kind() { + default: + return fmt.Errorf("don't know how to set %s: unknown type %s", rt.Name, rt.Type.Kind()) + case reflect.Struct: + if rt.Type == reflect.TypeOf(netip.Addr{}) { + if hostaddr { + vv := strings.Split(v, ",") + v = vv[0] + for _, vvv := range vv[1:] { + if vvv == "" { + cfg.multiHostaddr = append(cfg.multiHostaddr, netip.Addr{}) + } else { + ip, err := netip.ParseAddr(vvv) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + cfg.multiHostaddr = append(cfg.multiHostaddr, ip) + } + } + } + ip, err := netip.ParseAddr(v) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + rv.Set(reflect.ValueOf(ip)) + } else { + return fmt.Errorf("don't know how to set %s: unknown type %s", rt.Name, rt.Type) + } + case reflect.String: + if sslmode && !slices.Contains(sslModes, SSLMode(v)) && !(strings.HasPrefix(v, "pqgo-") && hasTLSConfig(v[5:])) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(sslModes)) + } + if sslnegotiation && !slices.Contains(sslNegotiations, SSLNegotiation(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(sslNegotiations)) + } + if targetsessionattrs && !slices.Contains(targetSessionAttrs, TargetSessionAttrs(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(targetSessionAttrs)) + } + if loadbalancehosts && !slices.Contains(loadBalanceHosts, LoadBalanceHosts(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(loadBalanceHosts)) + } + if host { + vv := strings.Split(v, ",") + v = vv[0] + for i, vvv := range vv[1:] { + if vvv == "" { + vv[i+1] = "localhost" + } + } + cfg.multiHost = append(cfg.multiHost, vv[1:]...) + } + rv.SetString(v) + case reflect.Int64: + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + if connectTimeout { + n = int64(time.Duration(n) * time.Second) + } + rv.SetInt(n) + case reflect.Uint16: + if port { + vv := strings.Split(v, ",") + v = vv[0] + for _, vvv := range vv[1:] { + if vvv == "" { + vvv = "5432" + } + n, err := strconv.ParseUint(vvv, 10, 16) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + cfg.multiPort = append(cfg.multiPort, uint16(n)) + } + } + n, err := strconv.ParseUint(v, 10, 16) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + rv.SetUint(n) + case reflect.Bool: + b, err := pqutil.ParseBool(v) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + rv.SetBool(b) + } + } + } + + // Set run-time; we delete map keys as they're set in the struct. + if tag == "postgres" { + // Make sure database= sets dbname=; in startup() we send database for + // dbname, and if we have both set it's inconsistent as the loop order + // is a map. + if d, ok := o["database"]; ok { + delete(o, "database") + if o["dbname"] == "" { + o["dbname"] = d + } + } + cfg.Runtime = o + } + + return nil +} + +func (cfg Config) isset(name string) bool { + return slices.Contains(cfg.set, name) +} + +// Convert to a map; used only in tests. +func (cfg Config) tomap() map[string]string { + var ( + o = make(map[string]string) + values = reflect.ValueOf(cfg) + types = reflect.TypeOf(cfg) + ) + for i := 0; i < types.NumField(); i++ { + var ( + rt = types.Field(i) + rv = values.Field(i) + k = rt.Tag.Get("postgres") + ) + if k == "" || k == "-" { + continue + } + if !rv.IsZero() || slices.Contains(cfg.set, k) { + switch rt.Type.Kind() { + default: + if s, ok := rv.Interface().(fmt.Stringer); ok { + o[k] = s.String() + } else { + o[k] = rv.String() + } + case reflect.Uint16: + n := rv.Uint() + o[k] = strconv.FormatUint(n, 10) + case reflect.Int64: + n := rv.Int() + if k == "connect_timeout" { + n = int64(time.Duration(n) / time.Second) + } + o[k] = strconv.FormatInt(n, 10) + case reflect.Bool: + if rv.Bool() { + o[k] = "yes" + } else { + o[k] = "no" + } + } + } + } + for k, v := range cfg.Runtime { + o[k] = v + } + return o +} + +// Create DSN for this config; used only in tests. +func (cfg Config) string() string { + var ( + m = cfg.tomap() + keys = make([]string, 0, len(m)) + ) + for k := range m { + switch k { + case "datestyle", "client_encoding": + continue + case "host", "port", "user", "sslsni": + if !cfg.isset(k) { + continue + } + } + if k == "host" && len(cfg.multiHost) > 0 { + m[k] += "," + strings.Join(cfg.multiHost, ",") + } + if k == "hostaddr" && len(cfg.multiHostaddr) > 0 { + for _, ha := range cfg.multiHostaddr { + m[k] += "," + if ha != (netip.Addr{}) { + m[k] += ha.String() + } + } + } + if k == "port" && len(cfg.multiPort) > 0 { + for _, p := range cfg.multiPort { + m[k] += "," + strconv.Itoa(int(p)) + } + } + keys = append(keys, k) + } + sort.Strings(keys) + + var b strings.Builder + for i, k := range keys { + if i > 0 { + b.WriteByte(' ') + } + b.WriteString(k) + b.WriteByte('=') + var ( + v = m[k] + nv = make([]rune, 0, len(v)+2) + quote = v == "" + ) + for _, c := range v { + if c == ' ' { + quote = true + } + if c == '\'' { + nv = append(nv, '\\') + } + nv = append(nv, c) + } + if quote { + b.WriteByte('\'') + } + b.WriteString(string(nv)) + if quote { + b.WriteByte('\'') + } + } + return b.String() +} + +// Recognize all sorts of silly things as "UTF-8", like Postgres does +func isUTF8(name string) bool { + s := strings.Map(func(c rune) rune { + if 'A' <= c && c <= 'Z' { + return c + ('a' - 'A') + } + if 'a' <= c && c <= 'z' || '0' <= c && c <= '9' { + return c + } + return -1 // discard + }, name) + return s == "utf8" || s == "unicode" +} + +func convertURL(url string) (string, error) { + u, err := neturl.Parse(url) + if err != nil { + return "", err + } + + if u.Scheme != "postgres" && u.Scheme != "postgresql" { + return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + var kvs []string + escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`) + accrue := func(k, v string) { + if v != "" { + kvs = append(kvs, k+"='"+escaper.Replace(v)+"'") + } + } + + if u.User != nil { + pw, _ := u.User.Password() + accrue("user", u.User.Username()) + accrue("password", pw) + } + + if host, port, err := net.SplitHostPort(u.Host); err != nil { + accrue("host", u.Host) + } else { + accrue("host", host) + accrue("port", port) + } + + if u.Path != "" { + accrue("dbname", u.Path[1:]) } - // SSL is not necessary or supported over UNIX domain sockets - if network, _ := network(o); network == "unix" { - o["sslmode"] = "disable" + q := u.Query() + for k := range q { + accrue(k, q.Get(k)) } - return &Connector{opts: o, dialer: defaultDialer{}}, nil + sort.Strings(kvs) // Makes testing easier (not a performance concern) + return strings.Join(kvs, " "), nil } diff --git a/vendor/github.com/lib/pq/copy.go b/vendor/github.com/lib/pq/copy.go index a8f16b2b2..c72f6ceb7 100644 --- a/vendor/github.com/lib/pq/copy.go +++ b/vendor/github.com/lib/pq/copy.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" "sync" + + "github.com/lib/pq/internal/proto" ) var ( @@ -18,38 +20,41 @@ var ( errCopyInProgress = errors.New("pq: COPY in progress") ) -// CopyIn creates a COPY FROM statement which can be prepared with -// Tx.Prepare(). The target table should be visible in search_path. +// CopyIn creates a COPY FROM statement which can be prepared with Tx.Prepare(). +// The target table should be visible in search_path. +// +// It copies all columns if the list of columns is empty. func CopyIn(table string, columns ...string) string { - buffer := bytes.NewBufferString("COPY ") - BufferQuoteIdentifier(table, buffer) - buffer.WriteString(" (") - makeStmt(buffer, columns...) - return buffer.String() + b := bytes.NewBufferString("COPY ") + BufferQuoteIdentifier(table, b) + makeStmt(b, columns...) + return b.String() +} + +// CopyInSchema creates a COPY FROM statement which can be prepared with +// Tx.Prepare(). +func CopyInSchema(schema, table string, columns ...string) string { + b := bytes.NewBufferString("COPY ") + BufferQuoteIdentifier(schema, b) + b.WriteRune('.') + BufferQuoteIdentifier(table, b) + makeStmt(b, columns...) + return b.String() } -// MakeStmt makes the stmt string for CopyIn and CopyInSchema. -func makeStmt(buffer *bytes.Buffer, columns ...string) { - //s := bytes.NewBufferString() +func makeStmt(b *bytes.Buffer, columns ...string) { + if len(columns) == 0 { + b.WriteString(" FROM STDIN") + return + } + b.WriteString(" (") for i, col := range columns { if i != 0 { - buffer.WriteString(", ") + b.WriteString(", ") } - BufferQuoteIdentifier(col, buffer) + BufferQuoteIdentifier(col, b) } - buffer.WriteString(") FROM STDIN") -} - -// CopyInSchema creates a COPY FROM statement which can be prepared with -// Tx.Prepare(). -func CopyInSchema(schema, table string, columns ...string) string { - buffer := bytes.NewBufferString("COPY ") - BufferQuoteIdentifier(schema, buffer) - buffer.WriteRune('.') - BufferQuoteIdentifier(table, buffer) - buffer.WriteString(" (") - makeStmt(buffer, columns...) - return buffer.String() + b.WriteString(") FROM STDIN") } type copyin struct { @@ -72,7 +77,7 @@ const ciBufferSize = 64 * 1024 // flush buffer before the buffer is filled up and needs reallocation const ciBufferFlushSize = 63 * 1024 -func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) { +func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, resErr error) { if !cn.isInTransaction() { return nil, errCopyNotSupportedOutsideTxn } @@ -84,69 +89,83 @@ func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) { done: make(chan bool, 1), } // add CopyData identifier + 4 bytes for message length - ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0) + ci.buffer = append(ci.buffer, byte(proto.CopyDataRequest), 0, 0, 0, 0) - b := cn.writeBuf('Q') + b := cn.writeBuf(proto.Query) b.string(q) - cn.send(b) + err := cn.send(b) + if err != nil { + return nil, err + } awaitCopyInResponse: for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return nil, err + } switch t { - case 'G': + case proto.CopyInResponse: if r.byte() != 0 { - err = errBinaryCopyNotSupported + resErr = errBinaryCopyNotSupported break awaitCopyInResponse } go ci.resploop() return ci, nil - case 'H': - err = errCopyToNotSupported + case proto.CopyOutResponse: + resErr = errCopyToNotSupported break awaitCopyInResponse - case 'E': - err = parseError(r) - case 'Z': - if err == nil { + case proto.ErrorResponse: + resErr = parseError(r, q) + case proto.ReadyForQuery: + if resErr == nil { ci.setBad(driver.ErrBadConn) - errorf("unexpected ReadyForQuery in response to COPY") + return nil, fmt.Errorf("pq: unexpected ReadyForQuery in response to COPY") } cn.processReadyForQuery(r) - return nil, err + return nil, resErr default: ci.setBad(driver.ErrBadConn) - errorf("unknown response for copy query: %q", t) + return nil, fmt.Errorf("pq: unknown response for copy query: %q", t) } } // something went wrong, abort COPY before we return - b = cn.writeBuf('f') - b.string(err.Error()) - cn.send(b) + b = cn.writeBuf(proto.CopyFail) + b.string(resErr.Error()) + err = cn.send(b) + if err != nil { + return nil, err + } for { - t, r := cn.recv1() + t, r, err := cn.recv1() + if err != nil { + return nil, err + } + switch t { - case 'c', 'C', 'E': - case 'Z': + case proto.CopyDoneResponse, proto.CommandComplete, proto.ErrorResponse: + case proto.ReadyForQuery: // correctly aborted, we're done cn.processReadyForQuery(r) - return nil, err + return nil, resErr default: ci.setBad(driver.ErrBadConn) - errorf("unknown response for CopyFail: %q", t) + return nil, fmt.Errorf("pq: unknown response for CopyFail: %q", t) } } } -func (ci *copyin) flush(buf []byte) { +func (ci *copyin) flush(buf []byte) error { + if len(buf)-1 > proto.MaxUint32 { + return errors.New("pq: too many columns") + } // set message length (without message identifier) binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1)) _, err := ci.cn.c.Write(buf) - if err != nil { - panic(err) - } + return err } func (ci *copyin) resploop() { @@ -160,20 +179,23 @@ func (ci *copyin) resploop() { return } switch t { - case 'C': + case proto.CommandComplete: // complete - res, _ := ci.cn.parseComplete(r.string()) + res, _, err := ci.cn.parseComplete(r.string()) + if err != nil { + panic(err) + } ci.setResult(res) - case 'N': + case proto.NoticeResponse: if n := ci.cn.noticeHandler; n != nil { - n(parseError(&r)) + n(parseError(&r, "")) } - case 'Z': + case proto.ReadyForQuery: ci.cn.processReadyForQuery(&r) ci.done <- true return - case 'E': - err := parseError(&r) + case proto.ErrorResponse: + err := parseError(&r, "") ci.setError(err) default: ci.setBad(driver.ErrBadConn) @@ -240,16 +262,13 @@ func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) { // You need to call Exec(nil) to sync the COPY stream and to get any // errors from pending data, since Stmt.Close() doesn't return errors // to the user. -func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { +func (ci *copyin) Exec(v []driver.Value) (driver.Result, error) { if ci.closed { return nil, errCopyInClosed } - if err := ci.getBad(); err != nil { return nil, err } - defer ci.cn.errRecover(&err) - if err := ci.err(); err != nil { return nil, err } @@ -258,13 +277,18 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { if err := ci.Close(); err != nil { return driver.RowsAffected(0), err } - return ci.getResult(), nil } - numValues := len(v) + var ( + numValues = len(v) + err error + ) for i, value := range v { - ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value) + ci.buffer, err = appendEncodedText(ci.buffer, value) + if err != nil { + return nil, ci.cn.handleError(err) + } if i < numValues-1 { ci.buffer = append(ci.buffer, '\t') } @@ -273,7 +297,10 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { ci.buffer = append(ci.buffer, '\n') if len(ci.buffer) > ciBufferFlushSize { - ci.flush(ci.buffer) + err := ci.flush(ci.buffer) + if err != nil { + return nil, ci.cn.handleError(err) + } // reset buffer, keep bytes for message identifier and length ci.buffer = ci.buffer[:5] } @@ -288,20 +315,16 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { // You need to call Exec(nil) to sync the COPY stream and to get any // errors from pending data, since Stmt.Close() doesn't return errors // to the user. -func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) { +func (ci *copyin) CopyData(ctx context.Context, line string) (driver.Result, error) { if ci.closed { return nil, errCopyInClosed } - if finish := ci.cn.watchCancel(ctx); finish != nil { defer finish() } - if err := ci.getBad(); err != nil { return nil, err } - defer ci.cn.errRecover(&err) - if err := ci.err(); err != nil { return nil, err } @@ -310,7 +333,11 @@ func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, e ci.buffer = append(ci.buffer, '\n') if len(ci.buffer) > ciBufferFlushSize { - ci.flush(ci.buffer) + err := ci.flush(ci.buffer) + if err != nil { + return nil, ci.cn.handleError(err) + } + // reset buffer, keep bytes for message identifier and length ci.buffer = ci.buffer[:5] } @@ -318,7 +345,7 @@ func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, e return driver.RowsAffected(0), nil } -func (ci *copyin) Close() (err error) { +func (ci *copyin) Close() error { if ci.closed { // Don't do anything, we're already closed return nil } @@ -327,15 +354,17 @@ func (ci *copyin) Close() (err error) { if err := ci.getBad(); err != nil { return err } - defer ci.cn.errRecover(&err) if len(ci.buffer) > 0 { - ci.flush(ci.buffer) + err := ci.flush(ci.buffer) + if err != nil { + return ci.cn.handleError(err) + } } // Avoid touching the scratch buffer as resploop could be using it. - err = ci.cn.sendSimpleMessage('c') + err := ci.cn.sendSimpleMessage(proto.CopyDoneRequest) if err != nil { - return err + return ci.cn.handleError(err) } <-ci.done diff --git a/vendor/github.com/lib/pq/deprecated.go b/vendor/github.com/lib/pq/deprecated.go new file mode 100644 index 000000000..0def49de9 --- /dev/null +++ b/vendor/github.com/lib/pq/deprecated.go @@ -0,0 +1,59 @@ +package pq + +// PGError is an interface used by previous versions of pq. +// +// Deprecated: use the Error type. This is never used. +type PGError interface { + Error() string + Fatal() bool + Get(k byte) (v string) +} + +// Get implements the legacy PGError interface. +// +// Deprecated: new code should use the fields of the Error struct directly. +func (e *Error) Get(k byte) (v string) { + switch k { + case 'S': + return e.Severity + case 'C': + return string(e.Code) + case 'M': + return e.Message + case 'D': + return e.Detail + case 'H': + return e.Hint + case 'P': + return e.Position + case 'p': + return e.InternalPosition + case 'q': + return e.InternalQuery + case 'W': + return e.Where + case 's': + return e.Schema + case 't': + return e.Table + case 'c': + return e.Column + case 'd': + return e.DataTypeName + case 'n': + return e.Constraint + case 'F': + return e.File + case 'L': + return e.Line + case 'R': + return e.Routine + } + return "" +} + +// ParseURL converts a url to a connection string for driver.Open. +// +// Deprecated: directly passing an URL to sql.Open("postgres", "postgres://...") +// now works, and calling this manually is no longer required. +func ParseURL(url string) (string, error) { return convertURL(url) } diff --git a/vendor/github.com/lib/pq/doc.go b/vendor/github.com/lib/pq/doc.go index b57184801..103940733 100644 --- a/vendor/github.com/lib/pq/doc.go +++ b/vendor/github.com/lib/pq/doc.go @@ -1,8 +1,8 @@ /* -Package pq is a pure Go Postgres driver for the database/sql package. +Package pq is a Go PostgreSQL driver for database/sql. -In most cases clients will use the database/sql package instead of -using this package directly. For example: +Most clients will use the database/sql package instead of using this package +directly. For example: import ( "database/sql" @@ -11,157 +11,75 @@ using this package directly. For example: ) func main() { - connStr := "user=pqgotest dbname=pqgotest sslmode=verify-full" - db, err := sql.Open("postgres", connStr) + dsn := "user=pqgo dbname=pqgo sslmode=verify-full" + db, err := sql.Open("postgres", dsn) if err != nil { log.Fatal(err) } age := 21 - rows, err := db.Query("SELECT name FROM users WHERE age = $1", age) - … + rows, err := db.Query("select name from users where age = $1", age) + // … } -You can also connect to a database using a URL. For example: +You can also connect with an URL: - connStr := "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full" - db, err := sql.Open("postgres", connStr) + dsn := "postgres://pqgo:password@localhost/pqgo?sslmode=verify-full" + db, err := sql.Open("postgres", dsn) +# Connection String Parameters -Connection String Parameters +See [NewConfig]. +# Queries -Similarly to libpq, when establishing a connection using pq you are expected to -supply a connection string containing zero or more parameters. -A subset of the connection parameters supported by libpq are also supported by pq. -Additionally, pq also lets you specify run-time parameters (such as search_path or work_mem) -directly in the connection string. This is different from libpq, which does not allow -run-time parameters in the connection string, instead requiring you to supply -them in the options parameter. +database/sql does not dictate any specific format for parameter placeholders, +and pq uses the PostgreSQL-native ordinal markers ($1, $2, etc.). The same +placeholder can be used more than once: -For compatibility with libpq, the following special connection parameters are -supported: + rows, err := db.Query( + `select * from users where name = $1 or age between $2 and $2 + 3`, + "Duck", 64) - * dbname - The name of the database to connect to - * user - The user to sign in as - * password - The user's password - * host - The host to connect to. Values that start with / are for unix - domain sockets. (default is localhost) - * port - The port to bind to. (default is 5432) - * sslmode - Whether or not to use SSL (default is require, this is not - the default for libpq) - * fallback_application_name - An application_name to fall back to if one isn't provided. - * connect_timeout - Maximum wait for connection, in seconds. Zero or - not specified means wait indefinitely. - * sslcert - Cert file location. The file must contain PEM encoded data. - * sslkey - Key file location. The file must contain PEM encoded data. - * sslrootcert - The location of the root certificate file. The file - must contain PEM encoded data. +pq does not support [sql.Result.LastInsertId]. Use the RETURNING clause with a +Query or QueryRow call instead to return the identifier: -Valid values for sslmode are: - - * disable - No SSL - * require - Always SSL (skip verification) - * verify-ca - Always SSL (verify that the certificate presented by the - server was signed by a trusted CA) - * verify-full - Always SSL (verify that the certification presented by - the server was signed by a trusted CA and the server host name - matches the one in the certificate) - -See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING -for more information about connection string parameters. - -Use single quotes for values that contain whitespace: - - "user=pqgotest password='with spaces'" - -A backslash will escape the next character in values: - - "user=space\ man password='it\'s valid'" - -Note that the connection parameter client_encoding (which sets the -text encoding for the connection) may be set but must be "UTF8", -matching with the same rules as Postgres. It is an error to provide -any other value. - -In addition to the parameters listed above, any run-time parameter that can be -set at backend start time can be set in the connection string. For more -information, see -http://www.postgresql.org/docs/current/static/runtime-config.html. - -Most environment variables as specified at http://www.postgresql.org/docs/current/static/libpq-envars.html -supported by libpq are also supported by pq. If any of the environment -variables not supported by pq are set, pq will panic during connection -establishment. Environment variables have a lower precedence than explicitly -provided connection parameters. - -The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html -is supported, but on Windows PGPASSFILE must be specified explicitly. - - -Queries - - -database/sql does not dictate any specific format for parameter -markers in query strings, and pq uses the Postgres-native ordinal markers, -as shown above. The same marker can be reused for the same parameter: - - rows, err := db.Query(`SELECT name FROM users WHERE favorite_fruit = $1 - OR age BETWEEN $2 AND $2 + 3`, "orange", 64) - -pq does not support the LastInsertId() method of the Result type in database/sql. -To return the identifier of an INSERT (or UPDATE or DELETE), use the Postgres -RETURNING clause with a standard Query or QueryRow call: + row := db.QueryRow(`insert into users(name, age) values('Scrooge McDuck', 93) returning id`) var userid int - err := db.QueryRow(`INSERT INTO users(name, favorite_fruit, age) - VALUES('beatrice', 'starfruit', 93) RETURNING id`).Scan(&userid) - -For more details on RETURNING, see the Postgres documentation: - - http://www.postgresql.org/docs/current/static/sql-insert.html - http://www.postgresql.org/docs/current/static/sql-update.html - http://www.postgresql.org/docs/current/static/sql-delete.html + err := row.Scan(&userid) -For additional instructions on querying see the documentation for the database/sql package. +# Data Types - -Data Types - - -Parameters pass through driver.DefaultParameterConverter before they are handled -by this package. When the binary_parameters connection option is enabled, -[]byte values are sent directly to the backend as data in binary format. +Parameters pass through [driver.DefaultParameterConverter] before they are handled +by this package. When the binary_parameters connection option is enabled, []byte +values are sent directly to the backend as data in binary format. This package returns the following types for values from the PostgreSQL backend: - - integer types smallint, integer, and bigint are returned as int64 - - floating-point types real and double precision are returned as float64 - - character types char, varchar, and text are returned as string - - temporal types date, time, timetz, timestamp, and timestamptz are - returned as time.Time - - the boolean type is returned as bool - - the bytea type is returned as []byte + - integer types smallint, integer, and bigint are returned as int64 + - floating-point types real and double precision are returned as float64 + - character types char, varchar, and text are returned as string + - temporal types date, time, timetz, timestamp, and timestamptz are + returned as time.Time + - the boolean type is returned as bool + - the bytea type is returned as []byte All other types are returned directly from the backend as []byte values in text format. +# Errors -Errors - - -pq may return errors of type *pq.Error which can be interrogated for error details: - - if err, ok := err.(*pq.Error); ok { - fmt.Println("pq error:", err.Code.Name()) - } - -See the pq.Error type for details. +pq may return errors of type [*pq.Error] which contain error details: + pqErr := new(pq.Error) + if errors.As(err, &pqErr) { + fmt.Println("pq error:", pqErr.Code.Name()) + } -Bulk imports +# Bulk imports -You can perform bulk imports by preparing a statement returned by pq.CopyIn (or -pq.CopyInSchema) in an explicit transaction (sql.Tx). The returned statement +You can perform bulk imports by preparing a statement returned by [CopyIn] (or +[CopyInSchema]) in an explicit transaction ([sql.Tx]). The returned statement handle can then be repeatedly "executed" to copy data into the target table. After all data has been processed you should call Exec() once with no arguments to flush all buffered data. Any call to Exec() might return an error which @@ -172,78 +90,35 @@ failed. CopyIn uses COPY FROM internally. It is not possible to COPY outside of an explicit transaction in pq. -Usage example: - - txn, err := db.Begin() - if err != nil { - log.Fatal(err) - } - - stmt, err := txn.Prepare(pq.CopyIn("users", "name", "age")) - if err != nil { - log.Fatal(err) - } - - for _, user := range users { - _, err = stmt.Exec(user.Name, int64(user.Age)) - if err != nil { - log.Fatal(err) - } - } +# Notifications - _, err = stmt.Exec() - if err != nil { - log.Fatal(err) - } +PostgreSQL supports a simple publish/subscribe model using PostgreSQL's [NOTIFY] mechanism. - err = stmt.Close() - if err != nil { - log.Fatal(err) - } - - err = txn.Commit() - if err != nil { - log.Fatal(err) - } - - -Notifications - - -PostgreSQL supports a simple publish/subscribe model over database -connections. See http://www.postgresql.org/docs/current/static/sql-notify.html -for more information about the general mechanism. - -To start listening for notifications, you first have to open a new connection -to the database by calling NewListener. This connection can not be used for -anything other than LISTEN / NOTIFY. Calling Listen will open a "notification +To start listening for notifications, you first have to open a new connection to +the database by calling [NewListener]. This connection can not be used for +anything other than LISTEN / NOTIFY. Calling Listen will open a "notification channel"; once a notification channel is open, a notification generated on that -channel will effect a send on the Listener.Notify channel. A notification +channel will effect a send on the Listener.Notify channel. A notification channel will remain open until Unlisten is called, though connection loss might -result in some notifications being lost. To solve this problem, Listener sends -a nil pointer over the Notify channel any time the connection is re-established -following a connection loss. The application can get information about the -state of the underlying connection by setting an event callback in the call to +result in some notifications being lost. To solve this problem, Listener sends a +nil pointer over the Notify channel any time the connection is re-established +following a connection loss. The application can get information about the state +of the underlying connection by setting an event callback in the call to NewListener. -A single Listener can safely be used from concurrent goroutines, which means +A single [Listener] can safely be used from concurrent goroutines, which means that there is often no need to create more than one Listener in your -application. However, a Listener is always connected to a single database, so +application. However, a Listener is always connected to a single database, so you will need to create a new Listener instance for every database you want to receive notifications in. The channel name in both Listen and Unlisten is case sensitive, and can contain -any characters legal in an identifier (see -http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -for more information). Note that the channel name will be truncated to 63 -bytes by the PostgreSQL server. - -You can find a complete, working example of Listener usage at -https://godoc.org/github.com/lib/pq/example/listen. +any characters legal in an [identifier]. Note that the channel name will be +truncated to 63 bytes by the PostgreSQL server. +You can find a complete, working example of Listener usage at [cmd/pqlisten]. -Kerberos Support - +# Kerberos Support If you need support for Kerberos authentication, add the following to your main package: @@ -254,15 +129,11 @@ package: pq.RegisterGSSProvider(func() (pq.Gss, error) { return kerberos.NewGSS() }) } -This package is in a separate module so that users who don't need Kerberos -don't have to download unnecessary dependencies. - -When imported, additional connection string parameters are supported: +This package is in a separate module so that users who don't need Kerberos don't +have to add unnecessary dependencies. - * krbsrvname - GSS (Kerberos) service name when constructing the - SPN (default is `postgres`). This will be combined with the host - to form the full SPN: `krbsrvname/host`. - * krbspn - GSS (Kerberos) SPN. This takes priority over - `krbsrvname` if present. +[cmd/pqlisten]: https://github.com/lib/pq/tree/master/cmd/pqlisten +[identifier]: http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +[NOTIFY]: http://www.postgresql.org/docs/current/static/sql-notify.html */ package pq diff --git a/vendor/github.com/lib/pq/encode.go b/vendor/github.com/lib/pq/encode.go index bffe6096a..e43fc93d6 100644 --- a/vendor/github.com/lib/pq/encode.go +++ b/vendor/github.com/lib/pq/encode.go @@ -2,7 +2,7 @@ package pq import ( "bytes" - "database/sql/driver" + "database/sql" "encoding/binary" "encoding/hex" "errors" @@ -19,144 +19,154 @@ import ( var time2400Regex = regexp.MustCompile(`^(24:00(?::00(?:\.0+)?)?)(?:[Z+-].*)?$`) -func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte { +func binaryEncode(x any) ([]byte, error) { switch v := x.(type) { case []byte: - return v + return v, nil default: - return encode(parameterStatus, x, oid.T_unknown) + return encode(x, oid.T_unknown) } } -func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte { +func encode(x any, pgtypOid oid.Oid) ([]byte, error) { switch v := x.(type) { case int64: - return strconv.AppendInt(nil, v, 10) + return strconv.AppendInt(nil, v, 10), nil case float64: - return strconv.AppendFloat(nil, v, 'f', -1, 64) + return strconv.AppendFloat(nil, v, 'f', -1, 64), nil case []byte: + if v == nil { + return nil, nil + } if pgtypOid == oid.T_bytea { - return encodeBytea(parameterStatus.serverVersion, v) + return encodeBytea(v), nil } - - return v + return v, nil case string: if pgtypOid == oid.T_bytea { - return encodeBytea(parameterStatus.serverVersion, []byte(v)) + return encodeBytea([]byte(v)), nil } - - return []byte(v) + return []byte(v), nil case bool: - return strconv.AppendBool(nil, v) + return strconv.AppendBool(nil, v), nil case time.Time: - return formatTs(v) - + return formatTS(v), nil default: - errorf("encode: unknown type for %T", v) + return nil, fmt.Errorf("pq: encode: unknown type for %T", v) } - - panic("not reached") } -func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} { +func decode(ps *parameterStatus, s []byte, typ oid.Oid, f format) (any, error) { switch f { case formatBinary: - return binaryDecode(parameterStatus, s, typ) + return binaryDecode(s, typ) case formatText: - return textDecode(parameterStatus, s, typ) + return textDecode(ps, s, typ) default: - panic("not reached") + panic("unreachable") } } -func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { +func binaryDecode(s []byte, typ oid.Oid) (any, error) { switch typ { case oid.T_bytea: - return s + return s, nil case oid.T_int8: - return int64(binary.BigEndian.Uint64(s)) + return int64(binary.BigEndian.Uint64(s)), nil case oid.T_int4: - return int64(int32(binary.BigEndian.Uint32(s))) + return int64(int32(binary.BigEndian.Uint32(s))), nil case oid.T_int2: - return int64(int16(binary.BigEndian.Uint16(s))) + return int64(int16(binary.BigEndian.Uint16(s))), nil case oid.T_uuid: b, err := decodeUUIDBinary(s) if err != nil { - panic(err) + err = errors.New("pq: " + err.Error()) } - return b - + return b, err default: - errorf("don't know how to decode binary parameter of type %d", uint32(typ)) + return nil, fmt.Errorf("pq: don't know how to decode binary parameter of type %d", uint32(typ)) + } + +} + +// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. +func decodeUUIDBinary(src []byte) ([]byte, error) { + if len(src) != 16 { + return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) } - panic("not reached") + dst := make([]byte, 36) + dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' + hex.Encode(dst[0:], src[0:4]) + hex.Encode(dst[9:], src[4:6]) + hex.Encode(dst[14:], src[6:8]) + hex.Encode(dst[19:], src[8:10]) + hex.Encode(dst[24:], src[10:16]) + + return dst, nil } -func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { +func textDecode(ps *parameterStatus, s []byte, typ oid.Oid) (any, error) { switch typ { - case oid.T_char, oid.T_varchar, oid.T_text: - return string(s) + case oid.T_char, oid.T_bpchar, oid.T_varchar, oid.T_text: + return string(s), nil case oid.T_bytea: b, err := parseBytea(s) if err != nil { - errorf("%s", err) + err = errors.New("pq: " + err.Error()) } - return b + return b, err case oid.T_timestamptz: - return parseTs(parameterStatus.currentLocation, string(s)) + return parseTS(ps.currentLocation, string(s)) case oid.T_timestamp, oid.T_date: - return parseTs(nil, string(s)) + return parseTS(nil, string(s)) case oid.T_time: - return mustParse("15:04:05", typ, s) + return parseTime("15:04:05", typ, s) case oid.T_timetz: - return mustParse("15:04:05-07", typ, s) + return parseTime("15:04:05-07", typ, s) case oid.T_bool: - return s[0] == 't' + return s[0] == 't', nil case oid.T_int8, oid.T_int4, oid.T_int2: i, err := strconv.ParseInt(string(s), 10, 64) if err != nil { - errorf("%s", err) + err = errors.New("pq: " + err.Error()) } - return i + return i, err case oid.T_float4, oid.T_float8: // We always use 64 bit parsing, regardless of whether the input text is for // a float4 or float8, because clients expect float64s for all float datatypes // and returning a 32-bit parsed float64 produces lossy results. f, err := strconv.ParseFloat(string(s), 64) if err != nil { - errorf("%s", err) + err = errors.New("pq: " + err.Error()) } - return f + return f, err } - - return s + return s, nil } // appendEncodedText encodes item in text format as required by COPY // and appends to buf -func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte { +func appendEncodedText(buf []byte, x any) ([]byte, error) { switch v := x.(type) { case int64: - return strconv.AppendInt(buf, v, 10) + return strconv.AppendInt(buf, v, 10), nil case float64: - return strconv.AppendFloat(buf, v, 'f', -1, 64) + return strconv.AppendFloat(buf, v, 'f', -1, 64), nil case []byte: - encodedBytea := encodeBytea(parameterStatus.serverVersion, v) - return appendEscapedText(buf, string(encodedBytea)) + encodedBytea := encodeBytea(v) + return appendEscapedText(buf, string(encodedBytea)), nil case string: - return appendEscapedText(buf, v) + return appendEscapedText(buf, v), nil case bool: - return strconv.AppendBool(buf, v) + return strconv.AppendBool(buf, v), nil case time.Time: - return append(buf, formatTs(v)...) + return append(buf, formatTS(v)...), nil case nil: - return append(buf, "\\N"...) + return append(buf, "\\N"...), nil default: - errorf("encode: unknown type for %T", v) + return nil, fmt.Errorf("pq: encode: unknown type for %T", v) } - - panic("not reached") } func appendEscapedText(buf []byte, text string) []byte { @@ -197,7 +207,7 @@ func appendEscapedText(buf []byte, text string) []byte { return result } -func mustParse(f string, typ oid.Oid, s []byte) time.Time { +func parseTime(f string, typ oid.Oid, s []byte) (time.Time, error) { str := string(s) // Check for a minute and second offset in the timezone. @@ -227,12 +237,12 @@ func mustParse(f string, typ oid.Oid, s []byte) time.Time { } t, err := time.Parse(f, str) if err != nil { - errorf("decode: %s", err) + return time.Time{}, errors.New("pq: " + err.Error()) } if is2400Time { t = t.Add(24 * time.Hour) } - return t + return t, nil } var errInvalidTimestamp = errors.New("invalid timestamp") @@ -303,13 +313,15 @@ func (c *locationCache) getLocation(offset int) *time.Location { return location } -var infinityTsEnabled = false -var infinityTsNegative time.Time -var infinityTsPositive time.Time +var ( + infinityTSEnabled = false + infinityTSNegative time.Time + infinityTSPositive time.Time +) const ( - infinityTsEnabledAlready = "pq: infinity timestamp enabled already" - infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" + infinityTSEnabledAlready = "pq: infinity timestamp enabled already" + infinityTSNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" ) // EnableInfinityTs controls the handling of Postgres' "-infinity" and @@ -333,46 +345,44 @@ const ( // undefined behavior. If EnableInfinityTs is called more than once, it will // panic. func EnableInfinityTs(negative time.Time, positive time.Time) { - if infinityTsEnabled { - panic(infinityTsEnabledAlready) + if infinityTSEnabled { + panic(infinityTSEnabledAlready) } if !negative.Before(positive) { - panic(infinityTsNegativeMustBeSmaller) + panic(infinityTSNegativeMustBeSmaller) } - infinityTsEnabled = true - infinityTsNegative = negative - infinityTsPositive = positive + infinityTSEnabled = true + infinityTSNegative = negative + infinityTSPositive = positive } -/* - * Testing might want to toggle infinityTsEnabled - */ -func disableInfinityTs() { - infinityTsEnabled = false +// Testing might want to toggle infinityTSEnabled +func disableInfinityTS() { + infinityTSEnabled = false } // This is a time function specific to the Postgres default DateStyle // setting ("ISO, MDY"), the only one we currently support. This // accounts for the discrepancies between the parsing available with // time.Parse and the Postgres date formatting quirks. -func parseTs(currentLocation *time.Location, str string) interface{} { +func parseTS(currentLocation *time.Location, str string) (any, error) { switch str { case "-infinity": - if infinityTsEnabled { - return infinityTsNegative + if infinityTSEnabled { + return infinityTSNegative, nil } - return []byte(str) + return []byte(str), nil case "infinity": - if infinityTsEnabled { - return infinityTsPositive + if infinityTSEnabled { + return infinityTSPositive, nil } - return []byte(str) + return []byte(str), nil } t, err := ParseTimestamp(currentLocation, str) if err != nil { - panic(err) + err = errors.New("pq: " + err.Error()) } - return t + return t, err } // ParseTimestamp parses Postgres' text format. It returns a time.Time in @@ -488,15 +498,15 @@ func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, erro return t, p.err } -// formatTs formats t into a format postgres understands. -func formatTs(t time.Time) []byte { - if infinityTsEnabled { +// formatTS formats t into a format postgres understands. +func formatTS(t time.Time) []byte { + if infinityTSEnabled { // t <= -infinity : ! (t > -infinity) - if !t.After(infinityTsNegative) { + if !t.After(infinityTSNegative) { return []byte("-infinity") } // t >= infinity : ! (!t < infinity) - if !t.Before(infinityTsPositive) { + if !t.Before(infinityTSPositive) { return []byte("infinity") } } @@ -565,7 +575,7 @@ func parseBytea(s []byte) (result []byte, err error) { } r, err := strconv.ParseUint(string(s[1:4]), 8, 8) if err != nil { - return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) + return nil, fmt.Errorf("could not parse bytea value: %w", err) } result = append(result, byte(r)) s = s[4:] @@ -586,47 +596,17 @@ func parseBytea(s []byte) (result []byte, err error) { return result, nil } -func encodeBytea(serverVersion int, v []byte) (result []byte) { - if serverVersion >= 90000 { - // Use the hex format if we know that the server supports it - result = make([]byte, 2+hex.EncodedLen(len(v))) - result[0] = '\\' - result[1] = 'x' - hex.Encode(result[2:], v) - } else { - // .. or resort to "escape" - for _, b := range v { - if b == '\\' { - result = append(result, '\\', '\\') - } else if b < 0x20 || b > 0x7e { - result = append(result, []byte(fmt.Sprintf("\\%03o", b))...) - } else { - result = append(result, b) - } - } - } - +func encodeBytea(v []byte) (result []byte) { + result = make([]byte, 2+hex.EncodedLen(len(v))) + result[0] = '\\' + result[1] = 'x' + hex.Encode(result[2:], v) return result } -// NullTime represents a time.Time that may be null. NullTime implements the -// sql.Scanner interface so it can be used as a scan destination, similar to -// sql.NullString. -type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL -} - -// Scan implements the Scanner interface. -func (nt *NullTime) Scan(value interface{}) error { - nt.Time, nt.Valid = value.(time.Time) - return nil -} - -// Value implements the driver Valuer interface. -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil -} +// NullTime represents a [time.Time] that may be null. +// NullTime implements the [sql.Scanner] interface so +// it can be used as a scan destination, similar to [sql.NullString]. +// +// Deprecated: this is an alias for [sql.NullTime]. +type NullTime = sql.NullTime diff --git a/vendor/github.com/lib/pq/error.go b/vendor/github.com/lib/pq/error.go index f67c5a5fa..234d39e2c 100644 --- a/vendor/github.com/lib/pq/error.go +++ b/vendor/github.com/lib/pq/error.go @@ -6,9 +6,12 @@ import ( "io" "net" "runtime" + "strconv" + "strings" + "unicode/utf8" ) -// Error severities +// [pq.Error.Severity] values. const ( Efatal = "FATAL" Epanic = "PANIC" @@ -21,25 +24,99 @@ const ( // Error represents an error communicating with the server. // +// The [Error] method only returns the error message and error code: +// +// pq: invalid input syntax for type json (22P02) +// +// The [ErrorWithDetail] method also includes the error Detail, Hint, and +// location context (if any): +// +// ERROR: invalid input syntax for type json (22P02) +// DETAIL: Token "asd" is invalid. +// CONTEXT: line 5, column 8: +// +// 3 | 'def', +// 4 | 123, +// 5 | 'foo', 'asd'::jsonb +// ^ +// // See http://www.postgresql.org/docs/current/static/protocol-error-fields.html for details of the fields type Error struct { - Severity string - Code ErrorCode - Message string - Detail string - Hint string - Position string + // [Efatal], [Epanic], [Ewarning], [Enotice], [Edebug], [Einfo], or [Elog]. + // Always present. + Severity string + + // SQLSTATE code. Always present. + Code ErrorCode + + // Primary human-readable error message. This should be accurate but terse + // (typically one line). Always present. + Message string + + // Optional secondary error message carrying more detail about the problem. + // Might run to multiple lines. + Detail string + + // Optional suggestion what to do about the problem. This is intended to + // differ from Detail in that it offers advice (potentially inappropriate) + // rather than hard facts. Might run to multiple lines. + Hint string + + // error position as an index into the original query string, as decimal + // ASCII integer. The first character has index 1, and positions are + // measured in characters not bytes. + Position string + + // This is defined the same as the Position field, but it is used when the + // cursor position refers to an internally generated command rather than the + // one submitted by the client. The InternalQuery field will always appear + // when this field appears. InternalPosition string - InternalQuery string - Where string - Schema string - Table string - Column string - DataTypeName string - Constraint string - File string - Line string - Routine string + + // Text of a failed internally-generated command. This could be, for + // example, an SQL query issued by a PL/pgSQL function. + InternalQuery string + + // An indication of the context in which the error occurred. Presently this + // includes a call stack traceback of active procedural language functions + // and internally-generated queries. The trace is one entry per line, most + // recent first. + Where string + + // If the error was associated with a specific database object, the name of + // the schema containing that object, if any. + Schema string + + // If the error was associated with a specific table, the name of the table. + // (Refer to the schema name field for the name of the table's schema.) + Table string + + // If the error was associated with a specific table column, the name of the + // column. (Refer to the schema and table name fields to identify the + // table.) + Column string + + // If the error was associated with a specific data type, the name of the + // data type. (Refer to the schema name field for the name of the data + // type's schema.) + DataTypeName string + + // If the error was associated with a specific constraint, the name of the + // constraint. Refer to fields listed above for the associated table or + // domain. (For this purpose, indexes are treated as constraints, even if + // they weren't created with constraint syntax.) + Constraint string + + // File name of the source-code location where the error was reported. + File string + + // Line number of the source-code location where the error was reported. + Line string + + // Name of the source-code routine reporting the error. + Routine string + + query string } // ErrorCode is a five-character error code. @@ -353,8 +430,8 @@ var errorCodeNames = map[ErrorCode]string{ "XX002": "index_corrupted", } -func parseError(r *readBuf) *Error { - err := new(Error) +func parseError(r *readBuf, q string) *Error { + err := &Error{query: q} for t := r.byte(); t != 0; t = r.byte() { msg := r.string() switch t { @@ -398,126 +475,163 @@ func parseError(r *readBuf) *Error { } // Fatal returns true if the Error Severity is fatal. -func (err *Error) Fatal() bool { - return err.Severity == Efatal +func (e *Error) Fatal() bool { + return e.Severity == Efatal } // SQLState returns the SQLState of the error. -func (err *Error) SQLState() string { - return string(err.Code) +func (e *Error) SQLState() string { + return string(e.Code) } -// Get implements the legacy PGError interface. New code should use the fields -// of the Error struct directly. -func (err *Error) Get(k byte) (v string) { - switch k { - case 'S': - return err.Severity - case 'C': - return string(err.Code) - case 'M': - return err.Message - case 'D': - return err.Detail - case 'H': - return err.Hint - case 'P': - return err.Position - case 'p': - return err.InternalPosition - case 'q': - return err.InternalQuery - case 'W': - return err.Where - case 's': - return err.Schema - case 't': - return err.Table - case 'c': - return err.Column - case 'd': - return err.DataTypeName - case 'n': - return err.Constraint - case 'F': - return err.File - case 'L': - return err.Line - case 'R': - return err.Routine +func (e *Error) Error() string { + msg := e.Message + if e.query != "" && e.Position != "" { + pos, err := strconv.Atoi(e.Position) + if err == nil { + lines := strings.Split(e.query, "\n") + line, col := posToLine(pos, lines) + if len(lines) == 1 { + msg += " at column " + strconv.Itoa(col) + } else { + msg += " at position " + strconv.Itoa(line) + ":" + strconv.Itoa(col) + } + } } - return "" -} -func (err *Error) Error() string { - return "pq: " + err.Message + if e.Code != "" { + return "pq: " + msg + " (" + string(e.Code) + ")" + } + return "pq: " + msg } -// PGError is an interface used by previous versions of pq. It is provided -// only to support legacy code. New code should use the Error type. -type PGError interface { - Error() string - Fatal() bool - Get(k byte) (v string) -} +// ErrorWithDetail returns the error message with detailed information and +// location context (if any). +// +// See the documentation on [Error]. +func (e *Error) ErrorWithDetail() string { + b := new(strings.Builder) + b.Grow(len(e.Message) + len(e.Detail) + len(e.Hint) + 30) + b.WriteString("ERROR: ") + b.WriteString(e.Message) + if e.Code != "" { + b.WriteString(" (") + b.WriteString(string(e.Code)) + b.WriteByte(')') + } + if e.Detail != "" { + b.WriteString("\nDETAIL: ") + b.WriteString(e.Detail) + } + if e.Hint != "" { + b.WriteString("\nHINT: ") + b.WriteString(e.Hint) + } -func errorf(s string, args ...interface{}) { - panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) -} + if e.query != "" && e.Position != "" { + b.Grow(512) + pos, err := strconv.Atoi(e.Position) + if err != nil { + return b.String() + } + lines := strings.Split(e.query, "\n") + line, col := posToLine(pos, lines) + + fmt.Fprintf(b, "\nCONTEXT: line %d, column %d:\n\n", line, col) + if line > 2 { + fmt.Fprintf(b, "% 7d | %s\n", line-2, expandTab(lines[line-3])) + } + if line > 1 { + fmt.Fprintf(b, "% 7d | %s\n", line-1, expandTab(lines[line-2])) + } + /// Expand tabs, so that the ^ is at at the correct position, but leave + /// "column 10-13" intact. Adjusting this to the visual column would be + /// better, but we don't know the tabsize of the user in their editor, + /// which can be 8, 4, 2, or something else. We can't know. So leaving + /// it as the character index is probably the "most correct". + expanded := expandTab(lines[line-1]) + diff := len(expanded) - len(lines[line-1]) + fmt.Fprintf(b, "% 7d | %s\n", line, expanded) + fmt.Fprintf(b, "% 10s%s%s\n", "", strings.Repeat(" ", col-1+diff), "^") + } -// TODO(ainar-g) Rename to errorf after removing panics. -func fmterrorf(s string, args ...interface{}) error { - return fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)) + return b.String() } -func errRecoverNoErrBadConn(err *error) { - e := recover() - if e == nil { - // Do nothing - return +func posToLine(pos int, lines []string) (line, col int) { + read := 0 + for i := range lines { + line++ + ll := utf8.RuneCountInString(lines[i]) + 1 // +1 for the removed newline + if read+ll >= pos { + col = pos - read + if col < 1 { // Should never happen, but just in case. + col = 1 + } + break + } + read += ll } - var ok bool - *err, ok = e.(error) - if !ok { - *err = fmt.Errorf("pq: unexpected error: %#v", e) + return line, col +} + +func expandTab(s string) string { + var ( + b strings.Builder + l int + fill = func(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = ' ' + } + return string(b) + } + ) + b.Grow(len(s)) + for _, r := range s { + switch r { + case '\t': + tw := 8 - l%8 + b.WriteString(fill(tw)) + l += tw + default: + b.WriteRune(r) + l += 1 + } } + return b.String() } -func (cn *conn) errRecover(err *error) { - e := recover() - switch v := e.(type) { +func (cn *conn) handleError(reported error, query ...string) error { + switch err := reported.(type) { case nil: - // Do nothing - case runtime.Error: - cn.err.set(driver.ErrBadConn) - panic(v) - case *Error: - if v.Fatal() { - *err = driver.ErrBadConn - } else { - *err = v - } - case *net.OpError: + return nil + case runtime.Error, *net.OpError: cn.err.set(driver.ErrBadConn) - *err = v case *safeRetryError: cn.err.set(driver.ErrBadConn) - *err = driver.ErrBadConn + reported = driver.ErrBadConn + case *Error: + if len(query) > 0 && query[0] != "" { + err.query = query[0] + reported = err + } + if err.Fatal() { + reported = driver.ErrBadConn + } case error: - if v == io.EOF || v.Error() == "remote error: handshake failure" { - *err = driver.ErrBadConn - } else { - *err = v + if err == io.EOF || err.Error() == "remote error: handshake failure" { + reported = driver.ErrBadConn } - default: cn.err.set(driver.ErrBadConn) - panic(fmt.Sprintf("unknown error: %#v", e)) + reported = fmt.Errorf("pq: unknown error %T: %[1]s", err) } // Any time we return ErrBadConn, we need to remember it since *Tx doesn't // mark the connection bad in database/sql. - if *err == driver.ErrBadConn { + if reported == driver.ErrBadConn { cn.err.set(driver.ErrBadConn) } + return reported } diff --git a/vendor/github.com/lib/pq/internal/pgpass/pgpass.go b/vendor/github.com/lib/pq/internal/pgpass/pgpass.go new file mode 100644 index 000000000..002631da7 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pgpass/pgpass.go @@ -0,0 +1,71 @@ +package pgpass + +import ( + "bufio" + "os" + "path/filepath" + "strings" + + "github.com/lib/pq/internal/pqutil" +) + +func PasswordFromPgpass(passfile, user, password, host, port, dbname string, passwordSet bool) string { + // Do not process .pgpass if a password was supplied. + if passwordSet { + return password + } + + filename := pqutil.Pgpass(passfile) + if filename == "" { + return "" + } + + fp, err := os.Open(filename) + if err != nil { + return "" + } + defer fp.Close() + + scan := bufio.NewScanner(fp) + for scan.Scan() { + line := scan.Text() + if len(line) == 0 || line[0] == '#' { + continue + } + split := splitFields(line) + if len(split) != 5 { + continue + } + + socket := host == "" || filepath.IsAbs(host) || strings.HasPrefix(host, "@") + if (split[0] == "*" || split[0] == host || (split[0] == "localhost" && socket)) && + (split[1] == "*" || split[1] == port) && + (split[2] == "*" || split[2] == dbname) && + (split[3] == "*" || split[3] == user) { + return split[4] + } + } + + return "" +} + +func splitFields(s string) []string { + var ( + fs = make([]string, 0, 5) + f = make([]rune, 0, len(s)) + esc bool + ) + for _, c := range s { + switch { + case esc: + f, esc = append(f, c), false + case c == '\\': + esc = true + case c == ':': + fs, f = append(fs, string(f)), f[:0] + default: + f = append(f, c) + } + } + return append(fs, string(f)) +} diff --git a/vendor/github.com/lib/pq/internal/pqsql/copy.go b/vendor/github.com/lib/pq/internal/pqsql/copy.go new file mode 100644 index 000000000..ccb688f63 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqsql/copy.go @@ -0,0 +1,37 @@ +package pqsql + +// StartsWithCopy reports if the SQL strings start with "copy", ignoring +// whitespace, comments, and casing. +func StartsWithCopy(query string) bool { + if len(query) < 4 { + return false + } + var linecmt, blockcmt bool + for i := 0; i < len(query); i++ { + c := query[i] + if linecmt { + linecmt = c != '\n' + continue + } + if blockcmt { + blockcmt = !(c == '/' && query[i-1] == '*') + continue + } + if c == '-' && len(query) > i+1 && query[i+1] == '-' { + linecmt = true + continue + } + if c == '/' && len(query) > i+1 && query[i+1] == '*' { + blockcmt = true + continue + } + if c == ' ' || c == '\t' || c == '\r' || c == '\n' { + continue + } + + // First non-comment and non-whitespace. + return len(query) > i+3 && c|0x20 == 'c' && query[i+1]|0x20 == 'o' && + query[i+2]|0x20 == 'p' && query[i+3]|0x20 == 'y' + } + return false +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/path.go b/vendor/github.com/lib/pq/internal/pqutil/path.go new file mode 100644 index 000000000..e6827a96f --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/path.go @@ -0,0 +1,65 @@ +package pqutil + +import ( + "fmt" + "os" + "os/user" + "path/filepath" + "runtime" +) + +// Home gets the user's home directory. Matches pqGetHomeDirectory() from +// PostgreSQL +// +// https://github.com/postgres/postgres/blob/2b117bb/src/interfaces/libpq/fe-connect.c#L8214 +func Home() string { + if runtime.GOOS == "windows" { + // pq uses SHGetFolderPath(), which is deprecated but x/sys/windows has + // KnownFolderPath(). We don't really want to pull that in though, so + // use APPDATA env. This is also what PostgreSQL uses in some other + // codepaths (get_home_path() for example). + ad := os.Getenv("APPDATA") + if ad == "" { + return "" + } + return filepath.Join(ad, "postgresql") + } + + home, _ := os.UserHomeDir() + if home == "" { + u, err := user.Current() + if err != nil { + return "" + } + home = u.HomeDir + } + return home +} + +// Pgpass gets the filepath to the pgpass file to use, returning "" if a pgpass +// file shouldn't be used. +func Pgpass(passfile string) string { + // Get passfile from the options. + if passfile == "" { + home := Home() + if home == "" { + return "" + } + passfile = filepath.Join(home, ".pgpass") + } + + // On Win32, the directory is protected, so we don't have to check the file. + if runtime.GOOS != "windows" { + fi, err := os.Stat(passfile) + if err != nil { + return "" + } + if fi.Mode().Perm()&(0x77) != 0 { + fmt.Fprintf(os.Stderr, + "WARNING: password file %q has group or world access; permissions should be u=rw (0600) or less\n", + passfile) + return "" + } + } + return passfile +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/perm.go b/vendor/github.com/lib/pq/internal/pqutil/perm.go new file mode 100644 index 000000000..fdfa94a07 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/perm.go @@ -0,0 +1,64 @@ +//go:build !windows && !plan9 + +package pqutil + +import ( + "errors" + "os" + "syscall" +) + +var ( + ErrSSLKeyUnknownOwnership = errors.New("pq: could not get owner information for private key, may not be properly protected") + ErrSSLKeyHasWorldPermissions = errors.New("pq: private key has world access; permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less") +) + +// SSLKeyPermissions checks the permissions on user-supplied SSL key files, +// which should have very little access. libpq does not check key file +// permissions on Windows. +// +// If the file is owned by the same user the process is running as, the file +// should only have 0600. If the file is owned by root, and the group matches +// the group that the process is running in, the permissions cannot be more than +// 0640. The file should never have world permissions. +// +// Returns an error when the permission check fails. +func SSLKeyPermissions(sslkey string) error { + fi, err := os.Stat(sslkey) + if err != nil { + return err + } + + return checkPermissions(fi) +} + +func checkPermissions(fi os.FileInfo) error { + // The maximum permissions that a private key file owned by a regular user + // is allowed to have. This translates to u=rw. Regardless of if we're + // running as root or not, 0600 is acceptable, so we return if we match the + // regular user permission mask. + if fi.Mode().Perm()&os.FileMode(0777)^0600 == 0 { + return nil + } + + // We need to pull the Unix file information to get the file's owner. + // If we can't access it, there's some sort of operating system level error + // and we should fail rather than attempting to use faulty information. + sys, ok := fi.Sys().(*syscall.Stat_t) + if !ok { + return ErrSSLKeyUnknownOwnership + } + + // if the file is owned by root, we allow 0640 (u=rw,g=r) to match what + // Postgres does. + if sys.Uid == 0 { + // The maximum permissions that a private key file owned by root is + // allowed to have. This translates to u=rw,g=r. + if fi.Mode().Perm()&os.FileMode(0777)^0640 != 0 { + return ErrSSLKeyHasWorldPermissions + } + return nil + } + + return ErrSSLKeyHasWorldPermissions +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/perm_unsupported.go b/vendor/github.com/lib/pq/internal/pqutil/perm_unsupported.go new file mode 100644 index 000000000..3ce759576 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/perm_unsupported.go @@ -0,0 +1,12 @@ +//go:build windows || plan9 + +package pqutil + +import "errors" + +var ( + ErrSSLKeyUnknownOwnership = errors.New("unused") + ErrSSLKeyHasWorldPermissions = errors.New("unused") +) + +func SSLKeyPermissions(sslkey string) error { return nil } diff --git a/vendor/github.com/lib/pq/internal/pqutil/pqutil.go b/vendor/github.com/lib/pq/internal/pqutil/pqutil.go new file mode 100644 index 000000000..ca869e9cc --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/pqutil.go @@ -0,0 +1,32 @@ +package pqutil + +import ( + "strconv" + "strings" +) + +// ParseBool is like strconv.ParseBool, but also accepts "yes"/"no" and +// "on"/"off". +func ParseBool(str string) (bool, error) { + switch str { + case "1", "t", "T", "true", "TRUE", "True", "yes", "on": + return true, nil + case "0", "f", "F", "false", "FALSE", "False", "no", "off": + return false, nil + } + return false, &strconv.NumError{Func: "ParseBool", Num: str, Err: strconv.ErrSyntax} +} + +func Join[S ~[]E, E ~string](s S) string { + var b strings.Builder + for i := range s { + if i > 0 { + b.WriteString(", ") + } + if i == len(s)-1 { + b.WriteString("or ") + } + b.WriteString(string(s[i])) + } + return b.String() +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/user_other.go b/vendor/github.com/lib/pq/internal/pqutil/user_other.go new file mode 100644 index 000000000..09e4f8dff --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/user_other.go @@ -0,0 +1,9 @@ +//go:build js || android || hurd || zos || wasip1 || appengine + +package pqutil + +import "errors" + +func User() (string, error) { + return "", errors.New("pqutil.User: not supported on current platform") +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/user_posix.go b/vendor/github.com/lib/pq/internal/pqutil/user_posix.go new file mode 100644 index 000000000..bd0ece6da --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/user_posix.go @@ -0,0 +1,25 @@ +//go:build !windows && !js && !android && !hurd && !zos && !wasip1 && !appengine + +package pqutil + +import ( + "os" + "os/user" + "runtime" +) + +func User() (string, error) { + env := "USER" + if runtime.GOOS == "plan9" { + env = "user" + } + if n := os.Getenv(env); n != "" { + return n, nil + } + + u, err := user.Current() + if err != nil { + return "", err + } + return u.Username, nil +} diff --git a/vendor/github.com/lib/pq/internal/pqutil/user_windows.go b/vendor/github.com/lib/pq/internal/pqutil/user_windows.go new file mode 100644 index 000000000..960cb8055 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/pqutil/user_windows.go @@ -0,0 +1,28 @@ +//go:build windows && !appengine + +package pqutil + +import ( + "path/filepath" + "syscall" +) + +func User() (string, error) { + // Perform Windows user name lookup identically to libpq. + // + // The PostgreSQL code makes use of the legacy Win32 function GetUserName, + // and that function has not been imported into stock Go. GetUserNameEx is + // available though, the difference being that a wider range of names are + // available. To get the output to be the same as GetUserName, only the + // base (or last) component of the result is returned. + var ( + name = make([]uint16, 128) + pwnameSz = uint32(len(name)) - 1 + ) + err := syscall.GetUserNameEx(syscall.NameSamCompatible, &name[0], &pwnameSz) + if err != nil { + return "", err + } + s := syscall.UTF16ToString(name) + return filepath.Base(s), nil +} diff --git a/vendor/github.com/lib/pq/internal/proto/proto.go b/vendor/github.com/lib/pq/internal/proto/proto.go new file mode 100644 index 000000000..318d180a6 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/proto/proto.go @@ -0,0 +1,186 @@ +// From src/include/libpq/protocol.h and src/include/libpq/pqcomm.h – PostgreSQL 18.1 + +package proto + +import ( + "fmt" + "strconv" +) + +// Constants from pqcomm.h +const ( + ProtocolVersion30 = (3 << 16) | 0 //lint:ignore SA4016 x + ProtocolVersion32 = (3 << 16) | 2 // PostgreSQL ≥18; not yet supported. + CancelRequestCode = (1234 << 16) | 5678 + NegotiateSSLCode = (1234 << 16) | 5679 + NegotiateGSSCode = (1234 << 16) | 5680 +) + +// Constants from fe-connect.c +const ( + MaxErrlen = 30_000 // https://github.com/postgres/postgres/blob/c6a10a89f/src/interfaces/libpq/fe-connect.c#L4067 +) + +// RequestCode is a request codes sent by the frontend. +type RequestCode byte + +// These are the request codes sent by the frontend. +const ( + Bind = RequestCode('B') + Close = RequestCode('C') + Describe = RequestCode('D') + Execute = RequestCode('E') + FunctionCall = RequestCode('F') + Flush = RequestCode('H') + Parse = RequestCode('P') + Query = RequestCode('Q') + Sync = RequestCode('S') + Terminate = RequestCode('X') + CopyFail = RequestCode('f') + GSSResponse = RequestCode('p') + PasswordMessage = RequestCode('p') + SASLInitialResponse = RequestCode('p') + SASLResponse = RequestCode('p') + CopyDoneRequest = RequestCode('c') + CopyDataRequest = RequestCode('d') +) + +func (r RequestCode) String() string { + s, ok := map[RequestCode]string{ + Bind: "Bind", + Close: "Close", + Describe: "Describe", + Execute: "Execute", + FunctionCall: "FunctionCall", + Flush: "Flush", + Parse: "Parse", + Query: "Query", + Sync: "Sync", + Terminate: "Terminate", + CopyFail: "CopyFail", + // These are all the same :-/ + //GSSResponse: "GSSResponse", + PasswordMessage: "PasswordMessage", + //SASLInitialResponse: "SASLInitialResponse", + //SASLResponse: "SASLResponse", + CopyDoneRequest: "CopyDone", + CopyDataRequest: "CopyData", + }[r] + if !ok { + s = "" + } + c := string(r) + if r <= 0x1f || r == 0x7f { + c = fmt.Sprintf("0x%x", string(r)) + } + return "(" + c + ") " + s +} + +// ResponseCode is a response codes sent by the backend. +type ResponseCode byte + +// These are the response codes sent by the backend. +const ( + ParseComplete = ResponseCode('1') + BindComplete = ResponseCode('2') + CloseComplete = ResponseCode('3') + NotificationResponse = ResponseCode('A') + CommandComplete = ResponseCode('C') + DataRow = ResponseCode('D') + ErrorResponse = ResponseCode('E') + CopyInResponse = ResponseCode('G') + CopyOutResponse = ResponseCode('H') + EmptyQueryResponse = ResponseCode('I') + BackendKeyData = ResponseCode('K') + NoticeResponse = ResponseCode('N') + AuthenticationRequest = ResponseCode('R') + ParameterStatus = ResponseCode('S') + RowDescription = ResponseCode('T') + FunctionCallResponse = ResponseCode('V') + CopyBothResponse = ResponseCode('W') + ReadyForQuery = ResponseCode('Z') + NoData = ResponseCode('n') + PortalSuspended = ResponseCode('s') + ParameterDescription = ResponseCode('t') + NegotiateProtocolVersion = ResponseCode('v') + CopyDoneResponse = ResponseCode('c') + CopyDataResponse = ResponseCode('d') +) + +func (r ResponseCode) String() string { + s, ok := map[ResponseCode]string{ + ParseComplete: "ParseComplete", + BindComplete: "BindComplete", + CloseComplete: "CloseComplete", + NotificationResponse: "NotificationResponse", + CommandComplete: "CommandComplete", + DataRow: "DataRow", + ErrorResponse: "ErrorResponse", + CopyInResponse: "CopyInResponse", + CopyOutResponse: "CopyOutResponse", + EmptyQueryResponse: "EmptyQueryResponse", + BackendKeyData: "BackendKeyData", + NoticeResponse: "NoticeResponse", + AuthenticationRequest: "AuthRequest", + ParameterStatus: "ParamStatus", + RowDescription: "RowDescription", + FunctionCallResponse: "FunctionCallResponse", + CopyBothResponse: "CopyBothResponse", + ReadyForQuery: "ReadyForQuery", + NoData: "NoData", + PortalSuspended: "PortalSuspended", + ParameterDescription: "ParamDescription", + NegotiateProtocolVersion: "NegotiateProtocolVersion", + CopyDoneResponse: "CopyDone", + CopyDataResponse: "CopyData", + }[r] + if !ok { + s = "" + } + c := string(r) + if r <= 0x1f || r == 0x7f { + c = fmt.Sprintf("0x%x", string(r)) + } + return "(" + c + ") " + s +} + +// AuthCode are authentication request codes sent by the backend. +type AuthCode int32 + +// These are the authentication request codes sent by the backend. +const ( + AuthReqOk = AuthCode(0) // User is authenticated + AuthReqKrb4 = AuthCode(1) // Kerberos V4. Not supported any more. + AuthReqKrb5 = AuthCode(2) // Kerberos V5. Not supported any more. + AuthReqPassword = AuthCode(3) // Password + AuthReqCrypt = AuthCode(4) // crypt password. Not supported any more. + AuthReqMD5 = AuthCode(5) // md5 password + _ = AuthCode(6) // 6 is available. It was used for SCM creds, not supported any more. + AuthReqGSS = AuthCode(7) // GSSAPI without wrap() + AuthReqGSSCont = AuthCode(8) // Continue GSS exchanges + AuthReqSSPI = AuthCode(9) // SSPI negotiate without wrap() + AuthReqSASL = AuthCode(10) // Begin SASL authentication + AuthReqSASLCont = AuthCode(11) // Continue SASL authentication + AuthReqSASLFin = AuthCode(12) // Final SASL message +) + +func (a AuthCode) String() string { + s, ok := map[AuthCode]string{ + AuthReqOk: "ok", + AuthReqKrb4: "krb4", + AuthReqKrb5: "krb5", + AuthReqPassword: "password", + AuthReqCrypt: "crypt", + AuthReqMD5: "md5", + AuthReqGSS: "GDD", + AuthReqGSSCont: "GSSCont", + AuthReqSSPI: "SSPI", + AuthReqSASL: "SASL", + AuthReqSASLCont: "SASLCont", + AuthReqSASLFin: "SASLFin", + }[a] + if !ok { + s = "" + } + return s + " (" + strconv.Itoa(int(a)) + ")" +} diff --git a/vendor/github.com/lib/pq/internal/proto/sz_32.go b/vendor/github.com/lib/pq/internal/proto/sz_32.go new file mode 100644 index 000000000..68065591b --- /dev/null +++ b/vendor/github.com/lib/pq/internal/proto/sz_32.go @@ -0,0 +1,7 @@ +//go:build 386 || arm || mips || mipsle + +package proto + +import "math" + +const MaxUint32 = math.MaxInt diff --git a/vendor/github.com/lib/pq/internal/proto/sz_64.go b/vendor/github.com/lib/pq/internal/proto/sz_64.go new file mode 100644 index 000000000..2b8ad8975 --- /dev/null +++ b/vendor/github.com/lib/pq/internal/proto/sz_64.go @@ -0,0 +1,7 @@ +//go:build !386 && !arm && !mips && !mipsle + +package proto + +import "math" + +const MaxUint32 = math.MaxUint32 diff --git a/vendor/github.com/lib/pq/notice.go b/vendor/github.com/lib/pq/notice.go index 70ad122a7..61b90f81e 100644 --- a/vendor/github.com/lib/pq/notice.go +++ b/vendor/github.com/lib/pq/notice.go @@ -1,6 +1,3 @@ -//go:build go1.10 -// +build go1.10 - package pq import ( diff --git a/vendor/github.com/lib/pq/notify.go b/vendor/github.com/lib/pq/notify.go index 5c421fdb8..920f0486d 100644 --- a/vendor/github.com/lib/pq/notify.go +++ b/vendor/github.com/lib/pq/notify.go @@ -1,16 +1,16 @@ package pq -// Package pq is a pure Go Postgres driver for the database/sql package. -// This module contains support for Postgres LISTEN/NOTIFY. - import ( "context" "database/sql/driver" "errors" "fmt" + "net" "sync" "sync/atomic" "time" + + "github.com/lib/pq/internal/proto" ) // Notification represents a single notification from the database. @@ -93,7 +93,7 @@ const ( ) type message struct { - typ byte + typ proto.ResponseCode err error } @@ -102,19 +102,13 @@ var errListenerConnClosed = errors.New("pq: ListenerConn has been closed") // ListenerConn is a low-level interface for waiting for notifications. You // should use Listener instead. type ListenerConn struct { - // guards cn and err - connectionLock sync.Mutex - cn *conn - err error - - connState int32 - - // the sending goroutine will be holding this lock - senderLock sync.Mutex - + connectionLock sync.Mutex // guards cn and err + senderLock sync.Mutex // the sending goroutine will be holding this lock + cn *conn + err error + connState int32 notificationChan chan<- *Notification - - replyChan chan message + replyChan chan message } // NewListenerConn creates a new ListenerConn. Use NewListener instead. @@ -189,8 +183,6 @@ func (l *ListenerConn) setState(newState int32) bool { // away or should be discarded because we couldn't agree on the state with the // server backend. func (l *ListenerConn) listenerConnLoop() (err error) { - defer errRecoverNoErrBadConn(&err) - r := &readBuf{} for { t, err := l.cn.recvMessage(r) @@ -199,43 +191,43 @@ func (l *ListenerConn) listenerConnLoop() (err error) { } switch t { - case 'A': + case proto.NotificationResponse: // recvNotification copies all the data so we don't need to worry // about the scratch buffer being overwritten. l.notificationChan <- recvNotification(r) - case 'T', 'D': + case proto.RowDescription, proto.DataRow: // only used by tests; ignore - case 'E': + case proto.ErrorResponse: // We might receive an ErrorResponse even when not in a query; it // is expected that the server will close the connection after // that, but we should make sure that the error we display is the // one from the stray ErrorResponse, not io.ErrUnexpectedEOF. if !l.setState(connStateExpectReadyForQuery) { - return parseError(r) + return parseError(r, "") } - l.replyChan <- message{t, parseError(r)} + l.replyChan <- message{t, parseError(r, "")} - case 'C', 'I': + case proto.CommandComplete, proto.EmptyQueryResponse: if !l.setState(connStateExpectReadyForQuery) { // protocol out of sync return fmt.Errorf("unexpected CommandComplete") } // ExecSimpleQuery doesn't need to know about this message - case 'Z': + case proto.ReadyForQuery: if !l.setState(connStateIdle) { // protocol out of sync return fmt.Errorf("unexpected ReadyForQuery") } l.replyChan <- message{t, nil} - case 'S': + case proto.ParameterStatus: // ignore - case 'N': + case proto.NoticeResponse: if n := l.cn.noticeHandler; n != nil { - n(parseError(r)) + n(parseError(r, "")) } default: return fmt.Errorf("unexpected message %q from server in listenerConnLoop", t) @@ -262,7 +254,7 @@ func (l *ListenerConn) listenerConnMain() { if l.err == nil { l.err = err } - l.cn.Close() + _ = l.cn.Close() l.connectionLock.Unlock() // There might be a query in-flight; make sure nobody's waiting for a @@ -290,15 +282,14 @@ func (l *ListenerConn) UnlistenAll() (bool, error) { return l.ExecSimpleQuery("UNLISTEN *") } -// Ping the remote server to make sure it's alive. Non-nil error means the +// Ping the remote server to make sure it's alive. Non-nil error means the // connection has failed and should be abandoned. func (l *ListenerConn) Ping() error { sent, err := l.ExecSimpleQuery("") if !sent { return err } - if err != nil { - // shouldn't happen + if err != nil { // shouldn't happen panic(err) } return nil @@ -309,11 +300,9 @@ func (l *ListenerConn) Ping() error { // The caller must be holding senderLock (see acquireSenderLock and // releaseSenderLock). func (l *ListenerConn) sendSimpleQuery(q string) (err error) { - defer errRecoverNoErrBadConn(&err) - - // must set connection state before sending the query + // Must set connection state before sending the query if !l.setState(connStateExpectResponse) { - panic("two queries running at the same time") + return errors.New("pq: two queries running at the same time") } // Can't use l.cn.writeBuf here because it uses the scratch buffer which @@ -323,18 +312,16 @@ func (l *ListenerConn) sendSimpleQuery(q string) (err error) { pos: 1, } b.string(q) - l.cn.send(b) - - return nil + return l.cn.send(b) } // ExecSimpleQuery executes a "simple query" (i.e. one with no bindable // parameters) on the connection. The possible return values are: -// 1) "executed" is true; the query was executed to completion on the -// database server. If the query failed, err will be set to the error -// returned by the database, otherwise err will be nil. -// 2) If "executed" is false, the query could not be executed on the remote -// server. err will be non-nil. +// 1. "executed" is true; the query was executed to completion on the +// database server. If the query failed, err will be set to the error +// returned by the database, otherwise err will be nil. +// 2. If "executed" is false, the query could not be executed on the remote +// server. err will be non-nil. // // After a call to ExecSimpleQuery has returned an executed=false value, the // connection has either been closed or will be closed shortly thereafter, and @@ -356,7 +343,7 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { l.err = err } l.connectionLock.Unlock() - l.cn.c.Close() + _ = l.cn.c.Close() return false, err } @@ -372,7 +359,7 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { return false, err } switch m.typ { - case 'Z': + case proto.ReadyForQuery: // sanity check if m.err != nil { panic("m.err != nil") @@ -380,7 +367,7 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { // done; err might or might not be set return true, err - case 'E': + case proto.ErrorResponse: // sanity check if m.err == nil { panic("m.err == nil") @@ -414,8 +401,6 @@ func (l *ListenerConn) Err() error { return l.err } -var errListenerClosed = errors.New("pq: Listener has been closed") - // ErrChannelAlreadyOpen is returned from Listen when a channel is already // open. var ErrChannelAlreadyOpen = errors.New("pq: channel is already open") @@ -541,12 +526,12 @@ func (l *Listener) NotificationChannel() <-chan *Notification { // connection can not be re-established. // // Listen will only fail in three conditions: -// 1) The channel is already open. The returned error will be -// ErrChannelAlreadyOpen. -// 2) The query was executed on the remote server, but PostgreSQL returned an -// error message in response to the query. The returned error will be a -// pq.Error containing the information the server supplied. -// 3) Close is called on the Listener before the request could be completed. +// 1. The channel is already open. The returned error will be +// ErrChannelAlreadyOpen. +// 2. The query was executed on the remote server, but PostgreSQL returned an +// error message in response to the query. The returned error will be a +// pq.Error containing the information the server supplied. +// 3. Close is called on the Listener before the request could be completed. // // The channel name is case-sensitive. func (l *Listener) Listen(channel string) error { @@ -554,7 +539,7 @@ func (l *Listener) Listen(channel string) error { defer l.lock.Unlock() if l.isClosed { - return errListenerClosed + return net.ErrClosed } // The server allows you to issue a LISTEN on a channel which is already @@ -585,7 +570,7 @@ func (l *Listener) Listen(channel string) error { l.reconnectCond.Wait() // we let go of the mutex for a while if l.isClosed { - return errListenerClosed + return net.ErrClosed } } @@ -604,7 +589,7 @@ func (l *Listener) Unlisten(channel string) error { defer l.lock.Unlock() if l.isClosed { - return errListenerClosed + return net.ErrClosed } // Similarly to LISTEN, this is not an error in Postgres, but it seems @@ -638,7 +623,7 @@ func (l *Listener) UnlistenAll() error { defer l.lock.Unlock() if l.isClosed { - return errListenerClosed + return net.ErrClosed } if l.cn != nil { @@ -663,7 +648,7 @@ func (l *Listener) Ping() error { defer l.lock.Unlock() if l.isClosed { - return errListenerClosed + return net.ErrClosed } if l.cn == nil { return errors.New("no connection") @@ -689,7 +674,7 @@ func (l *Listener) disconnectCleanup() error { } err := l.cn.Err() - l.cn.Close() + _ = l.cn.Close() l.cn = nil return err } @@ -748,25 +733,28 @@ func (l *Listener) closed() bool { } func (l *Listener) connect() error { + l.lock.Lock() + defer l.lock.Unlock() + if l.isClosed { + return net.ErrClosed + } + notificationChan := make(chan *Notification, 32) - cn, err := newDialListenerConn(l.dialer, l.name, notificationChan) + + var err error + l.cn, err = newDialListenerConn(l.dialer, l.name, notificationChan) if err != nil { return err } - l.lock.Lock() - defer l.lock.Unlock() - - err = l.resync(cn, notificationChan) + err = l.resync(l.cn, notificationChan) if err != nil { - cn.Close() + _ = l.cn.Close() return err } - l.cn = cn l.connNotificationChan = notificationChan l.reconnectCond.Broadcast() - return nil } @@ -778,11 +766,11 @@ func (l *Listener) Close() error { defer l.lock.Unlock() if l.isClosed { - return errListenerClosed + return net.ErrClosed } if l.cn != nil { - l.cn.Close() + _ = l.cn.Close() } l.isClosed = true diff --git a/vendor/github.com/lib/pq/oid/doc.go b/vendor/github.com/lib/pq/oid/doc.go index caaede248..a48650663 100644 --- a/vendor/github.com/lib/pq/oid/doc.go +++ b/vendor/github.com/lib/pq/oid/doc.go @@ -1,5 +1,6 @@ -// Package oid contains OID constants -// as defined by the Postgres server. +//go:generate go run ./gen.go + +// Package oid contains OID constants as defined by the Postgres server. package oid // Oid is a Postgres Object ID. diff --git a/vendor/github.com/lib/pq/quote.go b/vendor/github.com/lib/pq/quote.go new file mode 100644 index 000000000..909e41ecb --- /dev/null +++ b/vendor/github.com/lib/pq/quote.go @@ -0,0 +1,71 @@ +package pq + +import ( + "bytes" + "strings" +) + +// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be +// used as part of an SQL statement. For example: +// +// tblname := "my_table" +// data := "my_data" +// quoted := pq.QuoteIdentifier(tblname) +// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) +// +// Any double quotes in name will be escaped. The quoted identifier will be case +// sensitive when used in a query. If the input string contains a zero byte, the +// result will be truncated immediately before it. +func QuoteIdentifier(name string) string { + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + return `"` + strings.Replace(name, `"`, `""`, -1) + `"` +} + +// BufferQuoteIdentifier satisfies the same purpose as QuoteIdentifier, but backed by a +// byte buffer. +func BufferQuoteIdentifier(name string, buffer *bytes.Buffer) { + // TODO(v2): this should have accepted an io.Writer, not *bytes.Buffer. + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + buffer.WriteRune('"') + buffer.WriteString(strings.Replace(name, `"`, `""`, -1)) + buffer.WriteRune('"') +} + +// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal +// to DDL and other statements that do not accept parameters) to be used as part +// of an SQL statement. For example: +// +// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") +// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) +// +// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be +// replaced by two backslashes (i.e. "\\") and the C-style escape identifier +// that PostgreSQL provides ('E') will be prepended to the string. +func QuoteLiteral(literal string) string { + // This follows the PostgreSQL internal algorithm for handling quoted literals + // from libpq, which can be found in the "PQEscapeStringInternal" function, + // which is found in the libpq/fe-exec.c source file: + // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c + // + // substitute any single-quotes (') with two single-quotes ('') + literal = strings.Replace(literal, `'`, `''`, -1) + // determine if the string has any backslashes (\) in it. + // if it does, replace any backslashes (\) with two backslashes (\\) + // then, we need to wrap the entire string with a PostgreSQL + // C-style escape. Per how "PQEscapeStringInternal" handles this case, we + // also add a space before the "E" + if strings.Contains(literal, `\`) { + literal = strings.Replace(literal, `\`, `\\`, -1) + literal = ` E'` + literal + `'` + } else { + // otherwise, we can just wrap the literal with a pair of single quotes + literal = `'` + literal + `'` + } + return literal +} diff --git a/vendor/github.com/lib/pq/rows.go b/vendor/github.com/lib/pq/rows.go index c6aa5b9a3..2029bfed2 100644 --- a/vendor/github.com/lib/pq/rows.go +++ b/vendor/github.com/lib/pq/rows.go @@ -1,13 +1,182 @@ package pq import ( + "database/sql/driver" + "fmt" + "io" "math" "reflect" "time" + "github.com/lib/pq/internal/proto" "github.com/lib/pq/oid" ) +type noRows struct{} + +var emptyRows noRows + +var _ driver.Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { return 0, errNoLastInsertID } +func (noRows) RowsAffected() (int64, error) { return 0, errNoRowsAffected } + +type ( + rowsHeader struct { + colNames []string + colTyps []fieldDesc + colFmts []format + } + rows struct { + cn *conn + finish func() + rowsHeader + done bool + rb readBuf + result driver.Result + tag string + + next *rowsHeader + } +) + +func (rs *rows) Close() error { + if finish := rs.finish; finish != nil { + defer finish() + } + // no need to look at cn.bad as Next() will + for { + err := rs.Next(nil) + switch err { + case nil: + case io.EOF: + // rs.Next can return io.EOF on both ReadyForQuery and + // RowDescription (used with HasNextResultSet). We need to fetch + // messages until we hit a ReadyForQuery, which is done by waiting + // for done to be set. + if rs.done { + return nil + } + default: + return err + } + } +} + +func (rs *rows) Columns() []string { + return rs.colNames +} + +func (rs *rows) Result() driver.Result { + if rs.result == nil { + return emptyRows + } + return rs.result +} + +func (rs *rows) Tag() string { + return rs.tag +} + +func (rs *rows) Next(dest []driver.Value) (resErr error) { + if rs.done { + return io.EOF + } + if err := rs.cn.err.getForNext(); err != nil { + return err + } + + for { + t, err := rs.cn.recv1Buf(&rs.rb) + if err != nil { + return rs.cn.handleError(err) + } + switch t { + case proto.ErrorResponse: + resErr = parseError(&rs.rb, "") + case proto.CommandComplete, proto.EmptyQueryResponse: + if t == proto.CommandComplete { + rs.result, rs.tag, err = rs.cn.parseComplete(rs.rb.string()) + if err != nil { + return rs.cn.handleError(err) + } + } + continue + case proto.ReadyForQuery: + rs.cn.processReadyForQuery(&rs.rb) + rs.done = true + if resErr != nil { + return rs.cn.handleError(resErr) + } + return io.EOF + case proto.DataRow: + n := rs.rb.int16() + if resErr != nil { + rs.cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: unexpected DataRow after error %s", resErr) + } + if n < len(dest) { + dest = dest[:n] + } + for i := range dest { + l := rs.rb.int32() + if l == -1 { + dest[i] = nil + continue + } + dest[i], err = decode(&rs.cn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) + if err != nil { + return rs.cn.handleError(err) + } + } + return rs.cn.handleError(resErr) + case proto.RowDescription: + next := parsePortalRowDescribe(&rs.rb) + rs.next = &next + return io.EOF + default: + return fmt.Errorf("pq: unexpected message after execute: %q", t) + } + } +} + +func (rs *rows) HasNextResultSet() bool { + hasNext := rs.next != nil && !rs.done + return hasNext +} + +func (rs *rows) NextResultSet() error { + if rs.next == nil { + return io.EOF + } + rs.rowsHeader = *rs.next + rs.next = nil + return nil +} + +// ColumnTypeScanType returns the value type that can be used to scan types into. +func (rs *rows) ColumnTypeScanType(index int) reflect.Type { + return rs.colTyps[index].Type() +} + +// ColumnTypeDatabaseTypeName return the database system type name. +func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { + return rs.colTyps[index].Name() +} + +// ColumnTypeLength returns the length of the column type if the column is a +// variable length type. If the column is not a variable length type ok +// should return false. +func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) { + return rs.colTyps[index].Length() +} + +// ColumnTypePrecisionScale should return the precision and scale for decimal +// types. If not applicable, ok should be false. +func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + return rs.colTyps[index].PrecisionScale() +} + const headerSize = 4 type fieldDesc struct { @@ -29,7 +198,11 @@ func (fd fieldDesc) Type() reflect.Type { return reflect.TypeOf(int32(0)) case oid.T_int2: return reflect.TypeOf(int16(0)) - case oid.T_varchar, oid.T_text: + case oid.T_float8: + return reflect.TypeOf(float64(0)) + case oid.T_float4: + return reflect.TypeOf(float32(0)) + case oid.T_varchar, oid.T_text, oid.T_varbit, oid.T_bit: return reflect.TypeOf("") case oid.T_bool: return reflect.TypeOf(false) @@ -38,7 +211,7 @@ func (fd fieldDesc) Type() reflect.Type { case oid.T_bytea: return reflect.TypeOf([]byte(nil)) default: - return reflect.TypeOf(new(interface{})).Elem() + return reflect.TypeOf(new(any)).Elem() } } @@ -52,6 +225,8 @@ func (fd fieldDesc) Length() (length int64, ok bool) { return math.MaxInt64, true case oid.T_varchar, oid.T_bpchar: return int64(fd.Mod - headerSize), true + case oid.T_varbit, oid.T_bit: + return int64(fd.Mod), true default: return 0, false } @@ -68,26 +243,3 @@ func (fd fieldDesc) PrecisionScale() (precision, scale int64, ok bool) { return 0, 0, false } } - -// ColumnTypeScanType returns the value type that can be used to scan types into. -func (rs *rows) ColumnTypeScanType(index int) reflect.Type { - return rs.colTyps[index].Type() -} - -// ColumnTypeDatabaseTypeName return the database system type name. -func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { - return rs.colTyps[index].Name() -} - -// ColumnTypeLength returns the length of the column type if the column is a -// variable length type. If the column is not a variable length type ok -// should return false. -func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) { - return rs.colTyps[index].Length() -} - -// ColumnTypePrecisionScale should return the precision and scale for decimal -// types. If not applicable, ok should be false. -func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { - return rs.colTyps[index].PrecisionScale() -} diff --git a/vendor/github.com/lib/pq/scram/scram.go b/vendor/github.com/lib/pq/scram/scram.go index 477216b60..7ed7a9939 100644 --- a/vendor/github.com/lib/pq/scram/scram.go +++ b/vendor/github.com/lib/pq/scram/scram.go @@ -25,7 +25,6 @@ // Package scram implements a SCRAM-{SHA-1,etc} client per RFC5802. // // http://tools.ietf.org/html/rfc5802 -// package scram import ( @@ -43,17 +42,16 @@ import ( // // A Client may be used within a SASL conversation with logic resembling: // -// var in []byte -// var client = scram.NewClient(sha1.New, user, pass) -// for client.Step(in) { -// out := client.Out() -// // send out to server -// in := serverOut -// } -// if client.Err() != nil { -// // auth failed -// } -// +// var in []byte +// var client = scram.NewClient(sha1.New, user, pass) +// for client.Step(in) { +// out := client.Out() +// // send out to server +// in := serverOut +// } +// if client.Err() != nil { +// // auth failed +// } type Client struct { newHash func() hash.Hash @@ -73,8 +71,7 @@ type Client struct { // // For SCRAM-SHA-256, for example, use: // -// client := scram.NewClient(sha256.New, user, pass) -// +// client := scram.NewClient(sha256.New, user, pass) func NewClient(newHash func() hash.Hash, user, pass string) *Client { c := &Client{ newHash: newHash, @@ -133,7 +130,7 @@ func (c *Client) step1(in []byte) error { const nonceLen = 16 buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen)) if _, err := rand.Read(buf[:nonceLen]); err != nil { - return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %v", err) + return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %w", err) } c.clientNonce = buf[nonceLen:] b64.Encode(c.clientNonce, buf[:nonceLen]) diff --git a/vendor/github.com/lib/pq/ssl.go b/vendor/github.com/lib/pq/ssl.go index 36b61ba45..3aea110eb 100644 --- a/vendor/github.com/lib/pq/ssl.go +++ b/vendor/github.com/lib/pq/ssl.go @@ -3,22 +3,71 @@ package pq import ( "crypto/tls" "crypto/x509" - "io/ioutil" + "errors" + "fmt" "net" "os" - "os/user" "path/filepath" + "runtime" "strings" + "sync" + "syscall" + + "github.com/lib/pq/internal/pqutil" +) + +// Registry for custom tls.Configs +var ( + tlsConfs = make(map[string]*tls.Config) + tlsConfsMu sync.RWMutex ) +// RegisterTLSConfig registers a custom [tls.Config]. They are used by using +// sslmode=pqgo-«key» in the connection string. +// +// Set the config to nil to remove a configuration. +func RegisterTLSConfig(key string, config *tls.Config) error { + key = strings.TrimPrefix(key, "pqgo-") + if config == nil { + tlsConfsMu.Lock() + delete(tlsConfs, key) + tlsConfsMu.Unlock() + return nil + } + + tlsConfsMu.Lock() + tlsConfs[key] = config + tlsConfsMu.Unlock() + return nil +} + +func hasTLSConfig(key string) bool { + tlsConfsMu.RLock() + defer tlsConfsMu.RUnlock() + _, ok := tlsConfs[key] + return ok +} + +func getTLSConfigClone(key string) *tls.Config { + tlsConfsMu.RLock() + defer tlsConfsMu.RUnlock() + if v, ok := tlsConfs[key]; ok { + return v.Clone() + } + return nil +} + // ssl generates a function to upgrade a net.Conn based on the "sslmode" and // related settings. The function is nil when no upgrade should take place. -func ssl(o values) (func(net.Conn) (net.Conn, error), error) { - verifyCaOnly := false - tlsConf := tls.Config{} - switch mode := o["sslmode"]; mode { +func ssl(cfg Config) (func(net.Conn) (net.Conn, error), error) { + var ( + verifyCaOnly = false + tlsConf = &tls.Config{} + mode = cfg.SSLMode + ) + switch { // "require" is the default. - case "", "require": + case mode == "" || mode == SSLModeRequire: // We must skip TLS's own verification since it requires full // verification since Go 1.3. tlsConf.InsecureSkipVerify = true @@ -31,41 +80,46 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { // server certificate is validated against the CA. Relying on this // behavior is discouraged, and applications that need certificate // validation should always use verify-ca or verify-full. - if sslrootcert, ok := o["sslrootcert"]; ok { - if _, err := os.Stat(sslrootcert); err == nil { + if cfg.SSLRootCert != "" { + if _, err := os.Stat(cfg.SSLRootCert); err == nil { verifyCaOnly = true } else { - delete(o, "sslrootcert") + cfg.SSLRootCert = "" } } - case "verify-ca": + case mode == SSLModeVerifyCA: // We must skip TLS's own verification since it requires full // verification since Go 1.3. tlsConf.InsecureSkipVerify = true verifyCaOnly = true - case "verify-full": - tlsConf.ServerName = o["host"] - case "disable": + case mode == SSLModeVerifyFull: + tlsConf.ServerName = cfg.Host + case mode == SSLModeDisable: return nil, nil + case strings.HasPrefix(string(mode), "pqgo-"): + tlsConf = getTLSConfigClone(string(mode[5:])) + if tlsConf == nil { + return nil, fmt.Errorf(`pq: unknown custom sslmode %q`, mode) + } default: - return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) + return nil, fmt.Errorf( + `pq: unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, + mode) } // Set Server Name Indication (SNI), if enabled by connection parameters. - // By default SNI is on, any value which is not starting with "1" disables - // SNI -- that is the same check vanilla libpq uses. - if sslsni := o["sslsni"]; sslsni == "" || strings.HasPrefix(sslsni, "1") { + if cfg.SSLSNI { // RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 // or IPv6). This check is coded already crypto.tls.hostnameInSNI, so // just always set ServerName here and let crypto/tls do the filtering. - tlsConf.ServerName = o["host"] + tlsConf.ServerName = cfg.Host } - err := sslClientCertificates(&tlsConf, o) + err := sslClientCertificates(tlsConf, cfg) if err != nil { return nil, err } - err = sslCertificateAuthority(&tlsConf, o) + err = sslCertificateAuthority(tlsConf, cfg) if err != nil { return nil, err } @@ -78,25 +132,34 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient return func(conn net.Conn) (net.Conn, error) { - client := tls.Client(conn, &tlsConf) + client := tls.Client(conn, tlsConf) if verifyCaOnly { - err := sslVerifyCertificateAuthority(client, &tlsConf) + err := client.Handshake() if err != nil { - return nil, err + return client, err + } + var ( + certs = client.ConnectionState().PeerCertificates + opts = x509.VerifyOptions{Intermediates: x509.NewCertPool(), Roots: tlsConf.RootCAs} + ) + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) } + _, err = certs[0].Verify(opts) + return client, err } return client, nil }, nil } // sslClientCertificates adds the certificate specified in the "sslcert" and +// // "sslkey" settings, or if they aren't set, from the .postgresql directory // in the user's home directory. The configured files must exist and have // the correct permissions. -func sslClientCertificates(tlsConf *tls.Config, o values) error { - sslinline := o["sslinline"] - if sslinline == "true" { - cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"])) +func sslClientCertificates(tlsConf *tls.Config, cfg Config) error { + if cfg.SSLInline { + cert, err := tls.X509KeyPair([]byte(cfg.SSLCert), []byte(cfg.SSLKey)) if err != nil { return err } @@ -104,39 +167,48 @@ func sslClientCertificates(tlsConf *tls.Config, o values) error { return nil } - // user.Current() might fail when cross-compiling. We have to ignore the - // error and continue without home directory defaults, since we wouldn't - // know from where to load them. - user, _ := user.Current() + home := pqutil.Home() // In libpq, the client certificate is only loaded if the setting is not blank. // // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037 - sslcert := o["sslcert"] - if len(sslcert) == 0 && user != nil { - sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + sslcert := cfg.SSLCert + if len(sslcert) == 0 && home != "" { + if runtime.GOOS == "windows" { + sslcert = filepath.Join(sslcert, "postgresql.crt") + } else { + sslcert = filepath.Join(home, ".postgresql/postgresql.crt") + } } // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1045 if len(sslcert) == 0 { return nil } // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1050:L1054 - if _, err := os.Stat(sslcert); os.IsNotExist(err) { - return nil - } else if err != nil { + _, err := os.Stat(sslcert) + if err != nil { + perr := new(os.PathError) + if errors.As(err, &perr) && (perr.Err == syscall.ENOENT || perr.Err == syscall.ENOTDIR) { + return nil + } return err } // In libpq, the ssl key is only loaded if the setting is not blank. // // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1123-L1222 - sslkey := o["sslkey"] - if len(sslkey) == 0 && user != nil { - sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + sslkey := cfg.SSLKey + if len(sslkey) == 0 && home != "" { + if runtime.GOOS == "windows" { + sslkey = filepath.Join(home, "postgresql.key") + } else { + sslkey = filepath.Join(home, ".postgresql/postgresql.key") + } } if len(sslkey) > 0 { - if err := sslKeyPermissions(sslkey); err != nil { + err := pqutil.SSLKeyPermissions(sslkey) + if err != nil { return err } } @@ -151,54 +223,28 @@ func sslClientCertificates(tlsConf *tls.Config, o values) error { } // sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting. -func sslCertificateAuthority(tlsConf *tls.Config, o values) error { +func sslCertificateAuthority(tlsConf *tls.Config, cfg Config) error { // In libpq, the root certificate is only loaded if the setting is not blank. // // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L950-L951 - if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { + if sslrootcert := cfg.SSLRootCert; len(sslrootcert) > 0 { tlsConf.RootCAs = x509.NewCertPool() - sslinline := o["sslinline"] - var cert []byte - if sslinline == "true" { + if cfg.SSLInline { cert = []byte(sslrootcert) } else { var err error - cert, err = ioutil.ReadFile(sslrootcert) + cert, err = os.ReadFile(sslrootcert) if err != nil { return err } } if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { - return fmterrorf("couldn't parse pem in sslrootcert") + return errors.New("pq: couldn't parse pem in sslrootcert") } } return nil } - -// sslVerifyCertificateAuthority carries out a TLS handshake to the server and -// verifies the presented certificate against the CA, i.e. the one specified in -// sslrootcert or the system CA if sslrootcert was not specified. -func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error { - err := client.Handshake() - if err != nil { - return err - } - certs := client.ConnectionState().PeerCertificates - opts := x509.VerifyOptions{ - DNSName: client.ConnectionState().ServerName, - Intermediates: x509.NewCertPool(), - Roots: tlsConf.RootCAs, - } - for i, cert := range certs { - if i == 0 { - continue - } - opts.Intermediates.AddCert(cert) - } - _, err = certs[0].Verify(opts) - return err -} diff --git a/vendor/github.com/lib/pq/ssl_permissions.go b/vendor/github.com/lib/pq/ssl_permissions.go deleted file mode 100644 index d587f102e..000000000 --- a/vendor/github.com/lib/pq/ssl_permissions.go +++ /dev/null @@ -1,93 +0,0 @@ -//go:build !windows -// +build !windows - -package pq - -import ( - "errors" - "os" - "syscall" -) - -const ( - rootUserID = uint32(0) - - // The maximum permissions that a private key file owned by a regular user - // is allowed to have. This translates to u=rw. - maxUserOwnedKeyPermissions os.FileMode = 0600 - - // The maximum permissions that a private key file owned by root is allowed - // to have. This translates to u=rw,g=r. - maxRootOwnedKeyPermissions os.FileMode = 0640 -) - -var ( - errSSLKeyHasUnacceptableUserPermissions = errors.New("permissions for files not owned by root should be u=rw (0600) or less") - errSSLKeyHasUnacceptableRootPermissions = errors.New("permissions for root owned files should be u=rw,g=r (0640) or less") -) - -// sslKeyPermissions checks the permissions on user-supplied ssl key files. -// The key file should have very little access. -// -// libpq does not check key file permissions on Windows. -func sslKeyPermissions(sslkey string) error { - info, err := os.Stat(sslkey) - if err != nil { - return err - } - - err = hasCorrectPermissions(info) - - // return ErrSSLKeyHasWorldPermissions for backwards compatability with - // existing code. - if err == errSSLKeyHasUnacceptableUserPermissions || err == errSSLKeyHasUnacceptableRootPermissions { - err = ErrSSLKeyHasWorldPermissions - } - return err -} - -// hasCorrectPermissions checks the file info (and the unix-specific stat_t -// output) to verify that the permissions on the file are correct. -// -// If the file is owned by the same user the process is running as, -// the file should only have 0600 (u=rw). If the file is owned by root, -// and the group matches the group that the process is running in, the -// permissions cannot be more than 0640 (u=rw,g=r). The file should -// never have world permissions. -// -// Returns an error when the permission check fails. -func hasCorrectPermissions(info os.FileInfo) error { - // if file's permission matches 0600, allow access. - userPermissionMask := (os.FileMode(0777) ^ maxUserOwnedKeyPermissions) - - // regardless of if we're running as root or not, 0600 is acceptable, - // so we return if we match the regular user permission mask. - if info.Mode().Perm()&userPermissionMask == 0 { - return nil - } - - // We need to pull the Unix file information to get the file's owner. - // If we can't access it, there's some sort of operating system level error - // and we should fail rather than attempting to use faulty information. - sysInfo := info.Sys() - if sysInfo == nil { - return ErrSSLKeyUnknownOwnership - } - - unixStat, ok := sysInfo.(*syscall.Stat_t) - if !ok { - return ErrSSLKeyUnknownOwnership - } - - // if the file is owned by root, we allow 0640 (u=rw,g=r) to match what - // Postgres does. - if unixStat.Uid == rootUserID { - rootPermissionMask := (os.FileMode(0777) ^ maxRootOwnedKeyPermissions) - if info.Mode().Perm()&rootPermissionMask != 0 { - return errSSLKeyHasUnacceptableRootPermissions - } - return nil - } - - return errSSLKeyHasUnacceptableUserPermissions -} diff --git a/vendor/github.com/lib/pq/ssl_windows.go b/vendor/github.com/lib/pq/ssl_windows.go deleted file mode 100644 index 73663c8f1..000000000 --- a/vendor/github.com/lib/pq/ssl_windows.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build windows -// +build windows - -package pq - -// sslKeyPermissions checks the permissions on user-supplied ssl key files. -// The key file should have very little access. -// -// libpq does not check key file permissions on Windows. -func sslKeyPermissions(string) error { return nil } diff --git a/vendor/github.com/lib/pq/staticcheck.conf b/vendor/github.com/lib/pq/staticcheck.conf new file mode 100644 index 000000000..83abe48e5 --- /dev/null +++ b/vendor/github.com/lib/pq/staticcheck.conf @@ -0,0 +1,5 @@ +checks = [ + 'all', + '-ST1000', # "Must have at least one package comment" + '-ST1003', # "func EnableInfinityTs should be EnableInfinityTS" +] diff --git a/vendor/github.com/lib/pq/stmt.go b/vendor/github.com/lib/pq/stmt.go new file mode 100644 index 000000000..ca6ecc896 --- /dev/null +++ b/vendor/github.com/lib/pq/stmt.go @@ -0,0 +1,150 @@ +package pq + +import ( + "context" + "database/sql/driver" + "fmt" + "os" + + "github.com/lib/pq/internal/proto" + "github.com/lib/pq/oid" +) + +type stmt struct { + cn *conn + name string + rowsHeader + colFmtData []byte + paramTyps []oid.Oid + closed bool +} + +func (st *stmt) Close() error { + if st.closed { + return nil + } + if err := st.cn.err.get(); err != nil { + return err + } + + w := st.cn.writeBuf(proto.Close) + w.byte(proto.Sync) + w.string(st.name) + err := st.cn.send(w) + if err != nil { + return st.cn.handleError(err) + } + err = st.cn.send(st.cn.writeBuf(proto.Sync)) + if err != nil { + return st.cn.handleError(err) + } + + t, _, err := st.cn.recv1() + if err != nil { + return st.cn.handleError(err) + } + if t != proto.CloseComplete { + st.cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: unexpected close response: %q", t) + } + st.closed = true + + t, r, err := st.cn.recv1() + if err != nil { + return st.cn.handleError(err) + } + if t != proto.ReadyForQuery { + st.cn.err.set(driver.ErrBadConn) + return fmt.Errorf("pq: expected ready for query, but got: %q", t) + } + st.cn.processReadyForQuery(r) + + return nil +} + +func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { + return st.query(toNamedValue(v)) +} + +func (st *stmt) query(v []driver.NamedValue) (*rows, error) { + if err := st.cn.err.get(); err != nil { + return nil, err + } + + err := st.exec(v) + if err != nil { + return nil, st.cn.handleError(err) + } + return &rows{ + cn: st.cn, + rowsHeader: st.rowsHeader, + }, nil +} + +func (st *stmt) Exec(v []driver.Value) (driver.Result, error) { + return st.ExecContext(context.Background(), toNamedValue(v)) +} + +func (st *stmt) exec(v []driver.NamedValue) error { + if debugProto { + fmt.Fprintf(os.Stderr, " START stmt.exec\n") + defer fmt.Fprintf(os.Stderr, " END stmt.exec\n") + } + if len(v) >= 65536 { + return fmt.Errorf("pq: got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) + } + if len(v) != len(st.paramTyps) { + return fmt.Errorf("pq: got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) + } + + cn := st.cn + w := cn.writeBuf(proto.Bind) + w.byte(0) // unnamed portal + w.string(st.name) + + if cn.cfg.BinaryParameters { + err := cn.sendBinaryParameters(w, v) + if err != nil { + return err + } + } else { + w.int16(0) + w.int16(len(v)) + for i, x := range v { + if x.Value == nil { + w.int32(-1) + } else { + b, err := encode(x.Value, st.paramTyps[i]) + if err != nil { + return err + } + if b == nil { + w.int32(-1) + } else { + w.int32(len(b)) + w.bytes(b) + } + } + } + } + w.bytes(st.colFmtData) + + w.next(proto.Execute) + w.byte(0) + w.int32(0) + + w.next(proto.Sync) + err := cn.send(w) + if err != nil { + return err + } + err = cn.readBindResponse() + if err != nil { + return err + } + return cn.postExecuteWorkaround() +} + +func (st *stmt) NumInput() int { + return len(st.paramTyps) +} diff --git a/vendor/github.com/lib/pq/url.go b/vendor/github.com/lib/pq/url.go deleted file mode 100644 index aec6e95be..000000000 --- a/vendor/github.com/lib/pq/url.go +++ /dev/null @@ -1,76 +0,0 @@ -package pq - -import ( - "fmt" - "net" - nurl "net/url" - "sort" - "strings" -) - -// ParseURL no longer needs to be used by clients of this library since supplying a URL as a -// connection string to sql.Open() is now supported: -// -// sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full") -// -// It remains exported here for backwards-compatibility. -// -// ParseURL converts a url to a connection string for driver.Open. -// Example: -// -// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full" -// -// converts to: -// -// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full" -// -// A minimal example: -// -// "postgres://" -// -// This will be blank, causing driver.Open to use all of the defaults -func ParseURL(url string) (string, error) { - u, err := nurl.Parse(url) - if err != nil { - return "", err - } - - if u.Scheme != "postgres" && u.Scheme != "postgresql" { - return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) - } - - var kvs []string - escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`) - accrue := func(k, v string) { - if v != "" { - kvs = append(kvs, k+"='"+escaper.Replace(v)+"'") - } - } - - if u.User != nil { - v := u.User.Username() - accrue("user", v) - - v, _ = u.User.Password() - accrue("password", v) - } - - if host, port, err := net.SplitHostPort(u.Host); err != nil { - accrue("host", u.Host) - } else { - accrue("host", host) - accrue("port", port) - } - - if u.Path != "" { - accrue("dbname", u.Path[1:]) - } - - q := u.Query() - for k := range q { - accrue(k, q.Get(k)) - } - - sort.Strings(kvs) // Makes testing easier (not a performance concern) - return strings.Join(kvs, " "), nil -} diff --git a/vendor/github.com/lib/pq/user_other.go b/vendor/github.com/lib/pq/user_other.go deleted file mode 100644 index 3dae8f557..000000000 --- a/vendor/github.com/lib/pq/user_other.go +++ /dev/null @@ -1,10 +0,0 @@ -// Package pq is a pure Go Postgres driver for the database/sql package. - -//go:build js || android || hurd || zos -// +build js android hurd zos - -package pq - -func userCurrent() (string, error) { - return "", ErrCouldNotDetectUsername -} diff --git a/vendor/github.com/lib/pq/user_posix.go b/vendor/github.com/lib/pq/user_posix.go deleted file mode 100644 index 5f2d439bc..000000000 --- a/vendor/github.com/lib/pq/user_posix.go +++ /dev/null @@ -1,25 +0,0 @@ -// Package pq is a pure Go Postgres driver for the database/sql package. - -//go:build aix || darwin || dragonfly || freebsd || (linux && !android) || nacl || netbsd || openbsd || plan9 || solaris || rumprun || illumos -// +build aix darwin dragonfly freebsd linux,!android nacl netbsd openbsd plan9 solaris rumprun illumos - -package pq - -import ( - "os" - "os/user" -) - -func userCurrent() (string, error) { - u, err := user.Current() - if err == nil { - return u.Username, nil - } - - name := os.Getenv("USER") - if name != "" { - return name, nil - } - - return "", ErrCouldNotDetectUsername -} diff --git a/vendor/github.com/lib/pq/user_windows.go b/vendor/github.com/lib/pq/user_windows.go deleted file mode 100644 index 2b691267b..000000000 --- a/vendor/github.com/lib/pq/user_windows.go +++ /dev/null @@ -1,27 +0,0 @@ -// Package pq is a pure Go Postgres driver for the database/sql package. -package pq - -import ( - "path/filepath" - "syscall" -) - -// Perform Windows user name lookup identically to libpq. -// -// The PostgreSQL code makes use of the legacy Win32 function -// GetUserName, and that function has not been imported into stock Go. -// GetUserNameEx is available though, the difference being that a -// wider range of names are available. To get the output to be the -// same as GetUserName, only the base (or last) component of the -// result is returned. -func userCurrent() (string, error) { - pw_name := make([]uint16, 128) - pwname_size := uint32(len(pw_name)) - 1 - err := syscall.GetUserNameEx(syscall.NameSamCompatible, &pw_name[0], &pwname_size) - if err != nil { - return "", ErrCouldNotDetectUsername - } - s := syscall.UTF16ToString(pw_name) - u := filepath.Base(s) - return u, nil -} diff --git a/vendor/github.com/lib/pq/uuid.go b/vendor/github.com/lib/pq/uuid.go deleted file mode 100644 index 9a1b9e074..000000000 --- a/vendor/github.com/lib/pq/uuid.go +++ /dev/null @@ -1,23 +0,0 @@ -package pq - -import ( - "encoding/hex" - "fmt" -) - -// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. -func decodeUUIDBinary(src []byte) ([]byte, error) { - if len(src) != 16 { - return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) - } - - dst := make([]byte, 36) - dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' - hex.Encode(dst[0:], src[0:4]) - hex.Encode(dst[9:], src[4:6]) - hex.Encode(dst[14:], src[6:8]) - hex.Encode(dst[19:], src[8:10]) - hex.Encode(dst[24:], src[10:16]) - - return dst, nil -} diff --git a/vendor/modules.txt b/vendor/modules.txt index aa6f0ffa0..5d927c16b 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -90,9 +90,13 @@ github.com/google/uuid # github.com/gorilla/handlers v1.5.2 ## explicit; go 1.20 github.com/gorilla/handlers -# github.com/lib/pq v1.10.9 -## explicit; go 1.13 +# github.com/lib/pq v1.11.1 +## explicit; go 1.21 github.com/lib/pq +github.com/lib/pq/internal/pgpass +github.com/lib/pq/internal/pqsql +github.com/lib/pq/internal/pqutil +github.com/lib/pq/internal/proto github.com/lib/pq/oid github.com/lib/pq/scram # github.com/lomik/og-rek v0.0.0-20170411191824-628eefeb8d80