diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 509d392e..201da6e4 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -1,11 +1,14 @@ package mysql import ( + "crypto/tls" + "crypto/x509" "database/sql" "fmt" "io" "io/ioutil" nurl "net/url" + "strconv" "strings" "github.com/go-sql-driver/mysql" @@ -23,6 +26,7 @@ var ( ErrDatabaseDirty = fmt.Errorf("database is dirty") ErrNilConfig = fmt.Errorf("no config") ErrNoDatabaseName = fmt.Errorf("no database name") + ErrAppendPEM = fmt.Errorf("failed to append PEM") ) type Config struct { @@ -94,6 +98,42 @@ func (m *Mysql) Open(url string) (database.Driver, error) { migrationsTable = DefaultMigrationsTable } + // use custom TLS? + ctls := purl.Query().Get("tls") + if len(ctls) > 0 { + if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { + rootCertPool := x509.NewCertPool() + pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca")) + if err != nil { + return nil, err + } + + if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { + return nil, ErrAppendPEM + } + + certs, err := tls.LoadX509KeyPair(purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key")) + if err != nil { + return nil, err + } + + insecureSkipVerify := false + if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 { + x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify")) + if err != nil { + return nil, err + } + insecureSkipVerify = x + } + + mysql.RegisterTLSConfig(ctls, &tls.Config{ + RootCAs: rootCertPool, + Certificates: []tls.Certificate{certs}, + InsecureSkipVerify: insecureSkipVerify, + }) + } + } + mx, err := WithInstance(db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, @@ -270,3 +310,18 @@ func (m *Mysql) ensureVersionTable() error { } return nil } + +// Returns the bool value of the input. +// The 2nd return value indicates if the input was a valid bool value +// See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 +func readBool(input string) (value bool, valid bool) { + switch input { + case "1", "true", "TRUE", "True": + return true, true + case "0", "false", "FALSE", "False": + return false, true + } + + // Not a valid bool value + return +}