Skip to content

Commit 92c7358

Browse files
authored
Merge pull request #1 from better/derek/add-host-variable
Add host option to pass down to driver
2 parents 70d9550 + ad91ae2 commit 92c7358

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

pkg/provider/provider.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@ func Provider() *schema.Provider {
131131
Required: true,
132132
DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_REGION", "us-west-2"),
133133
},
134+
"host": {
135+
Type: schema.TypeString,
136+
Description: "Supports passing in a custom host value to the snowflake go driver for use with privatelink",
137+
Optional: true,
138+
DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_HOST", nil),
139+
},
134140
},
135141
ResourcesMap: getResources(),
136142
DataSourcesMap: getDataSources(),
@@ -259,6 +265,7 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) {
259265
oauthClientSecret := s.Get("oauth_client_secret").(string)
260266
oauthEndpoint := s.Get("oauth_endpoint").(string)
261267
oauthRedirectURL := s.Get("oauth_redirect_url").(string)
268+
host := s.Get("host").(string)
262269

263270
if oauthRefreshToken != "" {
264271
accessToken, err := GetOauthAccessToken(oauthEndpoint, oauthClientID, oauthClientSecret, GetOauthData(oauthRefreshToken, oauthRedirectURL))
@@ -279,6 +286,7 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) {
279286
oauthAccessToken,
280287
region,
281288
role,
289+
host,
282290
)
283291
if err != nil {
284292
return nil, errors.Wrap(err, "could not build dsn for snowflake connection")
@@ -302,7 +310,8 @@ func DSN(
302310
privateKeyPassphrase,
303311
oauthAccessToken,
304312
region,
305-
role string) (string, error) {
313+
role,
314+
host string) (string, error) {
306315

307316
// us-west-2 is their default region, but if you actually specify that it won't trigger their default code
308317
// https://github.com/snowflakedb/gosnowflake/blob/52137ce8c32eaf93b0bd22fc5c7297beff339812/dsn.go#L61
@@ -317,6 +326,12 @@ func DSN(
317326
Role: role,
318327
}
319328

329+
// If host is set trust it and do not use the region value
330+
if host != "" {
331+
config.Region = ""
332+
config.Host = host
333+
}
334+
320335
if privateKeyPath != "" {
321336
privateKeyBytes, err := ReadPrivateKeyFile(privateKeyPath)
322337
if err != nil {

pkg/provider/provider_test.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,27 @@ func TestDSN(t *testing.T) {
3333
password string
3434
browserAuth bool
3535
region,
36-
role string
36+
role,
37+
host string
3738
}
3839
tests := []struct {
3940
name string
4041
args args
4142
want string
4243
wantErr bool
4344
}{
44-
{"simple", args{"acct", "user", "pass", false, "region", "role"},
45+
{"simple", args{"acct", "user", "pass", false, "region", "role", ""},
4546
"user:pass@acct.region.snowflakecomputing.com:443?ocspFailOpen=true&region=region&role=role&validateDefaultParameters=true", false},
46-
{"us-west-2 special case", args{"acct2", "user2", "pass2", false, "us-west-2", "role2"},
47+
{"us-west-2 special case", args{"acct2", "user2", "pass2", false, "us-west-2", "role2", ""},
4748
"user2:pass2@acct2.snowflakecomputing.com:443?ocspFailOpen=true&role=role2&validateDefaultParameters=true", false},
49+
{"customhostwregion", args{"acct3", "user3", "pass3", false, "", "role3", "zha123.us-east-1.privatelink.snowflakecomputing.com"},
50+
"user3:pass3@zha123.us-east-1.privatelink.snowflakecomputing.com:443?account=acct3&ocspFailOpen=true&role=role3&validateDefaultParameters=true", false},
51+
{"customhostignoreregion", args{"acct4", "user4", "pass4", false, "fakeregion", "role4", "zha1234.us-east-1.privatelink.snowflakecomputing.com"},
52+
"user4:pass4@zha1234.us-east-1.privatelink.snowflakecomputing.com:443?account=acct4&ocspFailOpen=true&role=role4&validateDefaultParameters=true", false},
4853
}
4954
for _, tt := range tests {
5055
t.Run(tt.name, func(t *testing.T) {
51-
got, err := provider.DSN(tt.args.account, tt.args.user, tt.args.password, tt.args.browserAuth, "", "", "", "", tt.args.region, tt.args.role)
56+
got, err := provider.DSN(tt.args.account, tt.args.user, tt.args.password, tt.args.browserAuth, "", "", "", "", tt.args.region, tt.args.role, tt.args.host)
5257
if (err != nil) != tt.wantErr {
5358
t.Errorf("DSN() error = %v, wantErr %v", err, tt.wantErr)
5459
return
@@ -84,7 +89,7 @@ func TestOAuthDSN(t *testing.T) {
8489
}
8590
for _, tt := range tests {
8691
t.Run(tt.name, func(t *testing.T) {
87-
got, err := provider.DSN(tt.args.account, tt.args.user, "", false, "", "", "", tt.args.oauthAccessToken, tt.args.region, tt.args.role)
92+
got, err := provider.DSN(tt.args.account, tt.args.user, "", false, "", "", "", tt.args.oauthAccessToken, tt.args.region, tt.args.role, "")
8893

8994
if (err != nil) != tt.wantErr {
9095
t.Errorf("DSN() error = %v, dsn = %v, wantErr %v", err, got, tt.wantErr)

0 commit comments

Comments
 (0)