Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TT-9464] New encoding URL function #80

Merged
merged 13 commits into from
Aug 7, 2023
166 changes: 161 additions & 5 deletions persistent/internal/driver/mongo/life_cycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package mongo
import (
"context"
"errors"
"fmt"
"net/url"
"strings"
"time"

"github.com/TykTechnologies/storage/persistent/internal/helper"
Expand All @@ -12,7 +15,6 @@ import (
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"

"github.com/TykTechnologies/storage/persistent/internal/types"
)
Expand All @@ -26,17 +28,23 @@ type lifeCycle struct {

var _ types.StorageLifecycle = &lifeCycle{}

const (
MongoPrefix = "mongodb://"
MongoSRVPrefix = "mongodb+srv://"
)

// Connect connects to the mongo database given the ClientOpts.
func (lc *lifeCycle) Connect(opts *types.ClientOpts) error {
var err error
var client *mongo.Client

// we check if the connection string is valid before building the connOpts.
cs, err := connstring.ParseAndValidate(opts.ConnectionString)
url, cs, err := parseURL(opts.ConnectionString)
if err != nil {
return errors.New("invalid connection string")
return err
}

opts.ConnectionString = url

connOpts, err := mongoOptsBuilder(opts)
if err != nil {
return errors.New(err.Error())
Expand All @@ -50,12 +58,160 @@ func (lc *lifeCycle) Connect(opts *types.ClientOpts) error {
}

lc.connectionString = opts.ConnectionString
lc.database = cs.Database
lc.database = cs.db
lc.client = client

return lc.client.Ping(context.Background(), nil)
}

type urlInfo struct {
addrs []string
user string
pass string
db string
options []urlOptions
}

// urlOptions is a key/value pair representing a single option in a URL.
// we need to use this struct instead of a map to avoid flaky tests due to the order of the options
type urlOptions struct {
key string
val string
}

func isOptSep(c rune) bool {
return c == ';' || c == '&'
}

func parseURL(s string) (string, *urlInfo, error) {
var info *urlInfo
prefix := ""

if strings.HasPrefix(s, MongoPrefix) {
prefix = MongoPrefix
} else if strings.HasPrefix(s, MongoSRVPrefix) {
prefix = MongoSRVPrefix
}

switch prefix {
case MongoPrefix:
s = strings.TrimPrefix(s, MongoPrefix)
case MongoSRVPrefix:
s = strings.TrimPrefix(s, MongoSRVPrefix)
default:
return "", info, errors.New("invalid connection string, no prefix found")
}

info, err := extractURL(s)
if err != nil {
return "", info, err
}

var connString string
connString += prefix

if info.user != "" {
info.user = url.QueryEscape(info.user)
connString += info.user

if info.pass != "" {
info.pass = url.QueryEscape(info.pass)
connString += ":" + info.pass
}

connString += "@"
}

connString += strings.Join(info.addrs, ",")

connString += "/" + info.db

if len(info.options) > 0 {
connString += "?"
for _, v := range info.options {
connString += v.key + "=" + v.val + "&"
}

connString = connString[:len(connString)-1]
}

return connString, info, nil
}

func extractURL(s string) (*urlInfo, error) {
info := &urlInfo{options: make([]urlOptions, 0)}
var err error

if s, err = extractOptions(s, info); err != nil {
return nil, err
}

if s, err = extractCredentials(s, info); err != nil {
return nil, err
}

if s, err = extractDatabase(s, info); err != nil {
return nil, err
}

info.addrs = strings.Split(s, ",")

return info, nil
}

func extractOptions(s string, info *urlInfo) (string, error) {
if c := strings.Index(s, "?"); c != -1 {
for _, pair := range strings.FieldsFunc(s[c+1:], isOptSep) {
l := strings.SplitN(pair, "=", 2)
if len(l) != 2 || l[0] == "" || l[1] == "" {
return s, errors.New("connection option must be key=value: " + pair)
}

info.options = append(info.options, urlOptions{key: l[0], val: l[1]})
}

s = s[:c]
}

return s, nil
}

func extractCredentials(s string, info *urlInfo) (string, error) {
if c := strings.Index(s, "@"); c != -1 {
pair := strings.SplitN(s[:c], ":", 2)
if len(pair) > 2 || pair[0] == "" {
return s, errors.New("credentials must be provided as user:pass@host")
}

var err error

info.user, err = url.QueryUnescape(pair[0])
if err != nil {
return s, fmt.Errorf("cannot unescape username in URL: %q", pair[0])
}

if len(pair) > 1 {
info.pass, err = url.QueryUnescape(pair[1])
if err != nil {
return s, fmt.Errorf("cannot unescape password in URL")
}
}

s = s[c+1:]
}

return s, nil
}

func extractDatabase(s string, info *urlInfo) (string, error) {
if c := strings.Index(s, "/"); c != -1 {
info.db = s[c+1:]
s = s[:c]
}

return s, nil
}

// Close finish the session.
func (lc *lifeCycle) Close() error {
if lc.client != nil {
Expand Down
135 changes: 131 additions & 4 deletions persistent/internal/driver/mongo/life_cycle_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
//go:build mongo
// +build mongo

package mongo

import (
Expand Down Expand Up @@ -160,7 +157,7 @@ func TestConnect(t *testing.T) {
UseSSL: false,
Type: "mongodb",
},
want: errors.New("invalid connection string"),
want: errors.New("invalid connection string, no prefix found"),
},
{
name: "valid connection_string and invalid tls config",
Expand All @@ -185,6 +182,112 @@ func TestConnect(t *testing.T) {
}
}

func TestParseURL(t *testing.T) {
tests := []struct {
name string
url string
want string
wantErr bool
}{
{
name: "valid connection_string with special characters",
url: "mongodb://lt_tyk:6}3cZQU.9KvM/hVR4qkm-hHqZTu3yg=G@localhost:27017/tyk_analytics",
want: "mongodb://lt_tyk:6%7D3cZQU.9KvM%2FhVR4qkm-hHqZTu3yg%3DG@localhost:27017/tyk_analytics",
},
{
name: "already encoded valid url",
url: "mongodb://lt_tyk:6%7D3cZQU.9KvM%2FhVR4qkm-hHqZTu3yg%3DG@localhost:27017/tyk_analytics",
want: "mongodb://lt_tyk:6%7D3cZQU.9KvM%2FhVR4qkm-hHqZTu3yg%3DG@localhost:27017/tyk_analytics",
},
{
name: "invalid connection_string",
url: "invalid_conn_string",
want: "",
wantErr: true,
},
{
name: "valid connection string with @",
url: "mongodb://user:p@ssword@localhost:27017",
want: "mongodb://user:p@ssword@localhost:27017/",
},
{
name: "valid connection string with @ and /",
url: "mongodb://u=s@r:p@sswor/d@localhost:27017/test",
want: "mongodb://u%3Ds@r:p@sswor/d@localhost:27017/test",
},
{
name: "valid connection string with @ and / and '?' outside of the credentials part",
url: "mongodb://user:p@sswor/d@localhost:27017/test?authSource=admin",
want: "mongodb://user:p@sswor/d@localhost:27017/test?authSource=admin",
},
{
name: "special characters and multiple hosts",
url: "mongodb://user:p@sswor/d@localhost:27017,localhost:27018/test?authSource=admin",
want: "mongodb://user:p@sswor/d@localhost:27017,localhost:27018/test?authSource=admin",
},
{
name: "url without credentials",
url: "mongodb://localhost:27017/test?authSource=admin",
want: "mongodb://localhost:27017/test?authSource=admin",
},
{
name: "invalid connection string",
url: "test",
want: "",
wantErr: true,
},
{
name: "srv connection string",
url: "mongodb+srv://tyk:tyk@clur0.zlgl.mongodb.net/tyk?w=majority",
want: "mongodb+srv://tyk:tyk@clur0.zlgl.mongodb.net/tyk?w=majority",
},
{
name: "srv connection string with special characters",
url: "mongodb+srv://tyk:p@ssword@clur0.zlgl.mongodb.net/tyk?w=majority",
want: "mongodb+srv://tyk:p@ssword@clur0.zlgl.mongodb.net/tyk?w=majority",
},
{
name: "connection string without username",
url: "mongodb://:password@localhost:27017/test",
want: "",
wantErr: true,
},
{
name: "connection string without password",
url: "mongodb://user:@localhost:27017/test",
want: "mongodb://user@localhost:27017/test",
},
{
name: "connection string without host",
url: "mongodb://user:password@/test",
want: "mongodb://user:password@/test",
},
{
name: "connection string without database",
url: "mongodb://user:password@localhost:27017",
want: "mongodb://user:password@localhost:27017/",
},
{
name: "cosmosdb url",
url: "mongodb+srv://4-0-qa:zFAQ==@4-0-qa.azure:10/a1?appName=@4-testing@&maxIdleTimeMS=120000",
want: "mongodb+srv://4-0-qa:zFAQ%3D%3D@4-0-qa.azure:10/a1?appName=@4-testing@&maxIdleTimeMS=120000",
},
{
name: "cosmosdb url without database with options",
url: "mongodb+srv://tyk:6}3c.9KvM/hVR4qkm-hu3yg=G@clu0.zl.mongodb.net/?retryWrites=true&w=majority",
want: "mongodb+srv://tyk:6%7D3c.9KvM%2FhVR4qkm-hu3yg%3DG@clu0.zl.mongodb.net/?retryWrites=true&w=majority",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parsedURL, _, err := parseURL(test.url)
assert.Equal(t, test.want, parsedURL)
assert.Equal(t, test.wantErr, err != nil)
})
}
}

func TestClose(t *testing.T) {
lc := &lifeCycle{}
opts := &types.ClientOpts{
Expand Down Expand Up @@ -216,3 +319,27 @@ func TestDBType(t *testing.T) {
dbType := lc.DBType()
assert.Equal(t, utils.StandardMongo, dbType)
}

func TestIsOptSep(t *testing.T) {
tests := []struct {
input rune
want bool
}{
{';', true},
{'&', true},
{':', false},
{'a', false},
{'1', false},
{' ', false},
{'\t', false},
{'\n', false},
{'!', false},
}

for _, test := range tests {
got := isOptSep(test.input)
if got != test.want {
t.Errorf("isOptSep(%q) = %v, want %v", test.input, got, test.want)
}
}
}
2 changes: 1 addition & 1 deletion persistent/internal/driver/mongo/mongo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func TestNewMongoDriver(t *testing.T) {
})

assert.NotNil(t, err)
assert.Equal(t, "invalid connection string", err.Error())
assert.Equal(t, "invalid connection string, no prefix found", err.Error())
assert.Nil(t, newDriver)
})
t.Run("new driver without connection string", func(t *testing.T) {
Expand Down
Loading