diff --git a/aws_config_test.go b/aws_config_test.go index 26718eb8..77348f89 100644 --- a/aws_config_test.go +++ b/aws_config_test.go @@ -3075,6 +3075,322 @@ web_identity_token_file = no-such-file } } +func TestStsEndpoint(t *testing.T) { + type settype int + const ( + setNone settype = iota + setValid + setInvalid + ) + testcases := map[string]struct { + Config Config + SetServiceEndpoint settype + SetEnv string + SetInvalidEnv string + // Use string at index 1 for valid endpoint url and index 2 for invalid endpoint url + ConfigFile string + ExpectedCredentials aws.Credentials + }{ + "service config": { + Config: Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetServiceEndpoint: setValid, + ExpectedCredentials: mockdata.MockStaticCredentials, + }, + + "service config overrides service envvar": { + Config: Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetServiceEndpoint: setValid, + SetInvalidEnv: "AWS_ENDPOINT_URL_STS", + ExpectedCredentials: mockdata.MockStaticCredentials, + }, + + "service config overrides base envvar": { + Config: Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetServiceEndpoint: setValid, + SetInvalidEnv: "AWS_ENDPOINT_URL", + ExpectedCredentials: mockdata.MockStaticCredentials, + }, + + "service config overrides service config_file": { + Config: Config{ + Profile: "default", + }, + ConfigFile: ` +[default] +aws_access_key_id = DefaultSharedCredentialsAccessKey +aws_secret_access_key = DefaultSharedCredentialsSecretKey +services = sts-test + +[services sts-test] +sts = + endpoint_url = %[2]s +`, + SetServiceEndpoint: setValid, + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "DefaultSharedCredentialsAccessKey", + SecretAccessKey: "DefaultSharedCredentialsSecretKey", + Source: sharedConfigCredentialsProvider, + }, + }, + + "service config overrides base config_file": { + Config: Config{ + Profile: "default", + }, + ConfigFile: ` +[default] +aws_access_key_id = DefaultSharedCredentialsAccessKey +aws_secret_access_key = DefaultSharedCredentialsSecretKey +endpoint_url = %[2]s +`, + SetServiceEndpoint: setValid, + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "DefaultSharedCredentialsAccessKey", + SecretAccessKey: "DefaultSharedCredentialsSecretKey", + Source: sharedConfigCredentialsProvider, + }, + }, + + "service envvar": { + Config: Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetEnv: "AWS_ENDPOINT_URL_STS", + ExpectedCredentials: mockdata.MockStaticCredentials, + }, + + "base envvar": { + Config: Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetEnv: "AWS_ENDPOINT_URL", + ExpectedCredentials: mockdata.MockStaticCredentials, + }, + + "service envvar overrides base envvar": { + Config: Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetEnv: "AWS_ENDPOINT_URL_STS", + SetInvalidEnv: "AWS_ENDPOINT_URL", + ExpectedCredentials: mockdata.MockStaticCredentials, + }, + + "service config_file": { + Config: Config{ + Profile: "default", + }, + ConfigFile: ` +[default] +aws_access_key_id = DefaultSharedCredentialsAccessKey +aws_secret_access_key = DefaultSharedCredentialsSecretKey +services = sts-test + +[services sts-test] +sts = + endpoint_url = %[1]s +`, + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "DefaultSharedCredentialsAccessKey", + SecretAccessKey: "DefaultSharedCredentialsSecretKey", + Source: sharedConfigCredentialsProvider, + }, + }, + + "service config_file overrides base config_file": { + Config: Config{ + Profile: "default", + }, + ConfigFile: ` +[default] +aws_access_key_id = DefaultSharedCredentialsAccessKey +aws_secret_access_key = DefaultSharedCredentialsSecretKey +services = sts-test +endpoint_url = %[2]s + +[services sts-test] +sts = + endpoint_url = %[1]s +`, + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "DefaultSharedCredentialsAccessKey", + SecretAccessKey: "DefaultSharedCredentialsSecretKey", + Source: sharedConfigCredentialsProvider, + }, + }, + + "service envvar overrides service config_file": { + Config: Config{ + Profile: "default", + }, + SetEnv: "AWS_ENDPOINT_URL_STS", + ConfigFile: ` +[default] +aws_access_key_id = DefaultSharedCredentialsAccessKey +aws_secret_access_key = DefaultSharedCredentialsSecretKey +services = sts-test + +[services sts-test] +sts = + endpoint_url = %[2]s +`, + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "DefaultSharedCredentialsAccessKey", + SecretAccessKey: "DefaultSharedCredentialsSecretKey", + Source: sharedConfigCredentialsProvider, + }, + }, + + "base envvar overrides service config_file": { + Config: Config{ + Profile: "default", + }, + SetEnv: "AWS_ENDPOINT_URL", + ConfigFile: ` +[default] +aws_access_key_id = DefaultSharedCredentialsAccessKey +aws_secret_access_key = DefaultSharedCredentialsSecretKey +services = sts-test + +[services sts-test] +sts = + endpoint_url = %[2]s +`, + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "DefaultSharedCredentialsAccessKey", + SecretAccessKey: "DefaultSharedCredentialsSecretKey", + Source: sharedConfigCredentialsProvider, + }, + }, + + "base config_file": { + Config: Config{ + Profile: "default", + }, + ConfigFile: ` +[default] +aws_access_key_id = DefaultSharedCredentialsAccessKey +aws_secret_access_key = DefaultSharedCredentialsSecretKey +endpoint_url = %[1]s +`, + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "DefaultSharedCredentialsAccessKey", + SecretAccessKey: "DefaultSharedCredentialsSecretKey", + Source: sharedConfigCredentialsProvider, + }, + }, + + "base envvar overrides base config_file": { + Config: Config{ + Profile: "default", + }, + SetEnv: "AWS_ENDPOINT_URL", + ConfigFile: ` +[default] +aws_access_key_id = DefaultSharedCredentialsAccessKey +aws_secret_access_key = DefaultSharedCredentialsSecretKey +endpoint_url = %[2]s +`, + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "DefaultSharedCredentialsAccessKey", + SecretAccessKey: "DefaultSharedCredentialsSecretKey", + Source: sharedConfigCredentialsProvider, + }, + }, + + "service envvar overrides base config_file": { + Config: Config{ + Profile: "default", + }, + SetEnv: "AWS_ENDPOINT_URL_STS", + ConfigFile: ` +[default] +aws_access_key_id = DefaultSharedCredentialsAccessKey +aws_secret_access_key = DefaultSharedCredentialsSecretKey +endpoint_url = %[2]s +`, + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "DefaultSharedCredentialsAccessKey", + SecretAccessKey: "DefaultSharedCredentialsSecretKey", + Source: sharedConfigCredentialsProvider, + }, + }, + } + + for name, testcase := range testcases { + testcase := testcase + + t.Run(name, func(t *testing.T) { + servicemocks.InitSessionTestEnv(t) + + ctx := context.Background() + + ts := servicemocks.MockAwsApiServer("STS", []*servicemocks.MockEndpoint{ + servicemocks.MockStsGetCallerIdentityValidEndpoint, + }) + defer ts.Close() + stsEndpoint := ts.URL + + invalidTS := servicemocks.MockAwsApiServer("STS", []*servicemocks.MockEndpoint{ + servicemocks.MockStsGetCallerIdentityInvalidEndpointAccessDenied, + }) + defer invalidTS.Close() + stsInvalidEndpoint := invalidTS.URL + + if testcase.SetServiceEndpoint == setValid { + testcase.Config.StsEndpoint = stsEndpoint + } + if testcase.SetEnv != "" { + t.Setenv(testcase.SetEnv, stsEndpoint) + } + if testcase.SetInvalidEnv != "" { + t.Setenv(testcase.SetInvalidEnv, stsInvalidEndpoint) + } + if testcase.ConfigFile != "" { + tempDir := t.TempDir() + filename := writeSharedConfigFile(t, &testcase.Config, tempDir, fmt.Sprintf(testcase.ConfigFile, stsEndpoint, stsInvalidEndpoint)) + testcase.ExpectedCredentials.Source = sharedConfigCredentialsSource(filename) + } + + ctx, awsConfig, diags := GetAwsConfig(ctx, &testcase.Config) + + if diff := cmp.Diff(diags, diag.Diagnostics{}); diff != "" { + t.Errorf("Unexpected response (+wanted, -got): %s", diff) + } + if diags.HasError() { + return + } + + credentialsValue, err := awsConfig.Credentials.Retrieve(ctx) + if err != nil { + t.Fatalf("unexpected credentials Retrieve() error: %s", err) + } + + if diff := cmp.Diff(credentialsValue, testcase.ExpectedCredentials, cmpopts.IgnoreFields(aws.Credentials{}, "Expires")); diff != "" { + t.Fatalf("unexpected credentials: (- got, + expected)\n%s", diff) + } + }) + } +} + var _ configtesting.TestDriver = &testDriver{} type testDriver struct { @@ -4006,3 +4322,21 @@ func configureHcLogger(name string, output io.Writer) hclog.Logger { return logger } + +func writeSharedConfigFile(t *testing.T, config *Config, tempDir, content string) string { + t.Helper() + + file, err := os.Create(filepath.Join(tempDir, "aws-sdk-go-base-shared-configuration-file")) + if err != nil { + t.Fatalf("creating shared configuration file: %s", err) + } + + _, err = file.WriteString(content) + if err != nil { + t.Fatalf(" writing shared configuration file: %s", err) + } + + config.SharedConfigFiles = append(config.SharedConfigFiles, file.Name()) + + return file.Name() +}