diff --git a/persistent/internal/driver/mongo/life_cycle.go b/persistent/internal/driver/mongo/life_cycle.go index c10eda3b..de8aa164 100644 --- a/persistent/internal/driver/mongo/life_cycle.go +++ b/persistent/internal/driver/mongo/life_cycle.go @@ -3,6 +3,9 @@ package mongo import ( "context" "errors" + "fmt" + "net/url" + "strings" "time" "github.com/TykTechnologies/storage/persistent/internal/helper" @@ -12,7 +15,6 @@ import ( "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" - "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "github.com/TykTechnologies/storage/persistent/internal/types" ) @@ -26,17 +28,23 @@ type lifeCycle struct { var _ types.StorageLifecycle = &lifeCycle{} +const ( + MongoPrefix = "mongodb://" + MongoSRVPrefix = "mongodb+srv://" +) + // Connect connects to the mongo database given the ClientOpts. func (lc *lifeCycle) Connect(opts *types.ClientOpts) error { var err error var client *mongo.Client - // we check if the connection string is valid before building the connOpts. - cs, err := connstring.ParseAndValidate(opts.ConnectionString) + url, cs, err := parseURL(opts.ConnectionString) if err != nil { - return errors.New("invalid connection string") + return err } + opts.ConnectionString = url + connOpts, err := mongoOptsBuilder(opts) if err != nil { return errors.New(err.Error()) @@ -50,12 +58,160 @@ func (lc *lifeCycle) Connect(opts *types.ClientOpts) error { } lc.connectionString = opts.ConnectionString - lc.database = cs.Database + lc.database = cs.db lc.client = client return lc.client.Ping(context.Background(), nil) } +type urlInfo struct { + addrs []string + user string + pass string + db string + options []urlOptions +} + +// urlOptions is a key/value pair representing a single option in a URL. +// we need to use this struct instead of a map to avoid flaky tests due to the order of the options +type urlOptions struct { + key string + val string +} + +func isOptSep(c rune) bool { + return c == ';' || c == '&' +} + +func parseURL(s string) (string, *urlInfo, error) { + var info *urlInfo + prefix := "" + + if strings.HasPrefix(s, MongoPrefix) { + prefix = MongoPrefix + } else if strings.HasPrefix(s, MongoSRVPrefix) { + prefix = MongoSRVPrefix + } + + switch prefix { + case MongoPrefix: + s = strings.TrimPrefix(s, MongoPrefix) + case MongoSRVPrefix: + s = strings.TrimPrefix(s, MongoSRVPrefix) + default: + return "", info, errors.New("invalid connection string, no prefix found") + } + + info, err := extractURL(s) + if err != nil { + return "", info, err + } + + var connString string + connString += prefix + + if info.user != "" { + info.user = url.QueryEscape(info.user) + connString += info.user + + if info.pass != "" { + info.pass = url.QueryEscape(info.pass) + connString += ":" + info.pass + } + + connString += "@" + } + + connString += strings.Join(info.addrs, ",") + + connString += "/" + info.db + + if len(info.options) > 0 { + connString += "?" + for _, v := range info.options { + connString += v.key + "=" + v.val + "&" + } + + connString = connString[:len(connString)-1] + } + + return connString, info, nil +} + +func extractURL(s string) (*urlInfo, error) { + info := &urlInfo{options: make([]urlOptions, 0)} + var err error + + if s, err = extractOptions(s, info); err != nil { + return nil, err + } + + if s, err = extractCredentials(s, info); err != nil { + return nil, err + } + + if s, err = extractDatabase(s, info); err != nil { + return nil, err + } + + info.addrs = strings.Split(s, ",") + + return info, nil +} + +func extractOptions(s string, info *urlInfo) (string, error) { + if c := strings.Index(s, "?"); c != -1 { + for _, pair := range strings.FieldsFunc(s[c+1:], isOptSep) { + l := strings.SplitN(pair, "=", 2) + if len(l) != 2 || l[0] == "" || l[1] == "" { + return s, errors.New("connection option must be key=value: " + pair) + } + + info.options = append(info.options, urlOptions{key: l[0], val: l[1]}) + } + + s = s[:c] + } + + return s, nil +} + +func extractCredentials(s string, info *urlInfo) (string, error) { + if c := strings.Index(s, "@"); c != -1 { + pair := strings.SplitN(s[:c], ":", 2) + if len(pair) > 2 || pair[0] == "" { + return s, errors.New("credentials must be provided as user:pass@host") + } + + var err error + + info.user, err = url.QueryUnescape(pair[0]) + if err != nil { + return s, fmt.Errorf("cannot unescape username in URL: %q", pair[0]) + } + + if len(pair) > 1 { + info.pass, err = url.QueryUnescape(pair[1]) + if err != nil { + return s, fmt.Errorf("cannot unescape password in URL") + } + } + + s = s[c+1:] + } + + return s, nil +} + +func extractDatabase(s string, info *urlInfo) (string, error) { + if c := strings.Index(s, "/"); c != -1 { + info.db = s[c+1:] + s = s[:c] + } + + return s, nil +} + // Close finish the session. func (lc *lifeCycle) Close() error { if lc.client != nil { diff --git a/persistent/internal/driver/mongo/life_cycle_test.go b/persistent/internal/driver/mongo/life_cycle_test.go index 465713c1..195ba374 100644 --- a/persistent/internal/driver/mongo/life_cycle_test.go +++ b/persistent/internal/driver/mongo/life_cycle_test.go @@ -1,6 +1,3 @@ -//go:build mongo -// +build mongo - package mongo import ( @@ -160,7 +157,7 @@ func TestConnect(t *testing.T) { UseSSL: false, Type: "mongodb", }, - want: errors.New("invalid connection string"), + want: errors.New("invalid connection string, no prefix found"), }, { name: "valid connection_string and invalid tls config", @@ -185,6 +182,112 @@ func TestConnect(t *testing.T) { } } +func TestParseURL(t *testing.T) { + tests := []struct { + name string + url string + want string + wantErr bool + }{ + { + name: "valid connection_string with special characters", + url: "mongodb://lt_tyk:6}3cZQU.9KvM/hVR4qkm-hHqZTu3yg=G@localhost:27017/tyk_analytics", + want: "mongodb://lt_tyk:6%7D3cZQU.9KvM%2FhVR4qkm-hHqZTu3yg%3DG@localhost:27017/tyk_analytics", + }, + { + name: "already encoded valid url", + url: "mongodb://lt_tyk:6%7D3cZQU.9KvM%2FhVR4qkm-hHqZTu3yg%3DG@localhost:27017/tyk_analytics", + want: "mongodb://lt_tyk:6%7D3cZQU.9KvM%2FhVR4qkm-hHqZTu3yg%3DG@localhost:27017/tyk_analytics", + }, + { + name: "invalid connection_string", + url: "invalid_conn_string", + want: "", + wantErr: true, + }, + { + name: "valid connection string with @", + url: "mongodb://user:p@ssword@localhost:27017", + want: "mongodb://user:p@ssword@localhost:27017/", + }, + { + name: "valid connection string with @ and /", + url: "mongodb://u=s@r:p@sswor/d@localhost:27017/test", + want: "mongodb://u%3Ds@r:p@sswor/d@localhost:27017/test", + }, + { + name: "valid connection string with @ and / and '?' outside of the credentials part", + url: "mongodb://user:p@sswor/d@localhost:27017/test?authSource=admin", + want: "mongodb://user:p@sswor/d@localhost:27017/test?authSource=admin", + }, + { + name: "special characters and multiple hosts", + url: "mongodb://user:p@sswor/d@localhost:27017,localhost:27018/test?authSource=admin", + want: "mongodb://user:p@sswor/d@localhost:27017,localhost:27018/test?authSource=admin", + }, + { + name: "url without credentials", + url: "mongodb://localhost:27017/test?authSource=admin", + want: "mongodb://localhost:27017/test?authSource=admin", + }, + { + name: "invalid connection string", + url: "test", + want: "", + wantErr: true, + }, + { + name: "srv connection string", + url: "mongodb+srv://tyk:tyk@clur0.zlgl.mongodb.net/tyk?w=majority", + want: "mongodb+srv://tyk:tyk@clur0.zlgl.mongodb.net/tyk?w=majority", + }, + { + name: "srv connection string with special characters", + url: "mongodb+srv://tyk:p@ssword@clur0.zlgl.mongodb.net/tyk?w=majority", + want: "mongodb+srv://tyk:p@ssword@clur0.zlgl.mongodb.net/tyk?w=majority", + }, + { + name: "connection string without username", + url: "mongodb://:password@localhost:27017/test", + want: "", + wantErr: true, + }, + { + name: "connection string without password", + url: "mongodb://user:@localhost:27017/test", + want: "mongodb://user@localhost:27017/test", + }, + { + name: "connection string without host", + url: "mongodb://user:password@/test", + want: "mongodb://user:password@/test", + }, + { + name: "connection string without database", + url: "mongodb://user:password@localhost:27017", + want: "mongodb://user:password@localhost:27017/", + }, + { + name: "cosmosdb url", + url: "mongodb+srv://4-0-qa:zFAQ==@4-0-qa.azure:10/a1?appName=@4-testing@&maxIdleTimeMS=120000", + want: "mongodb+srv://4-0-qa:zFAQ%3D%3D@4-0-qa.azure:10/a1?appName=@4-testing@&maxIdleTimeMS=120000", + }, + { + name: "cosmosdb url without database with options", + url: "mongodb+srv://tyk:6}3c.9KvM/hVR4qkm-hu3yg=G@clu0.zl.mongodb.net/?retryWrites=true&w=majority", + want: "mongodb+srv://tyk:6%7D3c.9KvM%2FhVR4qkm-hu3yg%3DG@clu0.zl.mongodb.net/?retryWrites=true&w=majority", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + parsedURL, _, err := parseURL(test.url) + assert.Equal(t, test.want, parsedURL) + assert.Equal(t, test.wantErr, err != nil) + }) + } +} + func TestClose(t *testing.T) { lc := &lifeCycle{} opts := &types.ClientOpts{ @@ -216,3 +319,27 @@ func TestDBType(t *testing.T) { dbType := lc.DBType() assert.Equal(t, utils.StandardMongo, dbType) } + +func TestIsOptSep(t *testing.T) { + tests := []struct { + input rune + want bool + }{ + {';', true}, + {'&', true}, + {':', false}, + {'a', false}, + {'1', false}, + {' ', false}, + {'\t', false}, + {'\n', false}, + {'!', false}, + } + + for _, test := range tests { + got := isOptSep(test.input) + if got != test.want { + t.Errorf("isOptSep(%q) = %v, want %v", test.input, got, test.want) + } + } +} diff --git a/persistent/internal/driver/mongo/mongo_test.go b/persistent/internal/driver/mongo/mongo_test.go index 1c448a90..ba473e79 100644 --- a/persistent/internal/driver/mongo/mongo_test.go +++ b/persistent/internal/driver/mongo/mongo_test.go @@ -98,7 +98,7 @@ func TestNewMongoDriver(t *testing.T) { }) assert.NotNil(t, err) - assert.Equal(t, "invalid connection string", err.Error()) + assert.Equal(t, "invalid connection string, no prefix found", err.Error()) assert.Nil(t, newDriver) }) t.Run("new driver without connection string", func(t *testing.T) {